|
@@ -11,6 +11,9 @@ from funasr import AutoModel
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
asr_model: Optional[AutoModel] = None
|
|
asr_model: Optional[AutoModel] = None
|
|
|
is_model_ready = False
|
|
is_model_ready = False
|
|
|
|
|
+model_name1 = 'paraformer-zh'
|
|
|
|
|
+model_name2 = 'fsmn-vad'
|
|
|
|
|
+model_name3 = 'ct-punc'
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_asr_model():
|
|
def get_asr_model():
|
|
@@ -39,13 +42,12 @@ async def init_funasr():
|
|
|
# 使用 run_in_executor 避免阻塞主事件循环
|
|
# 使用 run_in_executor 避免阻塞主事件循环
|
|
|
loop = asyncio.get_event_loop()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
- # 定义具体的加载逻辑
|
|
|
|
|
def load():
|
|
def load():
|
|
|
return AutoModel(
|
|
return AutoModel(
|
|
|
- model="paraformer-zh",
|
|
|
|
|
- vad_model="fsmn-vad",
|
|
|
|
|
|
|
+ model=model_name1,
|
|
|
|
|
+ vad_model=model_name2,
|
|
|
vad_kwargs={"max_single_segment_time": 30000},
|
|
vad_kwargs={"max_single_segment_time": 30000},
|
|
|
- punc_model="ct-punc",
|
|
|
|
|
|
|
+ punc_model=model_name3,
|
|
|
device="cuda:0" if torch.cuda.is_available() else "cpu",
|
|
device="cuda:0" if torch.cuda.is_available() else "cpu",
|
|
|
disable_update=True
|
|
disable_update=True
|
|
|
)
|
|
)
|