from fastapi import FastAPI, File, UploadFile from fastapi.responses import FileResponse from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import whisper import torch from gtts import gTTS import os import yt_dlp import re hf_token = os.getenv("HF_TOKEN") app = FastAPI() # Load Qwen model model_name = "Qwen/Qwen3-4B-Instruct-2507" tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token) model = AutoModelForCausalLM.from_pretrained( model_name, token=hf_token, device_map={"": "cpu"}, dtype=torch.float32 ) # Load Whisper model whisper_model = whisper.load_model("base") # Lưu hội thoại conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}] # Hàm trích xuất tên bài hát từ văn bản def extract_song_name(text): import re match = re.search(r"(bài|bài hát|nghe nhạc|mở nhạc)\s+(.*)", text.lower()) if match: return match.group(2).strip() return None def download_youtube_as_wav(song_name, output_path="song.wav"): search_query = f"ytsearch1:{song_name}" ydl_opts = { 'format': 'bestaudio/best', 'outtmpl': 'temp_audio.%(ext)s', 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192', }], 'quiet': True, } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([search_query]) if os.path.exists("temp_audio.wav"): os.rename("temp_audio.wav", output_path) return output_path return None class ChatRequest(BaseModel): message: str @app.get("/") def read_root(): return {"message": "Ứng dụng đang chạy!"} # Endpoint chat text @app.post("/chat") async def chat(request: ChatRequest): conversation.append({"role": "user", "content": request.message}) text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) response_text = generate_full_response(model_inputs) conversation.append({"role": "assistant", "content": response_text}) return {"response": response_text} # Endpoint voice chat + TTS @app.post("/voice_chat") async def voice_chat(file: UploadFile = File(...)): file_location = f"temp_{file.filename}" with open(file_location, "wb") as f: f.write(await file.read()) result = whisper_model.transcribe(file_location, language="vi") user_text = result["text"] os.remove(file_location) # Kiểm tra yêu cầu mở nhạc if any(kw in user_text.lower() for kw in ["nghe nhạc", "mở bài hát", "bài hát", "bài"]): song_name = extract_song_name(user_text) if song_name: wav_path = download_youtube_as_wav(song_name) if wav_path: return FileResponse(wav_path, media_type="audio/wav", filename="song.wav") else: return {"error": "Không tìm thấy hoặc tải được bài hát."} # Nếu không phải yêu cầu mở nhạc → xử lý như cũ conversation.append({"role": "user", "content": user_text}) text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) response_text = generate_full_response(model_inputs) conversation.append({"role": "assistant", "content": response_text}) tts = gTTS(response_text, lang="vi") audio_file = "response.mp3" tts.save(audio_file) return { "user_text": user_text, "response": response_text, "audio_url": f"/get_audio" } # Endpoint trả về file âm thanh @app.get("/get_audio") async def get_audio(): return FileResponse("response.mp3", media_type="audio/mpeg") # Hàm sinh phản hồi def generate_full_response(model_inputs, max_new_tokens=64): with torch.inference_mode(): generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() response_text = tokenizer.decode(output_ids, skip_special_tokens=True) return response_text.strip()