ai_asr.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import logging
  2. import asyncio
  3. import re
  4. import time
  5. from typing import Optional
  6. import service.pyav as pyav
  7. import torch
  8. from funasr import AutoModel
  9. logger = logging.getLogger(__name__)
  10. asr_model: Optional[AutoModel] = None
  11. is_model_ready = False
  12. def get_asr_model():
  13. """获取模型实例的接口"""
  14. global asr_model
  15. return asr_model
  16. def check_ready():
  17. """检查模型是否加载完成"""
  18. global is_model_ready
  19. return is_model_ready
  20. async def init_funasr():
  21. """异步初始化函数"""
  22. global asr_model, is_model_ready
  23. if is_model_ready:
  24. return
  25. logger.info("⏳ [ASR] 开始异步加载 funasr 模型...")
  26. start_time = time.time()
  27. try:
  28. # 使用 run_in_executor 避免阻塞主事件循环
  29. loop = asyncio.get_event_loop()
  30. # 定义具体的加载逻辑
  31. def load():
  32. return AutoModel(
  33. model="paraformer-zh",
  34. vad_model="fsmn-vad",
  35. vad_kwargs={"max_single_segment_time": 30000},
  36. punc_model="ct-punc",
  37. device="cuda:0" if torch.cuda.is_available() else "cpu",
  38. disable_update=True
  39. )
  40. asr_model = await loop.run_in_executor(None, load)
  41. is_model_ready = True
  42. logger.info(f"✅ [ASR] 模型加载成功!耗时 {(time.time() - start_time):.2f}s")
  43. except Exception as e:
  44. logger.error(f"❌ [ASR] 模型加载失败: {e}")
  45. is_model_ready = False
  46. def get_text(audio_path):
  47. start_time = time.time()
  48. logger.info("⏳ 开始进行音频识别...")
  49. result = asr_model.generate(input=[audio_path], cache={}, batch_size_s=300)
  50. logger.info(f"✅ 音频识别完成, 耗时 {(time.time() - start_time):.2f}秒")
  51. # 清理文本中的空字符
  52. text = result[0]['text'].replace(" ", "")
  53. timestamps = result[0]['timestamp']
  54. return {
  55. 'text': text,
  56. 'timestamps': timestamps
  57. }
  58. def generate_srt(audio_path, srt_path):
  59. result = get_text(audio_path)
  60. text = result['text']
  61. timestamp_list = result['timestamps']
  62. # 使用正则表达式将文本按标点切分为一个 list,保留标点
  63. text_list = re.split(r"([。!?;,])", text)
  64. srt_list = pyav.get_precise_srt(text_list, timestamp_list)
  65. pyav.save_srt_file(srt_list, srt_path)