saumya-pailwan commited on
Commit
27ce3bd
·
verified ·
1 Parent(s): 1e25ac3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +315 -4
app.py CHANGED
@@ -1,7 +1,318 @@
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces # Enables ZeroGPU on Hugging Face
3
  import gradio as gr
4
+ import torch
5
+ from dataclasses import asdict
6
+ from mido import MidiFile, tempo2bpm
7
 
8
+ from transformers import AutoModelForCausalLM
9
+ from anticipation.sample import generate
10
+ from anticipation.convert import events_to_midi, midi_to_events
11
+ from anticipation.tokenize import extract_instruments
12
+ from anticipation import ops
13
 
14
+ from pyharp.core import ModelCard, build_endpoint
15
+ from pyharp.labels import LabelList
16
+
17
+ # ---------------------------------------------------------
18
+ # Model Choices
19
+ # ---------------------------------------------------------
20
+ SMALL_MODEL = "stanford-crfm/music-small-800k"
21
+ MEDIUM_MODEL = "stanford-crfm/music-medium-800k"
22
+ LARGE_MODEL = "stanford-crfm/music-large-800k"
23
+
24
+ # ---------------------------------------------------------
25
+ # Model Card (for HARP)
26
+ # ---------------------------------------------------------
27
+ model_card = ModelCard(
28
+ name="Anticipatory Music Transformer",
29
+ description=(
30
+ "Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
31
+ "Input: a MIDI file with a short accompaniment followed by a melody line. "
32
+ "Output: a new MIDI file with extended accompaniment matching the melody. "
33
+ "Use the sliders to choose model size and how much of the song is used as context."
34
+ ),
35
+ author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
36
+ tags=["midi", "generation", "accompaniment"]
37
+ )
38
+
39
+ # ---------------------------------------------------------
40
+ # Model Cache Loader
41
+ # ---------------------------------------------------------
42
+ _model_cache = {}
43
+
44
+ def load_amt_model(model_choice: str):
45
+ """Loads and caches the AMT model inside the worker process."""
46
+ if model_choice in _model_cache:
47
+ return _model_cache[model_choice]
48
+
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ print(f"Loading {model_choice} ...")
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ model_choice,
54
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
55
+ low_cpu_mem_usage=True
56
+ ).to(device)
57
+
58
+ _model_cache[model_choice] = model
59
+ return model
60
+
61
+ # ---------------------------------------------------------
62
+ # Melody Detection (Auto Program)
63
+ # ---------------------------------------------------------
64
+ def find_melody_program(mid, debug=False):
65
+ """Detect melody track’s program number using pitch, density, and duration heuristics."""
66
+ track_stats = []
67
+ total_duration = 0
68
+
69
+ for i, track in enumerate(mid.tracks):
70
+ pitches, times = [], []
71
+ current_time = 0
72
+ current_program = None
73
+ track_note_count = 0
74
+
75
+ for msg in track:
76
+ if msg.type not in ("note_on", "program_change"):
77
+ continue
78
+
79
+ current_time += msg.time
80
+ if msg.type == "program_change":
81
+ current_program = msg.program
82
+ continue
83
+
84
+ if msg.velocity > 0:
85
+ pitches.append(msg.note)
86
+ times.append(current_time)
87
+ track_note_count += 1
88
+ if track_note_count >= 100:
89
+ break
90
+
91
+ if not pitches:
92
+ continue
93
+
94
+ track_duration = max(times) - min(times)
95
+ total_duration = max(total_duration, current_time)
96
+
97
+ mean_pitch = sum(pitches) / len(pitches)
98
+ polyphony = len(set(pitches)) / len(pitches)
99
+ coverage = track_duration / total_duration if total_duration > 0 else 0
100
+
101
+ track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage))
102
+
103
+ if not track_stats:
104
+ return None, False
105
+
106
+ if len(track_stats) == 1:
107
+ prog = track_stats[0][3]
108
+ if debug:
109
+ print(f"Single-track MIDI detected — using program {prog or 'None'}")
110
+ return prog, prog is not None
111
+
112
+ candidates = [t for t in track_stats if t[3] is not None and t[3] > 0]
113
+ has_valid_programs = len(candidates) > 0
114
+ if not candidates:
115
+ candidates = track_stats
116
+
117
+ max_notes = max(t[2] for t in candidates)
118
+ max_pitch = max(t[1] for t in candidates)
119
+ min_pitch = min(t[1] for t in candidates)
120
+ pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1
121
+
122
+ best_score = -1
123
+ best_program = None
124
+ best_track = None
125
+ best_pitch = None
126
+
127
+ for t in candidates:
128
+ idx, pitch, notes, prog, poly, coverage = t
129
+ pitch_norm = (pitch - min_pitch) / pitch_span
130
+ notes_norm = notes / max_notes
131
+ score = (pitch_norm * 0.35) + (notes_norm * 0.35) + (coverage * 0.30)
132
+
133
+ if poly < 0.15:
134
+ score *= 0.95
135
+ if 55 <= pitch <= 75:
136
+ score *= 1.1
137
+ if notes >= 30:
138
+ score *= 1.05
139
+ if coverage > 0.7:
140
+ score *= 1.15
141
+
142
+ if score > best_score:
143
+ best_score = score
144
+ best_program = prog
145
+ best_track = idx
146
+ best_pitch = pitch
147
+
148
+ return best_program, has_valid_programs
149
+
150
+
151
+ def auto_extract_melody(mid, debug=False):
152
+ """Extract melody events from MIDI object (optimized for direct input)."""
153
+ events = midi_to_events(mid)
154
+ melody_program, has_valid_program = find_melody_program(mid, debug=debug)
155
+
156
+ if not has_valid_program or melody_program is None or melody_program == 0:
157
+ if debug:
158
+ print("No valid program changes found; using all events as melody.")
159
+ return events, events
160
+
161
+ events, melody = extract_instruments(events, [melody_program])
162
+
163
+ if len(melody) == 0:
164
+ if debug:
165
+ print("No melody events found for program — reverting to all events.")
166
+ return events, events
167
+
168
+ if debug:
169
+ print(f"Extracted {len(melody)} melody events from program {melody_program}")
170
+
171
+ return events, melody
172
+
173
+ # ---------------------------------------------------------
174
+ # Core Generation Logic
175
+ # ---------------------------------------------------------
176
+ spaces.GPU
177
+ def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
178
+ """Generate accompaniment conditioned on context history and melody."""
179
+ model = load_amt_model(model_choice)
180
+
181
+ mid = MidiFile(midi_path)
182
+ print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
183
+
184
+ all_events, melody = auto_extract_melody(mid, debug=True)
185
+ if len(melody) == 0:
186
+ melody = all_events
187
+
188
+ mid_time = mid.length or 0
189
+ ops_time = ops.max_time(all_events, seconds=True)
190
+ total_time = round(max(mid_time, ops_time))
191
+
192
+ melody_history = ops.clip(all_events, 0, history_length, clip_duration=False)
193
+ melody_future = ops.clip(melody, history_length, total_time, clip_duration=False)
194
+
195
+ accompaniment = generate(
196
+ model,
197
+ start_time=history_length,
198
+ end_time=total_time,
199
+ inputs=melody_history,
200
+ controls=melody_future,
201
+ top_p=0.95,
202
+ debug=False
203
+ )
204
+
205
+ output_events = ops.clip(
206
+ ops.combine(accompaniment, melody),
207
+ 0,
208
+ total_time,
209
+ clip_duration=True
210
+ )
211
+
212
+ print(f"Generating from {history_length:.2f}s → {total_time:.2f}s "
213
+ f"(duration = {total_time - history_length:.2f}s)")
214
+
215
+ output_midi = "generated_accompaniment_huggingface.mid"
216
+ events_to_midi(output_events).save(output_midi)
217
+ return output_midi, None
218
+
219
+ # ---------------------------------------------------------
220
+ # HARP process_fn — with tempo-aware bar→seconds conversion
221
+ # ---------------------------------------------------------
222
+ def process_fn(input_midi_path, model_choice, history_length, use_bars):
223
+ """Convert bars to seconds (tempo-aware) before generation."""
224
+ if use_bars:
225
+ BEATS_PER_BAR = 4
226
+ bpm = 120
227
+ try:
228
+ mid = MidiFile(input_midi_path)
229
+ for tr in mid.tracks:
230
+ for msg in tr:
231
+ if msg.type == "set_tempo":
232
+ bpm = round(tempo2bpm(msg.tempo))
233
+ break
234
+ except Exception:
235
+ pass
236
+ seconds_per_bar = (60.0 / bpm) * BEATS_PER_BAR
237
+ history_length = history_length * seconds_per_bar
238
+ print(f"[INFO] Converted to {history_length:.2f}s from bars @ {bpm} BPM")
239
+
240
+ output_midi, error_message = generate_accompaniment(
241
+ input_midi_path,
242
+ model_choice,
243
+ float(history_length)
244
+ )
245
+
246
+ if error_message:
247
+ return {"message": error_message}, None
248
+
249
+ return asdict(LabelList()), output_midi
250
+
251
+ # ---------------------------------------------------------
252
+ # Gradio + HARP UI
253
+ # ---------------------------------------------------------
254
+ with gr.Blocks() as demo:
255
+ gr.Markdown("## 🎼 Anticipatory Music Transformer")
256
+
257
+ input_midi = gr.File(file_types=[".mid", ".midi"], label="Input MIDI File", type="filepath").harp_required(True)
258
+
259
+ model_dropdown = gr.Dropdown(
260
+ choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
261
+ value=MEDIUM_MODEL,
262
+ label="Select AMT Model (Faster vs. Higher Quality)"
263
+ )
264
+
265
+ with gr.Row():
266
+ history_slider = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Select History Length (seconds)")
267
+ use_bars = gr.Checkbox(
268
+ value=False,
269
+ label="Use Musical Bars Instead of Seconds",
270
+ info="If enabled, context length is interpreted as bars based on the MIDI tempo."
271
+ )
272
+
273
+ def get_midi_tempo(midi_path):
274
+ try:
275
+ mid = MidiFile(midi_path)
276
+ for track in mid.tracks:
277
+ for msg in track:
278
+ if msg.type == "set_tempo":
279
+ return round(tempo2bpm(msg.tempo))
280
+ except Exception:
281
+ pass
282
+ return 120
283
+
284
+ BEATS_PER_BAR = 4
285
+ def bars_to_seconds(bars, bpm, beats_per_bar=BEATS_PER_BAR):
286
+ return bars * beats_per_bar * (60.0 / bpm)
287
+
288
+ def toggle_label_and_range(use_bars):
289
+ if use_bars:
290
+ return gr.update(label="Select History Length (bars)", minimum=2, maximum=8, step=2, value=4)
291
+ else:
292
+ return gr.update(label="Select History Length (seconds)", minimum=1, maximum=10, step=1, value=5)
293
+
294
+ def update_bar_label(history_value, midi_path, use_bars):
295
+ if not use_bars:
296
+ return gr.update(label="Select History Length (seconds)")
297
+ bpm = get_midi_tempo(midi_path)
298
+ approx_sec = bars_to_seconds(history_value, bpm)
299
+ return gr.update(label=f"Select History Length ({history_value} bars ≈ {approx_sec:.1f}s @ {bpm} BPM)")
300
+
301
+ use_bars.change(fn=toggle_label_and_range, inputs=use_bars, outputs=history_slider)
302
+ history_slider.change(fn=update_bar_label, inputs=[history_slider, input_midi, use_bars],
303
+ outputs=history_slider, queue=False)
304
+
305
+ output_labels = gr.JSON(label="Labels / Metadata")
306
+ output_midi = gr.File(file_types=[".mid", ".midi"], label="Generated MIDI Output", type="filepath")
307
+
308
+ _ = build_endpoint(
309
+ model_card=model_card,
310
+ input_components=[input_midi, model_dropdown, history_slider, use_bars],
311
+ output_components=[output_labels, output_midi],
312
+ process_fn=process_fn
313
+ )
314
+
315
+ # ---------------------------------------------------------
316
+ # Launch App
317
+ # ---------------------------------------------------------
318
+ demo.launch(share=True, show_error=True, debug=True)