Spaces:
Running
Running
| # app_music_api.py | |
| from fastapi import FastAPI, Query | |
| from fastapi.responses import FileResponse | |
| import torch | |
| import numpy as np | |
| import os | |
| from tempfile import NamedTemporaryFile | |
| from core.music_generator import MusicGenerator | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| ARTIFACTS_DIR = "artifacts" | |
| app = FastAPI(title="MVI-AI Music API") | |
| # Initialize MusicGenerator | |
| music_model = MusicGenerator().to(DEVICE) | |
| checkpoint_path = os.path.join(ARTIFACTS_DIR, "music_generator.pt") | |
| if os.path.exists(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE) | |
| model_dict = music_model.state_dict() | |
| filtered = {k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape} | |
| model_dict.update(filtered) | |
| music_model.load_state_dict(model_dict) | |
| music_model.eval() | |
| def sequence_to_wav(sequence: np.ndarray, sample_rate=16000) -> str: | |
| """ | |
| Convert a sequence of floats to a temporary WAV file using 16-bit PCM. | |
| Returns the file path. | |
| """ | |
| from pydub import AudioSegment | |
| # Normalize to int16 | |
| audio_int16 = np.int16(sequence / np.max(np.abs(sequence)) * 32767) | |
| audio_segment = AudioSegment( | |
| audio_int16.tobytes(), | |
| frame_rate=sample_rate, | |
| sample_width=2, | |
| channels=1 | |
| ) | |
| tmp_file = NamedTemporaryFile(delete=False, suffix=".wav") | |
| audio_segment.export(tmp_file.name, format="wav") | |
| return tmp_file.name | |
| def generate(seq_len: int = 64): | |
| with torch.no_grad(): | |
| tokens = music_model.generate(seq_len=seq_len, device=DEVICE) | |
| return { | |
| "seq_len": seq_len, | |
| "output": tokens.cpu().tolist() | |
| } |