Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,318 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import spaces # Enables ZeroGPU on Hugging Face
|
| 3 |
import gradio as gr
|
| 4 |
+
import torch
|
| 5 |
+
from dataclasses import asdict
|
| 6 |
+
from mido import MidiFile, tempo2bpm
|
| 7 |
|
| 8 |
+
from transformers import AutoModelForCausalLM
|
| 9 |
+
from anticipation.sample import generate
|
| 10 |
+
from anticipation.convert import events_to_midi, midi_to_events
|
| 11 |
+
from anticipation.tokenize import extract_instruments
|
| 12 |
+
from anticipation import ops
|
| 13 |
|
| 14 |
+
from pyharp.core import ModelCard, build_endpoint
|
| 15 |
+
from pyharp.labels import LabelList
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------
|
| 18 |
+
# Model Choices
|
| 19 |
+
# ---------------------------------------------------------
|
| 20 |
+
SMALL_MODEL = "stanford-crfm/music-small-800k"
|
| 21 |
+
MEDIUM_MODEL = "stanford-crfm/music-medium-800k"
|
| 22 |
+
LARGE_MODEL = "stanford-crfm/music-large-800k"
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------
|
| 25 |
+
# Model Card (for HARP)
|
| 26 |
+
# ---------------------------------------------------------
|
| 27 |
+
model_card = ModelCard(
|
| 28 |
+
name="Anticipatory Music Transformer",
|
| 29 |
+
description=(
|
| 30 |
+
"Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
|
| 31 |
+
"Input: a MIDI file with a short accompaniment followed by a melody line. "
|
| 32 |
+
"Output: a new MIDI file with extended accompaniment matching the melody. "
|
| 33 |
+
"Use the sliders to choose model size and how much of the song is used as context."
|
| 34 |
+
),
|
| 35 |
+
author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
|
| 36 |
+
tags=["midi", "generation", "accompaniment"]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------
|
| 40 |
+
# Model Cache Loader
|
| 41 |
+
# ---------------------------------------------------------
|
| 42 |
+
_model_cache = {}
|
| 43 |
+
|
| 44 |
+
def load_amt_model(model_choice: str):
|
| 45 |
+
"""Loads and caches the AMT model inside the worker process."""
|
| 46 |
+
if model_choice in _model_cache:
|
| 47 |
+
return _model_cache[model_choice]
|
| 48 |
+
|
| 49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
|
| 51 |
+
print(f"Loading {model_choice} ...")
|
| 52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 53 |
+
model_choice,
|
| 54 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 55 |
+
low_cpu_mem_usage=True
|
| 56 |
+
).to(device)
|
| 57 |
+
|
| 58 |
+
_model_cache[model_choice] = model
|
| 59 |
+
return model
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------
|
| 62 |
+
# Melody Detection (Auto Program)
|
| 63 |
+
# ---------------------------------------------------------
|
| 64 |
+
def find_melody_program(mid, debug=False):
|
| 65 |
+
"""Detect melody track’s program number using pitch, density, and duration heuristics."""
|
| 66 |
+
track_stats = []
|
| 67 |
+
total_duration = 0
|
| 68 |
+
|
| 69 |
+
for i, track in enumerate(mid.tracks):
|
| 70 |
+
pitches, times = [], []
|
| 71 |
+
current_time = 0
|
| 72 |
+
current_program = None
|
| 73 |
+
track_note_count = 0
|
| 74 |
+
|
| 75 |
+
for msg in track:
|
| 76 |
+
if msg.type not in ("note_on", "program_change"):
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
current_time += msg.time
|
| 80 |
+
if msg.type == "program_change":
|
| 81 |
+
current_program = msg.program
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
if msg.velocity > 0:
|
| 85 |
+
pitches.append(msg.note)
|
| 86 |
+
times.append(current_time)
|
| 87 |
+
track_note_count += 1
|
| 88 |
+
if track_note_count >= 100:
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
if not pitches:
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
track_duration = max(times) - min(times)
|
| 95 |
+
total_duration = max(total_duration, current_time)
|
| 96 |
+
|
| 97 |
+
mean_pitch = sum(pitches) / len(pitches)
|
| 98 |
+
polyphony = len(set(pitches)) / len(pitches)
|
| 99 |
+
coverage = track_duration / total_duration if total_duration > 0 else 0
|
| 100 |
+
|
| 101 |
+
track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage))
|
| 102 |
+
|
| 103 |
+
if not track_stats:
|
| 104 |
+
return None, False
|
| 105 |
+
|
| 106 |
+
if len(track_stats) == 1:
|
| 107 |
+
prog = track_stats[0][3]
|
| 108 |
+
if debug:
|
| 109 |
+
print(f"Single-track MIDI detected — using program {prog or 'None'}")
|
| 110 |
+
return prog, prog is not None
|
| 111 |
+
|
| 112 |
+
candidates = [t for t in track_stats if t[3] is not None and t[3] > 0]
|
| 113 |
+
has_valid_programs = len(candidates) > 0
|
| 114 |
+
if not candidates:
|
| 115 |
+
candidates = track_stats
|
| 116 |
+
|
| 117 |
+
max_notes = max(t[2] for t in candidates)
|
| 118 |
+
max_pitch = max(t[1] for t in candidates)
|
| 119 |
+
min_pitch = min(t[1] for t in candidates)
|
| 120 |
+
pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1
|
| 121 |
+
|
| 122 |
+
best_score = -1
|
| 123 |
+
best_program = None
|
| 124 |
+
best_track = None
|
| 125 |
+
best_pitch = None
|
| 126 |
+
|
| 127 |
+
for t in candidates:
|
| 128 |
+
idx, pitch, notes, prog, poly, coverage = t
|
| 129 |
+
pitch_norm = (pitch - min_pitch) / pitch_span
|
| 130 |
+
notes_norm = notes / max_notes
|
| 131 |
+
score = (pitch_norm * 0.35) + (notes_norm * 0.35) + (coverage * 0.30)
|
| 132 |
+
|
| 133 |
+
if poly < 0.15:
|
| 134 |
+
score *= 0.95
|
| 135 |
+
if 55 <= pitch <= 75:
|
| 136 |
+
score *= 1.1
|
| 137 |
+
if notes >= 30:
|
| 138 |
+
score *= 1.05
|
| 139 |
+
if coverage > 0.7:
|
| 140 |
+
score *= 1.15
|
| 141 |
+
|
| 142 |
+
if score > best_score:
|
| 143 |
+
best_score = score
|
| 144 |
+
best_program = prog
|
| 145 |
+
best_track = idx
|
| 146 |
+
best_pitch = pitch
|
| 147 |
+
|
| 148 |
+
return best_program, has_valid_programs
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def auto_extract_melody(mid, debug=False):
|
| 152 |
+
"""Extract melody events from MIDI object (optimized for direct input)."""
|
| 153 |
+
events = midi_to_events(mid)
|
| 154 |
+
melody_program, has_valid_program = find_melody_program(mid, debug=debug)
|
| 155 |
+
|
| 156 |
+
if not has_valid_program or melody_program is None or melody_program == 0:
|
| 157 |
+
if debug:
|
| 158 |
+
print("No valid program changes found; using all events as melody.")
|
| 159 |
+
return events, events
|
| 160 |
+
|
| 161 |
+
events, melody = extract_instruments(events, [melody_program])
|
| 162 |
+
|
| 163 |
+
if len(melody) == 0:
|
| 164 |
+
if debug:
|
| 165 |
+
print("No melody events found for program — reverting to all events.")
|
| 166 |
+
return events, events
|
| 167 |
+
|
| 168 |
+
if debug:
|
| 169 |
+
print(f"Extracted {len(melody)} melody events from program {melody_program}")
|
| 170 |
+
|
| 171 |
+
return events, melody
|
| 172 |
+
|
| 173 |
+
# ---------------------------------------------------------
|
| 174 |
+
# Core Generation Logic
|
| 175 |
+
# ---------------------------------------------------------
|
| 176 |
+
spaces.GPU
|
| 177 |
+
def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
|
| 178 |
+
"""Generate accompaniment conditioned on context history and melody."""
|
| 179 |
+
model = load_amt_model(model_choice)
|
| 180 |
+
|
| 181 |
+
mid = MidiFile(midi_path)
|
| 182 |
+
print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
|
| 183 |
+
|
| 184 |
+
all_events, melody = auto_extract_melody(mid, debug=True)
|
| 185 |
+
if len(melody) == 0:
|
| 186 |
+
melody = all_events
|
| 187 |
+
|
| 188 |
+
mid_time = mid.length or 0
|
| 189 |
+
ops_time = ops.max_time(all_events, seconds=True)
|
| 190 |
+
total_time = round(max(mid_time, ops_time))
|
| 191 |
+
|
| 192 |
+
melody_history = ops.clip(all_events, 0, history_length, clip_duration=False)
|
| 193 |
+
melody_future = ops.clip(melody, history_length, total_time, clip_duration=False)
|
| 194 |
+
|
| 195 |
+
accompaniment = generate(
|
| 196 |
+
model,
|
| 197 |
+
start_time=history_length,
|
| 198 |
+
end_time=total_time,
|
| 199 |
+
inputs=melody_history,
|
| 200 |
+
controls=melody_future,
|
| 201 |
+
top_p=0.95,
|
| 202 |
+
debug=False
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
output_events = ops.clip(
|
| 206 |
+
ops.combine(accompaniment, melody),
|
| 207 |
+
0,
|
| 208 |
+
total_time,
|
| 209 |
+
clip_duration=True
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
print(f"Generating from {history_length:.2f}s → {total_time:.2f}s "
|
| 213 |
+
f"(duration = {total_time - history_length:.2f}s)")
|
| 214 |
+
|
| 215 |
+
output_midi = "generated_accompaniment_huggingface.mid"
|
| 216 |
+
events_to_midi(output_events).save(output_midi)
|
| 217 |
+
return output_midi, None
|
| 218 |
+
|
| 219 |
+
# ---------------------------------------------------------
|
| 220 |
+
# HARP process_fn — with tempo-aware bar→seconds conversion
|
| 221 |
+
# ---------------------------------------------------------
|
| 222 |
+
def process_fn(input_midi_path, model_choice, history_length, use_bars):
|
| 223 |
+
"""Convert bars to seconds (tempo-aware) before generation."""
|
| 224 |
+
if use_bars:
|
| 225 |
+
BEATS_PER_BAR = 4
|
| 226 |
+
bpm = 120
|
| 227 |
+
try:
|
| 228 |
+
mid = MidiFile(input_midi_path)
|
| 229 |
+
for tr in mid.tracks:
|
| 230 |
+
for msg in tr:
|
| 231 |
+
if msg.type == "set_tempo":
|
| 232 |
+
bpm = round(tempo2bpm(msg.tempo))
|
| 233 |
+
break
|
| 234 |
+
except Exception:
|
| 235 |
+
pass
|
| 236 |
+
seconds_per_bar = (60.0 / bpm) * BEATS_PER_BAR
|
| 237 |
+
history_length = history_length * seconds_per_bar
|
| 238 |
+
print(f"[INFO] Converted to {history_length:.2f}s from bars @ {bpm} BPM")
|
| 239 |
+
|
| 240 |
+
output_midi, error_message = generate_accompaniment(
|
| 241 |
+
input_midi_path,
|
| 242 |
+
model_choice,
|
| 243 |
+
float(history_length)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if error_message:
|
| 247 |
+
return {"message": error_message}, None
|
| 248 |
+
|
| 249 |
+
return asdict(LabelList()), output_midi
|
| 250 |
+
|
| 251 |
+
# ---------------------------------------------------------
|
| 252 |
+
# Gradio + HARP UI
|
| 253 |
+
# ---------------------------------------------------------
|
| 254 |
+
with gr.Blocks() as demo:
|
| 255 |
+
gr.Markdown("## 🎼 Anticipatory Music Transformer")
|
| 256 |
+
|
| 257 |
+
input_midi = gr.File(file_types=[".mid", ".midi"], label="Input MIDI File", type="filepath").harp_required(True)
|
| 258 |
+
|
| 259 |
+
model_dropdown = gr.Dropdown(
|
| 260 |
+
choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
|
| 261 |
+
value=MEDIUM_MODEL,
|
| 262 |
+
label="Select AMT Model (Faster vs. Higher Quality)"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
with gr.Row():
|
| 266 |
+
history_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Select History Length (seconds)")
|
| 267 |
+
use_bars = gr.Checkbox(
|
| 268 |
+
value=False,
|
| 269 |
+
label="Use Musical Bars Instead of Seconds",
|
| 270 |
+
info="If enabled, context length is interpreted as bars based on the MIDI tempo."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def get_midi_tempo(midi_path):
|
| 274 |
+
try:
|
| 275 |
+
mid = MidiFile(midi_path)
|
| 276 |
+
for track in mid.tracks:
|
| 277 |
+
for msg in track:
|
| 278 |
+
if msg.type == "set_tempo":
|
| 279 |
+
return round(tempo2bpm(msg.tempo))
|
| 280 |
+
except Exception:
|
| 281 |
+
pass
|
| 282 |
+
return 120
|
| 283 |
+
|
| 284 |
+
BEATS_PER_BAR = 4
|
| 285 |
+
def bars_to_seconds(bars, bpm, beats_per_bar=BEATS_PER_BAR):
|
| 286 |
+
return bars * beats_per_bar * (60.0 / bpm)
|
| 287 |
+
|
| 288 |
+
def toggle_label_and_range(use_bars):
|
| 289 |
+
if use_bars:
|
| 290 |
+
return gr.update(label="Select History Length (bars)", minimum=2, maximum=8, step=2, value=4)
|
| 291 |
+
else:
|
| 292 |
+
return gr.update(label="Select History Length (seconds)", minimum=1, maximum=10, step=1, value=5)
|
| 293 |
+
|
| 294 |
+
def update_bar_label(history_value, midi_path, use_bars):
|
| 295 |
+
if not use_bars:
|
| 296 |
+
return gr.update(label="Select History Length (seconds)")
|
| 297 |
+
bpm = get_midi_tempo(midi_path)
|
| 298 |
+
approx_sec = bars_to_seconds(history_value, bpm)
|
| 299 |
+
return gr.update(label=f"Select History Length ({history_value} bars ≈ {approx_sec:.1f}s @ {bpm} BPM)")
|
| 300 |
+
|
| 301 |
+
use_bars.change(fn=toggle_label_and_range, inputs=use_bars, outputs=history_slider)
|
| 302 |
+
history_slider.change(fn=update_bar_label, inputs=[history_slider, input_midi, use_bars],
|
| 303 |
+
outputs=history_slider, queue=False)
|
| 304 |
+
|
| 305 |
+
output_labels = gr.JSON(label="Labels / Metadata")
|
| 306 |
+
output_midi = gr.File(file_types=[".mid", ".midi"], label="Generated MIDI Output", type="filepath")
|
| 307 |
+
|
| 308 |
+
_ = build_endpoint(
|
| 309 |
+
model_card=model_card,
|
| 310 |
+
input_components=[input_midi, model_dropdown, history_slider, use_bars],
|
| 311 |
+
output_components=[output_labels, output_midi],
|
| 312 |
+
process_fn=process_fn
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# ---------------------------------------------------------
|
| 316 |
+
# Launch App
|
| 317 |
+
# ---------------------------------------------------------
|
| 318 |
+
demo.launch(share=True, show_error=True, debug=True)
|