File size: 1,690 Bytes
05a0146
d868109
 
05a0146
 
 
 
 
 
d868109
05a0146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868109
 
 
f76c865
05a0146
f76c865
05a0146
f76c865
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# app_music_api.py

from fastapi import FastAPI, Query
from fastapi.responses import FileResponse
import torch
import numpy as np
import os
from tempfile import NamedTemporaryFile
from core.music_generator import MusicGenerator

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ARTIFACTS_DIR = "artifacts"

app = FastAPI(title="MVI-AI Music API")

# Initialize MusicGenerator
music_model = MusicGenerator().to(DEVICE)
checkpoint_path = os.path.join(ARTIFACTS_DIR, "music_generator.pt")
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model_dict = music_model.state_dict()
    filtered = {k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape}
    model_dict.update(filtered)
    music_model.load_state_dict(model_dict)
music_model.eval()


def sequence_to_wav(sequence: np.ndarray, sample_rate=16000) -> str:
    """
    Convert a sequence of floats to a temporary WAV file using 16-bit PCM.
    Returns the file path.
    """
    from pydub import AudioSegment

    # Normalize to int16
    audio_int16 = np.int16(sequence / np.max(np.abs(sequence)) * 32767)
    audio_segment = AudioSegment(
        audio_int16.tobytes(),
        frame_rate=sample_rate,
        sample_width=2,
        channels=1
    )

    tmp_file = NamedTemporaryFile(delete=False, suffix=".wav")
    audio_segment.export(tmp_file.name, format="wav")
    return tmp_file.name


@app.get("/generate")
def generate(seq_len: int = 64):
    with torch.no_grad():
        tokens = music_model.generate(seq_len=seq_len, device=DEVICE)

    return {
        "seq_len": seq_len,
        "output": tokens.cpu().tolist()
    }