Speech-Translation.axera / ax_speech_translate_demo_qwen_api.py
HY-2012's picture
Update melotts module
c6cd813 verified
import subprocess
import tempfile
import os
import json
import shutil
import time
import librosa
import torch
import argparse
import soundfile as sf
from pathlib import Path
import cn2an
import requests
import re
import numpy as np
import onnxruntime as ort
import axengine as axe
# 导入SenseVoice相关模块
from model import SinusoidalPositionEncoder
from utils.ax_model_bin import AX_SenseVoiceSmall
from utils.ax_vad_bin import AX_Fsmn_vad
from utils.vad_utils import merge_vad
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
# 导入MeloTTS相关模块
from libmelotts.python.split_utils import split_sentence
from libmelotts.python.text import cleaned_text_to_sequence
from libmelotts.python.text.cleaner import clean_text
from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP
# 配置参数
# tts 参数
TTS_MODEL_DIR = "libmelotts/models"
TTS_MODEL_FILES = {
"g": "g-zh_mix_en.bin",
"encoder": "encoder-zh.onnx",
"decoder": "decoder-zh.axmodel"
}
# Qwen大模型翻译API参数
QWEN_API_URL = "" # API服务地址 http://10.126.29.158:8000
# TTS辅助函数(从melotts.py移植)
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
"""
def get_text_for_tts_infer(text, language_str, symbol_to_id=None):
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
phone = intersperse(phone, 0)
tone = intersperse(tone, 0)
language = intersperse(language, 0)
phone = np.array(phone, dtype=np.int32)
tone = np.array(tone, dtype=np.int32)
language = np.array(language, dtype=np.int32)
word2ph = np.array(word2ph, dtype=np.int32) * 2
word2ph[0] += 1
return phone, tone, language, norm_text, word2ph
"""
# 处理字符无法不识别
def get_text_for_tts_infer(text, language_str, symbol_to_id=None):
"""修复版音素处理:确保所有数组长度一致"""
try:
norm_text, phone, tone, word2ph = clean_text(text, language_str)
# 特殊音素直接映射为空字符串
phone_mapping = {
'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '',
'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '',
'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '',
}
# 同步处理 phone 和 tone,确保它们长度一致
processed_phone = []
processed_tone = []
removed_symbols = set()
for p, t in zip(phone, tone):
if p in phone_mapping:
# 特殊音素直接删除,同时删除对应的 tone
removed_symbols.add(p)
elif p in symbol_to_id:
# 正常音素保留,同时保留对应的 tone
processed_phone.append(p)
processed_tone.append(t)
else:
# 其他未知音素也删除
removed_symbols.add(p)
# 记录被删除的音素
if removed_symbols:
print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素: {sorted(removed_symbols)}")
print(f"[音素过滤] 处理后音素序列长度: {len(processed_phone)}")
print(f"[音素过滤] 处理后音调序列长度: {len(processed_tone)}")
# 如果没有有效音素,使用默认音素,
if not processed_phone:
print("[警告] 没有有效音素,使用默认中文音素")
processed_phone = ['ni', 'hao']
processed_tone = ['1', '3']
word2ph = [1, 1]
# 确保 word2ph 的长度与处理后的音素序列匹配
if len(processed_phone) != len(phone):
print(f"[警告] 音素序列长度变化: {len(phone)} -> {len(processed_phone)}")
# 简单处理:重新计算 word2ph
word2ph = [1] * len(processed_phone)
phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id)
phone = intersperse(phone, 0)
tone = intersperse(tone, 0)
language = intersperse(language, 0)
phone = np.array(phone, dtype=np.int32)
tone = np.array(tone, dtype=np.int32)
language = np.array(language, dtype=np.int32)
word2ph = np.array(word2ph, dtype=np.int32) * 2
word2ph[0] += 1
return phone, tone, language, norm_text, word2ph
except Exception as e:
print(f"[错误] 文本处理失败: {e}")
import traceback
traceback.print_exc()
raise e
def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = []
for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist()
audio_segments += [0] * int((sr * 0.05) / speed)
audio_segments = np.array(audio_segments).astype(np.float32)
return audio_segments
def merge_sub_audio(sub_audio_list, pad_size, audio_len):
# Average pad part
if pad_size > 0:
for i in range(len(sub_audio_list) - 1):
sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size]
sub_audio_list[i][-pad_size:] /= 2
if i > 0:
sub_audio_list[i] = sub_audio_list[i][pad_size:]
sub_audio = np.concatenate(sub_audio_list, axis=-1)
return sub_audio[:audio_len]
def calc_word2pronoun(word2ph, pronoun_lens):
indice = [0]
for ph in word2ph[:-1]:
indice.append(indice[-1] + ph)
word2pronoun = []
for i, ph in zip(indice, word2ph):
word2pronoun.append(np.sum(pronoun_lens[i : i + ph]))
return word2pronoun
def generate_slices(word2pronoun, dec_len):
pn_start, pn_end = 0, 0
zp_start, zp_end = 0, 0
zp_len = 0
pn_slices = []
zp_slices = []
while pn_end < len(word2pronoun):
# 前一个slice长度大于2 且 加上现在这个字没有超过dec_len,则往前overlap两个字
if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len:
zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end])
zp_start = zp_end - zp_len
pn_start = pn_end - 2
else:
zp_len = 0
zp_start = zp_end
pn_start = pn_end
while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len:
zp_len += word2pronoun[pn_end]
pn_end += 1
zp_end = zp_start + zp_len
pn_slices.append(slice(pn_start, pn_end))
zp_slices.append(slice(zp_start, zp_end))
return pn_slices, zp_slices
# 确认中英文
def lang_detect_with_regex(text):
"""
语言识别
"""
# 移除所有数字
text_without_digits = re.sub(r'\d+', '', text)
if not text_without_digits:
return 'unknown'
# 检查是否包含中文字符 #中文优先
if re.search(r'[\u4e00-\u9fff]', text_without_digits):
return 'chinese'
else:
# 检查是否包含英文字母
if re.search(r'[a-zA-Z]', text_without_digits):
return 'english'
else:
return 'unknown'
class QwenTranslationAPI:
def __init__(self, api_url=QWEN_API_URL):
self.api_url = api_url
self.session_id = f"speech_translate_{int(time.time())}"
def translate(self, text_content, max_retries=3, timeout=120):
"""调用千问API进行翻译"""
if not text_content or text_content.strip() == "":
return "输入文本为空"
#prompt = f"翻译成中文:{text_content}"
if lang_detect_with_regex(text_content)=='chinese':
prompt_f = "翻译成英文"
else:
prompt_f= "翻译成中文"
prompt = f"{prompt_f}{text_content}"
print(f"[翻译API] 发送请求: {prompt}")
for attempt in range(max_retries):
try:
# 第一步:发送生成请求
generate_url = f"{self.api_url}/api/generate"
payload = {
"prompt": prompt,
"temperature": 0.1, # 降低温度以获得更确定的翻译结果
"repetition_penalty": 1.0,
"top-p": 0.9,
"top-k": 40,
"max_new_tokens": 512
}
print(f"[翻译API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})")
response = requests.post(generate_url, json=payload, timeout=30)
response.raise_for_status()
print("[翻译API] 生成请求成功")
# 第二步:轮询获取结果并合并所有chunk
result_url = f"{self.api_url}/api/generate_provider"
start_time = time.time()
full_translation = ""
last_chunk = ""
while time.time() - start_time < timeout:
try:
result_response = requests.get(result_url, timeout=10)
result_data = result_response.json()
# 获取当前chunk
current_chunk = result_data.get("response", "")#.strip()
full_translation += current_chunk
# 检查是否完成
if result_data.get("done", False):
# 确保获取到完整的翻译结果
print(f"[翻译API] 翻译完成: {full_translation}")
return full_translation
time.sleep(0.05)
except requests.exceptions.RequestException as e:
print(f"[翻译API] 轮询请求失败: {e}")
if time.time() - start_time > timeout:
break
continue
print(f"[翻译API] 轮询超时,尝试第 {attempt + 1} 次重试")
except requests.exceptions.RequestException as e:
print(f"[翻译API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
wait_time = 2 ** attempt # 指数退避
print(f"[翻译API] 等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
return f"翻译失败: {str(e)}"
except Exception as e:
print(f"[翻译API] 翻译过程出错: {e}")
return f"翻译失败: {str(e)}"
return "翻译超时,请检查API服务状态"
class SpeechTranslationPipeline:
def __init__(self,
tts_model_dir, tts_model_files,
asr_model_dir="ax_model", seq_len=132,
tts_dec_len=128, sample_rate=44100, tts_speed=0.8,
qwen_api_url=QWEN_API_URL):
self.tts_model_dir = tts_model_dir
self.tts_model_files = tts_model_files
self.asr_model_dir = asr_model_dir
self.seq_len = seq_len
self.tts_dec_len = tts_dec_len
self.sample_rate = sample_rate
self.tts_speed = tts_speed
self.qwen_api_url = qwen_api_url
# 初始化ASR模型
self._init_asr_models()
# 初始化TTS模型
self._init_tts_models()
# 初始化翻译API
self.translator = QwenTranslationAPI(api_url=qwen_api_url)
# 验证所有必需文件存在
self._validate_files()
def _init_asr_models(self):
"""初始化语音识别相关模型"""
print("Initializing SenseVoice models...")
# VAD模型
self.model_vad = AX_Fsmn_vad(self.asr_model_dir)
# 位置编码
self.embed = SinusoidalPositionEncoder()
self.position_encoding = self.embed.get_position_encoding(
torch.randn(1, self.seq_len, 560)).numpy()
# ASR模型
self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len)
# Tokenizer
tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path)
print("SenseVoice models initialized successfully.")
def _init_tts_models(self):
"""初始化TTS相关模型"""
print("Initializing MeloTTS models...")
init_start = time.time()
# 加载encoder和decoder模型
enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"])
dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"])
model_load_start = time.time()
self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions())
self.sess_dec = axe.InferenceSession(dec_model)
print(f" Load encoder/decoder models: {(time.time() - model_load_start)*1000:.2f}ms")
# 加载静态输入g
g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"])
self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1)
# 设置语言和symbol映射(默认支持中英混合)
self.tts_language = "ZH_MIX_EN"
self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])}
# 预热:提前加载所有懒加载的模块(这是主要耗时部分)
print(" Warming up TTS modules (loading language models, tokenizers, etc.)...")
warmup_start = time.time()
# # 中文预热 - 触发 pypinyin, jieba, chinese tokenizer 等模块加载
# try:
# warmup_text_zh = "测试文本,用于预加载模块。"
# _, _, _, _, _ = get_text_for_tts_infer(warmup_text_zh, self.tts_language, symbol_to_id=self.symbol_to_id)
# print(f" Chinese module warm-up: {(time.time() - warmup_start)*1000:.2f}ms")
# except Exception as e:
# print(f" Warning: Chinese warm-up failed: {e}")
# # 英文预热 - 触发 g2p_en, english_utils 等模块加载
# try:
# warmup_start_en = time.time()
# warmup_text_en = "Hello world, test loading modules."
# _, _, _, _, _ = get_text_for_tts_infer(warmup_text_en, "EN", symbol_to_id=self.symbol_to_id)
# print(f" English module warm-up: {(time.time() - warmup_start_en)*1000:.2f}ms")
# except Exception as e:
# print(f" Warning: English warm-up failed: {e}")
# 中英混合预热
try:
warmup_start_mix = time.time()
warmup_text_mix = "这是一个test测试。"
_, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id)
print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start_mix)*1000:.2f}ms")
except Exception as e:
print(f" Warning: Mixed warm-up failed: {e}")
total_init_time = (time.time() - init_start) * 1000
print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms ({total_init_time/1000:.2f}s)")
def _validate_files(self):
"""验证所有必需的文件都存在"""
# 检查TTS相关文件
for key, filename in self.tts_model_files.items():
filepath = os.path.join(self.tts_model_dir, filename)
if not os.path.exists(filepath):
raise FileNotFoundError(f"TTS模型文件不存在: {filepath}")
# 检查API服务是否可用(可选)
try:
response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5)
print("[API检查] 千问API服务连接正常")
except:
print("[API警告] 无法连接到千问API服务,请确保已启动API服务")
def speech_recognition(self, speech, fs):
"""
第一步:语音识别(ASR)
"""
speech_lengths = len(speech)
# VAD处理
print("Running VAD...")
vad_start_time = time.time()
res_vad = self.model_vad(speech)[0]
vad_segments = merge_vad(res_vad, 15 * 1000)
vad_time_cost = time.time() - vad_start_time
print(f"VAD processing time: {vad_time_cost:.2f} seconds")
print(f"VAD segments detected: {len(vad_segments)}")
# ASR处理
print("Running ASR...")
asr_start_time = time.time()
all_results = ""
# 遍历每个VAD片段并处理
for i, segment in enumerate(vad_segments):
segment_start, segment_end = segment
start_sample = int(segment_start / 1000 * fs)
end_sample = min(int(segment_end / 1000 * fs), speech_lengths)
segment_speech = speech[start_sample:end_sample]
# 为当前片段创建临时文件
segment_filename = f"temp_segment_{i}.wav"
sf.write(segment_filename, segment_speech, fs)
# 对当前片段进行识别
try:
segment_res = self.model_bin(
segment_filename,
"auto", # 语言自动检测
True, # withitn
self.position_encoding,
tokenizer=self.tokenizer,
)
all_results += segment_res
# 清理临时文件
if os.path.exists(segment_filename):
os.remove(segment_filename)
except Exception as e:
if os.path.exists(segment_filename):
os.remove(segment_filename)
print(f"Error processing segment {i}: {e}")
continue
asr_time_cost = time.time() - asr_start_time
print(f"ASR processing time: {asr_time_cost:.2f} seconds")
print(f"ASR Result: {all_results}")
return all_results.strip()
def run_translation(self, text_content):
"""
第二步:调用Qwen大模型API中英互译
"""
print("Starting translation via API...")
translation_start_time = time.time()
# 使用API进行翻译
translate_content = self.translator.translate(text_content)
translation_time_cost = time.time() - translation_start_time
print(f"Translation processing time: {translation_time_cost:.2f} seconds")
print(f"Translation Result: {translate_content}")
return translate_content
def run_tts(self, translate_content, output_dir, output_wav=None):
"""
第三步:使用TTS模型合成语音
"""
output_path = os.path.join(output_dir, output_wav)
try:
# 处理中文文本中的数字
if lang_detect_with_regex(translate_content) == "chinese":
translate_content = cn2an.transform(translate_content, "an2cn")
print(f"TTS synthesis for text: {translate_content}")
# 分句
sens = split_sentence(translate_content, language_str=self.tts_language)
print(f"Text split into {len(sens)} sentences")
# 最终音频列表
audio_list = []
# 遍历每个句子
for n, se in enumerate(sens):
# 处理英文大小写连接
if self.tts_language in ['EN', 'ZH_MIX_EN']:
se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se)
print(f"Processing sentence[{n}]: {se}")
# 转换文本为音素和音调
phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer(
se, self.tts_language, symbol_to_id=self.symbol_to_id)
# 运行encoder
encoder_start = time.time()
z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={
'phone': phones, 'g': self.tts_g,
'tone': tones, 'language': lang_ids,
'noise_scale': np.array([0], dtype=np.float32),
'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32),
'noise_scale_w': np.array([0], dtype=np.float32),
'sdp_ratio': np.array([0], dtype=np.float32)})
print(f"Encoder run time: {1000 * (time.time() - encoder_start):.2f}ms")
# 计算每个词的发音长度
word2pronoun = calc_word2pronoun(word2ph, pronoun_lens)
# 生成切片
pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len)
audio_len = audio_len[0]
sub_audio_list = []
for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)):
zp_slice = z_p[..., zs]
# Padding前zp的长度
sub_dec_len = zp_slice.shape[-1]
# Padding前输出音频的长度
sub_audio_len = 512 * sub_dec_len
# Padding到dec_len
if zp_slice.shape[-1] < self.tts_dec_len:
zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1)
decoder_start = time.time()
audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten()
# 处理overlap
audio_start = 0
if len(sub_audio_list) > 0:
if pn_slices[i - 1].stop > ps.start:
# 去掉第一个字
audio_start = 512 * word2pronoun[ps.start]
audio_end = sub_audio_len
if i < len(pn_slices) - 1:
if ps.stop > pn_slices[i + 1].start:
# 去掉最后一个字
audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1]
audio = audio[audio_start:audio_end]
print(f"Decode slice[{i}]: decoder run time {1000 * (time.time() - decoder_start):.2f}ms")
sub_audio_list.append(audio)
# 合并子音频
sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len)
audio_list.append(sub_audio)
# 拼接所有句子的音频
audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed)
# 保存音频文件
sf.write(output_path, audio, self.sample_rate)
print(f"TTS audio saved to {output_path}")
return output_path
except Exception as e:
print(f"TTS synthesis failed: {e}")
import traceback
traceback.print_exc()
raise e
def full_pipeline(self, speech, fs, output_dir=None, output_tts=None):
"""
完整Pipeline:语音识别 -> 翻译 -> TTS合成
"""
# 第一步:语音识别
print("\n----------------------VAD+ASR----------------------------\n")
start_time = time.time() # 记录开始时间
text_content = self.speech_recognition(speech, fs)
asr_time = time.time() - start_time # 计算耗时
print(f"语音识别耗时: {asr_time:.2f} 秒")
if not text_content or text_content.strip() == "":
raise ValueError("ASR未能识别出有效文本")
# 第二步:翻译
print("\n---------------------Qwen翻译---------------------------\n")
start_time = time.time() # 记录开始时间
translate_content = self.run_translation(text_content)
translate_time = time.time() - start_time # 计算耗时
print(f"翻译耗时: {translate_time:.2f} 秒")
# 第三步:TTS合成
print("-------------------------TTS-------------------------------\n")
start_time = time.time() # 记录开始时间
output_path = self.run_tts(translate_content, output_dir, output_tts)
tts_time = time.time() - start_time # 计算耗时
print(f"TTS合成耗时: {tts_time:.2f} 秒")
return {
"original_text": text_content,
"translated_text": translate_content,
"audio_path": output_path
}
def main():
parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline")
parser.add_argument("--audio_file", type=str, default="./wav/en.mp3", help="Input audio file path")
parser.add_argument("--output_dir", type=str, default="./output", help="Output directory")
parser.add_argument("--output_tts", type=str, default="output.wav", help="Output TTS file name")
parser.add_argument("--api_url", type=str, default="http://10.126.29.158:8000", help="Qwen API server URL")
args = parser.parse_args()
print("-------------------START------------------------\n")
os.makedirs(args.output_dir, exist_ok=True)
print(f"Processing audio file: {args.audio_file}")
# 加载音频
speech, fs = librosa.load(args.audio_file, sr=None)
if fs != 16000:
print(f"Resampling audio from {fs}Hz to 16000Hz")
speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000)
fs = 16000
audio_duration = librosa.get_duration(y=speech, sr=fs)
# 初始化
pipeline = SpeechTranslationPipeline(
tts_model_dir=TTS_MODEL_DIR,
tts_model_files=TTS_MODEL_FILES,
asr_model_dir="ax_model",
seq_len=132,
tts_dec_len=128,
sample_rate=44100,
tts_speed=0.8,
qwen_api_url=args.api_url
)
start_time = time.time()
try:
# 运行
result = pipeline.full_pipeline(speech, fs, args.output_dir, args.output_tts)
print("\n" + "="*50)
print("speech translate 完成!")
print("="*50 + "\n")
print(f"原始音频: {args.audio_file}")
print(f"原始文本: {result['original_text']}")
print(f"翻译文本: {result['translated_text']}")
print(f"生成音频: {result['audio_path']}")
# 保存结果到文件
result_file = os.path.join(args.output_dir, "pipeline_result.txt")
with open(result_file, 'w', encoding='utf-8') as f:
f.write(f"原始音频: {args.audio_file}\n")
f.write(f"识别文本: {result['original_text']}\n")
f.write(f"翻译结果: {result['translated_text']}\n")
f.write(f"合成音频: {result['audio_path']}\n")
time_cost = time.time() - start_time
rtf = time_cost / audio_duration
print(f"Inference time for {args.audio_file}: {time_cost:.2f} seconds")
print(f"Audio duration: {audio_duration:.2f} seconds")
print(f"RTF: {rtf:.2f}\n")
except Exception as e:
print(f"Pipeline执行失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()