ai_asr.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import logging
  2. import asyncio
  3. import re
  4. import time
  5. from typing import Optional
  6. import service.pyav as pyav
  7. import torch
  8. from funasr import AutoModel
  9. logger = logging.getLogger(__name__)
  10. asr_model: Optional[AutoModel] = None
  11. is_model_ready = False
  12. model_name1 = 'paraformer-zh'
  13. model_name2 = 'fsmn-vad'
  14. model_name3 = 'ct-punc'
  15. def get_asr_model():
  16. """获取模型实例的接口"""
  17. global asr_model
  18. return asr_model
  19. def check_ready():
  20. """检查模型是否加载完成"""
  21. global is_model_ready
  22. return is_model_ready
  23. async def init_funasr():
  24. """异步初始化函数"""
  25. global asr_model, is_model_ready
  26. if is_model_ready:
  27. return
  28. logger.info("⏳ [ASR] 开始异步加载 funasr 模型...")
  29. start_time = time.time()
  30. try:
  31. # 使用 run_in_executor 避免阻塞主事件循环
  32. loop = asyncio.get_event_loop()
  33. def load():
  34. return AutoModel(
  35. model=model_name1,
  36. vad_model=model_name2,
  37. vad_kwargs={"max_single_segment_time": 30000},
  38. punc_model=model_name3,
  39. device="cuda:0" if torch.cuda.is_available() else "cpu",
  40. disable_update=True
  41. )
  42. asr_model = await loop.run_in_executor(None, load)
  43. is_model_ready = True
  44. logger.info(f"✅ [ASR] 模型加载成功!耗时 {(time.time() - start_time):.2f}s")
  45. except Exception as e:
  46. logger.error(f"❌ [ASR] 模型加载失败: {e}")
  47. is_model_ready = False
  48. def get_text(audio_path):
  49. start_time = time.time()
  50. logger.info("⏳ 开始进行音频识别...")
  51. result = asr_model.generate(input=[audio_path], cache={}, batch_size_s=300)
  52. logger.info(f"✅ 音频识别完成, 耗时 {(time.time() - start_time):.2f}秒")
  53. # 清理文本中的空字符
  54. text = result[0]['text'].replace(" ", "")
  55. timestamps = result[0]['timestamp']
  56. return {
  57. 'text': text,
  58. 'timestamps': timestamps
  59. }
  60. def generate_srt(audio_path, srt_path):
  61. result = get_text(audio_path)
  62. text = result['text']
  63. timestamp_list = result['timestamps']
  64. # 使用正则表达式将文本按标点切分为一个 list,保留标点
  65. text_list = re.split(r"([。!?;,])", text)
  66. srt_list = pyav.get_precise_srt(text_list, timestamp_list)
  67. pyav.save_srt_file(srt_list, srt_path)