|
|
""" |
|
|
Optimised NeMo Parakeet-TDT streaming demo for CPU-only Hugging Face Spaces |
|
|
""" |
|
|
|
|
|
import os, time, threading, queue, logging |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from scipy import signal |
|
|
import torch |
|
|
from nemo.collections.asr.models import ASRModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "2" |
|
|
torch.set_num_threads(2) |
|
|
torch.backends.quantized.engine = "fbgemm" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(message)s", |
|
|
datefmt="%H:%M:%S", |
|
|
) |
|
|
logger = logging.getLogger("asr_app") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SR = 16_000 |
|
|
CHUNK_SECONDS = 4 |
|
|
CHUNK_SAMPLES = SR * CHUNK_SECONDS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ASRApp: |
|
|
def __init__(self): |
|
|
self.audio_queue = queue.Queue(maxsize=100) |
|
|
self.transcript_queue = queue.Queue() |
|
|
self.transcript_list = [] |
|
|
self._load_model() |
|
|
self._start_worker() |
|
|
|
|
|
|
|
|
def _log(self, func: str, msg: str): |
|
|
logger.info( |
|
|
f"{func} | audio_q={self.audio_queue.qsize():02}, " |
|
|
f"txt_q={self.transcript_queue.qsize():02} | {msg}" |
|
|
) |
|
|
|
|
|
|
|
|
def _load_model(self): |
|
|
self._log("load_model", "loading Parakeet-TDT-0.6B-V2 (CPU)β¦") |
|
|
t0 = time.time() |
|
|
model = ASRModel.from_pretrained( |
|
|
model_name="nvidia/parakeet-tdt-0.6b-v2", |
|
|
map_location="cpu", |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
try: |
|
|
model = torch.quantization.quantize_dynamic( |
|
|
model, |
|
|
{torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU}, |
|
|
dtype=torch.qint8, |
|
|
) |
|
|
self._log("load_model", "INT8 quantisation applied") |
|
|
except Exception as e: |
|
|
self._log("load_model", f"quantisation skipped ({e})") |
|
|
self.asr_model = model |
|
|
self._log("load_model", f"model ready in {time.time()-t0:.1f}s") |
|
|
|
|
|
with torch.inference_mode(): |
|
|
_ = self.asr_model.transcribe( |
|
|
[np.zeros(SR, dtype=np.float32)] |
|
|
) |
|
|
self._log("load_model", "warm-up done") |
|
|
|
|
|
|
|
|
def _start_worker(self): |
|
|
threading.Thread( |
|
|
target=self._worker, |
|
|
daemon=True, |
|
|
).start() |
|
|
|
|
|
def _worker(self): |
|
|
buf = np.array([], dtype=np.float32) |
|
|
while True: |
|
|
try: |
|
|
|
|
|
while len(buf) < CHUNK_SAMPLES: |
|
|
buf = np.concatenate([buf, self.audio_queue.get()]) |
|
|
self._log("_worker", f"buffer={len(buf)}") |
|
|
chunk, buf = buf[:CHUNK_SAMPLES], buf[CHUNK_SAMPLES:] |
|
|
self._log("_worker", f"β transcribe {len(chunk)} samples") |
|
|
t0 = time.time() |
|
|
with torch.inference_mode(): |
|
|
out = self.asr_model.transcribe([chunk]) |
|
|
dur = time.time() - t0 |
|
|
text = out[0].text |
|
|
self._log("_worker", f"inference {dur:.2f}s β β{text}β") |
|
|
self.transcript_queue.put(text) |
|
|
except Exception as e: |
|
|
self._log("_worker", f"ASR error: {e}") |
|
|
|
|
|
|
|
|
def _preprocess(self, audio): |
|
|
sr, y = audio |
|
|
if y.ndim > 1: |
|
|
y = y.mean(axis=1) |
|
|
if sr != SR: |
|
|
|
|
|
y = signal.resample_poly(y, SR, sr) |
|
|
y = y.astype(np.float32) |
|
|
y /= (np.abs(y).max() + 1e-9) |
|
|
return y |
|
|
|
|
|
|
|
|
def stream_fn(self, audio): |
|
|
self._log("stream_fn", "audio arrived") |
|
|
self.audio_queue.put(self._preprocess(audio)) |
|
|
while not self.transcript_queue.empty(): |
|
|
self.transcript_list.append(self.transcript_queue.get()) |
|
|
return ( |
|
|
" ".join(self.transcript_list) |
|
|
if self.transcript_list |
|
|
else "β¦listeningβ¦" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
asr_app = ASRApp() |
|
|
with gr.Blocks() as demo: |
|
|
mic = gr.Audio( |
|
|
sources=["microphone"], |
|
|
type="numpy", |
|
|
streaming=True, |
|
|
label="Microphone", |
|
|
) |
|
|
out = gr.Textbox(label="Transcription") |
|
|
mic.stream( |
|
|
fn=asr_app.stream_fn, |
|
|
inputs=mic, |
|
|
outputs=out, |
|
|
stream_every=0.5, |
|
|
) |
|
|
|
|
|
asr_app._log("main", "launching UI") |
|
|
demo.launch() |