import os import re from fastapi import FastAPI, Request from pydantic import BaseModel from inference_onnx import get_transcription import torch import onnxruntime as ort from config import * from contextlib import asynccontextmanager # Global session object (attached to app.state) @asynccontextmanager async def lifespan(app: FastAPI): print("🔧 Loading model...") app.state.device = torch.device('cpu') app.state.tokenizer = MODELS["./distilbert-base-multilingual-cased"][1].from_pretrained("./distilbert-base-multilingual-cased") app.state.token_style = MODELS["./distilbert-base-multilingual-cased"][3] onnx_model_path = "./poc_onnx_model_punctuation_batch.onnx" providers = ['CPUExecutionProvider'] # providers = ["CUDAExecutionProvider"] # providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] sess_options = ort.SessionOptions() app.state.session = ort.InferenceSession(onnx_model_path, providers=providers) print("✅ ONNX model loaded into memory.") yield print("🧹 Shutting down...") app = FastAPI(lifespan=lifespan) punc_dict = { '!': 'EXCLAMATION', '?': 'QUESTION', ',': 'COMMA', ';': 'SEMICOLON', ':': 'COLON', '-': 'HYPHEN', '।': 'DARI', } allowed_punctuations = set(punc_dict.keys()) def clean_and_normalize_text(text, remove_punctuations=False): """Clean and normalize Bangla text with correct spacing""" if remove_punctuations: # Remove all allowed punctuations cleaned_text = re.sub(f"[{re.escape(''.join(allowed_punctuations))}]", "", text) # Normalize spaces cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() return cleaned_text else: # Keep only allowed punctuations and Bangla letters/digits chunks = re.split(f"([{re.escape(''.join(allowed_punctuations))}])", text) filtered_chunks = [] for chunk in chunks: if chunk in allowed_punctuations: filtered_chunks.append(chunk) else: # Clean text and preserve word boundaries clean_chunk = re.sub(rf"[^\u0980-\u09FF\u09E6-\u09EF\s]", "", chunk) clean_chunk = re.sub(r'\s+', ' ', clean_chunk) # Normalize internal spacing clean_chunk = clean_chunk.strip() if clean_chunk: filtered_chunks.append(' ' + clean_chunk) # Add space before word chunks # Join and clean up spacing result = ''.join(filtered_chunks) result = re.sub(r'\s+', ' ', result).strip() return result class TextInput(BaseModel): text: str @app.post("/punctuate") async def punctuate_text(data: TextInput): input_normalized = clean_and_normalize_text(data.text) input_normalized = clean_and_normalize_text(input_normalized, remove_punctuations=True) restored_text = get_transcription(input_normalized, app.state.session, app.state.tokenizer, app.state.device, app.state.token_style) return {"restored_text": restored_text} if __name__ == "__main__": import uvicorn uvicorn.run("api_onnx:app", host="0.0.0.0", port=5685, workers=1)