ai_text.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import logging
  2. import os
  3. os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
  4. import time
  5. from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
  6. import torch
  7. logger = logging.getLogger(__name__)
  8. model_name = "Qwen/Qwen2.5-1.5B-Instruct"
  9. start_time = time.time()
  10. logger.info(f"⏳ 开始加载 {model_name} 模型...")
  11. try:
  12. # 1. 定义量化配置
  13. quantization_config = BitsAndBytesConfig(
  14. load_in_4bit=True,
  15. bnb_4bit_compute_dtype=torch.float16, # 1650 显卡建议设为 fp16
  16. bnb_4bit_quant_type="nf4", # 高精度量化类型
  17. bnb_4bit_use_double_quant=True # 进一步压缩显存
  18. )
  19. # 2. 加载模型
  20. model = AutoModelForCausalLM.from_pretrained(
  21. model_name,
  22. quantization_config=quantization_config, # 使用配置对象
  23. device_map="auto" # 自动分配到 GPU
  24. )
  25. tokenizer = AutoTokenizer.from_pretrained(model_name)
  26. logger.info(f"✅ {model_name} 模型加载成功!耗时 {(time.time() - start_time):.2f}秒")
  27. except Exception as e:
  28. logger.info(f"模型加载失败: {e}")
  29. logger.info(f"✅ {model_name} 模型加载失败: {e}")
  30. raise e
  31. def translate2zh(text):
  32. # 构建适合 Moondream 场景的 Prompt
  33. prompt = f"你是一个专业的图像描述翻译官。请将下面这段英文描述翻译成自然、地道的中文,直接输出结果,不要解释:\n{text}"
  34. messages = [{"role": "user", "content": prompt}]
  35. input_text = tokenizer.apply_chat_template(
  36. messages,
  37. tokenize=False,
  38. add_generation_prompt=True
  39. )
  40. model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
  41. # 推理
  42. max_tokens = 128
  43. with torch.no_grad():
  44. generated_ids = model.generate(
  45. **model_inputs,
  46. max_new_tokens=max_tokens,
  47. do_sample=False # 翻译建议关闭随机性,保证结果稳定
  48. )
  49. # 解码
  50. response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  51. # 提取助手回答的部分
  52. final_result = response.split("assistant\n")[-1].strip()
  53. return final_result
  54. def summarize():
  55. pass