Spaces:
Sleeping
Sleeping
Commit ·
7deef83
1
Parent(s): 6438498
Upload 12 files
Browse files- .gitattributes +3 -0
- app.py +258 -0
- model_run.py +376 -0
- packages.txt +1 -0
- requirements.txt +4 -0
- sf2/.DS_Store +0 -0
- sf2/piano.sf2 +3 -0
- temp/.DS_Store +0 -0
- temp/output.mid +0 -0
- temp/output.wav +3 -0
- temp/output_fx.wav +3 -0
- utils.py +125 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
sf2/piano.sf2 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
temp/output_fx.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
temp/output.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import nullcontext
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
from utils import TOKENIZER, Dataset
|
| 4 |
+
from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter
|
| 5 |
+
from pedalboard.io import AudioFile
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import subprocess
|
| 8 |
+
import pretty_midi
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import time
|
| 11 |
+
import copy
|
| 12 |
+
import types
|
| 13 |
+
import torch
|
| 14 |
+
import random
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
torch.backends.cudnn.benchmark = True
|
| 18 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 20 |
+
|
| 21 |
+
in_space = os.getenv("SYSTEM") == "spaces"
|
| 22 |
+
|
| 23 |
+
n_layer = 8
|
| 24 |
+
n_embd = 768
|
| 25 |
+
ctx_len = 1536
|
| 26 |
+
top_k = 16
|
| 27 |
+
|
| 28 |
+
os.environ['RWKV_FLOAT_MODE'] = 'fp32'
|
| 29 |
+
os.environ['RWKV_RUN_DEVICE'] = 'cpu'
|
| 30 |
+
model_type = 'RWKV'
|
| 31 |
+
|
| 32 |
+
MODEL_NAME = 'model'
|
| 33 |
+
LENGTH_PER_TRIAL = round((2000) / 13) * 13
|
| 34 |
+
TEMPERATURE = 1.0
|
| 35 |
+
|
| 36 |
+
from model_run import RWKV_RNN
|
| 37 |
+
model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len)
|
| 38 |
+
tokenizer = TOKENIZER()
|
| 39 |
+
|
| 40 |
+
temp_dir = 'temp'
|
| 41 |
+
if not os.path.exists(temp_dir):
|
| 42 |
+
os.makedirs(temp_dir)
|
| 43 |
+
|
| 44 |
+
def clear_midi(dir):
|
| 45 |
+
for file in os.listdir(dir):
|
| 46 |
+
if file.endswith('.mid'):
|
| 47 |
+
os.remove(os.path.join(dir, file))
|
| 48 |
+
|
| 49 |
+
clear_midi(temp_dir)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
ctx_seed = "000000000000\n"
|
| 53 |
+
ctx = tokenizer.encode(ctx_seed)
|
| 54 |
+
src_len = len(ctx)
|
| 55 |
+
src_ctx = ctx.copy()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def generate_midi(LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, TEMPERATURE, top_k, tokenizer, ctx_seed, bpm):
|
| 59 |
+
midi_seq = []
|
| 60 |
+
|
| 61 |
+
for TRIAL in range(1):
|
| 62 |
+
t_begin = time.time_ns()
|
| 63 |
+
|
| 64 |
+
if TRIAL > 0:
|
| 65 |
+
midi_seq.append("\n")
|
| 66 |
+
|
| 67 |
+
ctx = src_ctx.copy()
|
| 68 |
+
model.clear()
|
| 69 |
+
midi_tokens = []
|
| 70 |
+
|
| 71 |
+
if TRIAL == 0:
|
| 72 |
+
init_state = types.SimpleNamespace()
|
| 73 |
+
for i in range(src_len):
|
| 74 |
+
x = ctx[:i+1]
|
| 75 |
+
if i == src_len - 1:
|
| 76 |
+
init_state.out = model.run(x)
|
| 77 |
+
else:
|
| 78 |
+
model.run(x)
|
| 79 |
+
model.save(init_state)
|
| 80 |
+
else:
|
| 81 |
+
model.load(init_state)
|
| 82 |
+
|
| 83 |
+
midi_seq.append(ctx_seed)
|
| 84 |
+
|
| 85 |
+
for i in range(src_len, src_len + LENGTH_PER_TRIAL):
|
| 86 |
+
x = ctx[:i+1]
|
| 87 |
+
x = x[-ctx_len:]
|
| 88 |
+
|
| 89 |
+
if i == src_len:
|
| 90 |
+
out = copy.deepcopy(init_state.out)
|
| 91 |
+
else:
|
| 92 |
+
out = model.run(x)
|
| 93 |
+
|
| 94 |
+
char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_k=top_k).item()
|
| 95 |
+
midi_tokens.append(char)
|
| 96 |
+
|
| 97 |
+
if len(midi_tokens) > 2:
|
| 98 |
+
midi_tokens.pop(0)
|
| 99 |
+
|
| 100 |
+
if midi_tokens == [11, 10]: # stop token pattern
|
| 101 |
+
break
|
| 102 |
+
|
| 103 |
+
midi_seq.append(tokenizer.decode([int(char)]))
|
| 104 |
+
|
| 105 |
+
if midi_tokens != [11, 10]:
|
| 106 |
+
ctx += [char]
|
| 107 |
+
|
| 108 |
+
t_end = time.time_ns()
|
| 109 |
+
|
| 110 |
+
trim_seq = "".join(midi_seq)
|
| 111 |
+
events = trim_seq.split("\n")
|
| 112 |
+
|
| 113 |
+
midi_events = []
|
| 114 |
+
sequence = []
|
| 115 |
+
rndm_num = 895645
|
| 116 |
+
|
| 117 |
+
for event in events:
|
| 118 |
+
if event.strip() == "":
|
| 119 |
+
midi_events.append(sequence)
|
| 120 |
+
sequence = []
|
| 121 |
+
rndm_num = random.randint(100000, 999999)
|
| 122 |
+
try:
|
| 123 |
+
pitch = int(event[0:2])
|
| 124 |
+
velocity = int(event[2:4])
|
| 125 |
+
start = int(event[4:8])
|
| 126 |
+
end = int(event[8:12])
|
| 127 |
+
except ValueError:
|
| 128 |
+
pitch = 0
|
| 129 |
+
velocity = 0
|
| 130 |
+
start = 0
|
| 131 |
+
end = 0
|
| 132 |
+
|
| 133 |
+
sequence.append({'file_name': f'rwkv_{rndm_num}', 'pitch': pitch, 'velocity': velocity, 'start': start, 'end': end})
|
| 134 |
+
|
| 135 |
+
if sequence:
|
| 136 |
+
midi_events.append(sequence)
|
| 137 |
+
|
| 138 |
+
midi_events = pd.DataFrame([pd.Series(event) for sequence in midi_events for event in sequence])
|
| 139 |
+
midi_events = midi_events[['file_name', 'pitch', 'velocity', 'start', 'end']]
|
| 140 |
+
midi_events = midi_events.sort_values(by=['file_name', 'start']).reset_index(drop=True)
|
| 141 |
+
midi_events = midi_events[(midi_events['start'] < 3072) & (midi_events['end'] <= 3072)]
|
| 142 |
+
|
| 143 |
+
for file_name, events in midi_events.groupby('file_name'):
|
| 144 |
+
midi_obj = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96)
|
| 145 |
+
instrument = pretty_midi.Instrument(0)
|
| 146 |
+
midi_obj.instruments.append(instrument)
|
| 147 |
+
|
| 148 |
+
for _, event in events.iterrows():
|
| 149 |
+
note = pretty_midi.Note(
|
| 150 |
+
pitch=event['pitch'],
|
| 151 |
+
velocity=event['velocity'],
|
| 152 |
+
start=midi_obj.tick_to_time(event['start']),
|
| 153 |
+
end=midi_obj.tick_to_time(event['end'])
|
| 154 |
+
)
|
| 155 |
+
instrument.notes.append(note)
|
| 156 |
+
|
| 157 |
+
midi_path = os.path.join(temp_dir, 'output.mid')
|
| 158 |
+
midi_obj.write(midi_path)
|
| 159 |
+
|
| 160 |
+
return midi_path
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def render_wav(midi_file, uploaded_sf2=None):
|
| 164 |
+
sf2_dir = 'sf2'
|
| 165 |
+
audio_format = 's16'
|
| 166 |
+
sample_rate = '44100'
|
| 167 |
+
gain = '2.0'
|
| 168 |
+
|
| 169 |
+
if uploaded_sf2:
|
| 170 |
+
sf2_file = uploaded_sf2
|
| 171 |
+
else:
|
| 172 |
+
sf2_files = [f for f in os.listdir(os.path.join(sf2_dir)) if f.endswith('.sf2')]
|
| 173 |
+
if not sf2_files:
|
| 174 |
+
raise ValueError("No SoundFont (.sf2) file found in directory.")
|
| 175 |
+
sf2_file = os.path.join(sf2_dir, random.choice(sf2_files))
|
| 176 |
+
|
| 177 |
+
print(f"Using SoundFont: {sf2_file}")
|
| 178 |
+
output_wav = os.path.join(temp_dir, 'output.wav')
|
| 179 |
+
|
| 180 |
+
with open(os.devnull, 'w') as devnull:
|
| 181 |
+
command = [
|
| 182 |
+
'fluidsynth', '-ni', sf2_file, midi_file, '-F', output_wav, '-r', str(sample_rate),
|
| 183 |
+
'-o', f'audio.file.format={audio_format}', '-g', str(gain)
|
| 184 |
+
]
|
| 185 |
+
subprocess.call(command, stdout=devnull, stderr=devnull)
|
| 186 |
+
|
| 187 |
+
return output_wav
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None):
|
| 191 |
+
midi_events = generate_midi(
|
| 192 |
+
LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, temperature, top_k,
|
| 193 |
+
tokenizer, ctx_seed, bpm
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
midi_file = 'temp/output.mid'
|
| 197 |
+
wav_raw = render_wav(midi_file, uploaded_sf2)
|
| 198 |
+
wav_fx = os.path.join(temp_dir, 'output_fx.wav')
|
| 199 |
+
|
| 200 |
+
sfx_settings = [
|
| 201 |
+
{
|
| 202 |
+
'board': Pedalboard([
|
| 203 |
+
Reverb(room_size=0.50, wet_level=0.40, dry_level=0.70, width=1.0),
|
| 204 |
+
Compressor(threshold_db=-3.0, ratio=8.0, attack_ms=0.0, release_ms=300.0),
|
| 205 |
+
])
|
| 206 |
+
}
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
for setting in sfx_settings:
|
| 210 |
+
board = setting['board']
|
| 211 |
+
|
| 212 |
+
with AudioFile(wav_raw) as f:
|
| 213 |
+
with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o:
|
| 214 |
+
while f.tell() < f.frames:
|
| 215 |
+
chunk = f.read(int(f.samplerate))
|
| 216 |
+
effected = board(chunk, f.samplerate, reset=False)
|
| 217 |
+
o.write(effected)
|
| 218 |
+
|
| 219 |
+
return midi_file, wav_fx
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
custom_css = """
|
| 223 |
+
#generate-btn {
|
| 224 |
+
background-color: #6366f1 !important;
|
| 225 |
+
color: white !important;
|
| 226 |
+
border: none !important;
|
| 227 |
+
font-size: 16px;
|
| 228 |
+
padding: 10px 20px;
|
| 229 |
+
border-radius: 5px;
|
| 230 |
+
cursor: pointer;
|
| 231 |
+
}
|
| 232 |
+
#generate-btn:hover {
|
| 233 |
+
background-color: #4f51c5 !important;
|
| 234 |
+
}
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
with gr.Blocks(css=custom_css, theme="soft") as iface:
|
| 238 |
+
gr.Markdown("<h1 style='font-weight: bold; text-align: center;'>Pop-K</h1>")
|
| 239 |
+
gr.Markdown("<p style='text-align:center;'>Pop-K is a small RWKV model that generates pop melodies in C major and A minor.</p>")
|
| 240 |
+
|
| 241 |
+
with gr.Row():
|
| 242 |
+
with gr.Column(scale=1):
|
| 243 |
+
bpm = gr.Slider(minimum=50, maximum=200, step=1, value=120, label="BPM")
|
| 244 |
+
temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.01, value=1.0, label="Temperature")
|
| 245 |
+
top_k = gr.Slider(minimum=1, maximum=32, step=1, value=16, label="Top-K")
|
| 246 |
+
|
| 247 |
+
with gr.Column(scale=1):
|
| 248 |
+
midi_file = gr.File(label="MIDI File Output")
|
| 249 |
+
audio_file = gr.Audio(label="Generated Audio Output", type="filepath")
|
| 250 |
+
generate_button = gr.Button("Generate", elem_id="generate-btn")
|
| 251 |
+
|
| 252 |
+
generate_button.click(
|
| 253 |
+
fn=generate_and_return_files,
|
| 254 |
+
inputs=[bpm, temperature, top_k],
|
| 255 |
+
outputs=[midi_file, audio_file]
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
iface.launch(share=True)
|
model_run.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
import copy
|
| 3 |
+
import torch
|
| 4 |
+
import math, os
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
RWKV_HEAD_QK_DIM = 1536
|
| 9 |
+
DEBUG_TIME = False
|
| 10 |
+
|
| 11 |
+
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
|
| 12 |
+
T_MAX = 1536
|
| 13 |
+
|
| 14 |
+
from torch.utils.cpp_extension import load
|
| 15 |
+
wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
|
| 16 |
+
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
|
| 17 |
+
|
| 18 |
+
class WKV(torch.autograd.Function):
|
| 19 |
+
@staticmethod
|
| 20 |
+
def forward(ctx, B, T, C, w, u, k, v):
|
| 21 |
+
ctx.B = B
|
| 22 |
+
ctx.T = T
|
| 23 |
+
ctx.C = C
|
| 24 |
+
assert T <= T_MAX
|
| 25 |
+
assert B * C % min(C, 1024) == 0
|
| 26 |
+
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
| 27 |
+
w = -torch.exp(w.contiguous())
|
| 28 |
+
u = u.contiguous()
|
| 29 |
+
k = k.contiguous()
|
| 30 |
+
v = v.contiguous()
|
| 31 |
+
else:
|
| 32 |
+
w = -torch.exp(w.float().contiguous())
|
| 33 |
+
u = u.float().contiguous()
|
| 34 |
+
k = k.float().contiguous()
|
| 35 |
+
v = v.float().contiguous()
|
| 36 |
+
ctx.save_for_backward(w, u, k, v)
|
| 37 |
+
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
|
| 38 |
+
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
| 39 |
+
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
| 40 |
+
return y
|
| 41 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
| 42 |
+
return y.half()
|
| 43 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
| 44 |
+
return y.bfloat16()
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def backward(ctx, gy):
|
| 48 |
+
B = ctx.B
|
| 49 |
+
T = ctx.T
|
| 50 |
+
C = ctx.C
|
| 51 |
+
assert T <= T_MAX
|
| 52 |
+
assert B * C % min(C, 1024) == 0
|
| 53 |
+
w, u, k, v = ctx.saved_tensors
|
| 54 |
+
gw = torch.zeros((B, C), device='cuda').contiguous()
|
| 55 |
+
gu = torch.zeros((B, C), device='cuda').contiguous()
|
| 56 |
+
gk = torch.zeros((B, T, C), device='cuda').contiguous()
|
| 57 |
+
gv = torch.zeros((B, T, C), device='cuda').contiguous()
|
| 58 |
+
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
| 59 |
+
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
| 60 |
+
else:
|
| 61 |
+
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
| 62 |
+
gw = torch.sum(gw, dim=0)
|
| 63 |
+
gu = torch.sum(gu, dim=0)
|
| 64 |
+
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
| 65 |
+
return (None, None, None, gw, gu, gk, gv)
|
| 66 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
| 67 |
+
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
| 68 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
| 69 |
+
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
| 70 |
+
|
| 71 |
+
def RUN_CUDA(B, T, C, w, u, k, v):
|
| 72 |
+
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
RWKV_CFG = types.SimpleNamespace()
|
| 76 |
+
|
| 77 |
+
class RWKV_ChannelMix(nn.Module):
|
| 78 |
+
def __init__(self, layer_id):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.layer_id = layer_id
|
| 81 |
+
|
| 82 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
| 83 |
+
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
| 84 |
+
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
| 85 |
+
|
| 86 |
+
hidden_sz = 4 * RWKV_CFG.n_embd
|
| 87 |
+
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
|
| 88 |
+
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
| 89 |
+
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
xx = self.time_shift(x)
|
| 93 |
+
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
| 94 |
+
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
| 95 |
+
|
| 96 |
+
k = self.key(xk)
|
| 97 |
+
k = torch.square(torch.relu(k))
|
| 98 |
+
kv = self.value(k)
|
| 99 |
+
|
| 100 |
+
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
| 101 |
+
return rkv
|
| 102 |
+
|
| 103 |
+
class RWKV_TimeMix(nn.Module):
|
| 104 |
+
def __init__(self, layer_id):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.layer_id = layer_id
|
| 107 |
+
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd))
|
| 108 |
+
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3))
|
| 109 |
+
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
| 110 |
+
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
| 111 |
+
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
| 112 |
+
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
| 113 |
+
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
| 114 |
+
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
| 115 |
+
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
| 116 |
+
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
B, T, C = x.size()
|
| 120 |
+
|
| 121 |
+
xx = self.time_shift(x)
|
| 122 |
+
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
| 123 |
+
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
| 124 |
+
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
| 125 |
+
|
| 126 |
+
k = self.key(xk)
|
| 127 |
+
v = self.value(xv)
|
| 128 |
+
r = self.receptance(xr)
|
| 129 |
+
|
| 130 |
+
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
| 131 |
+
|
| 132 |
+
rwkv = self.output(rwkv)
|
| 133 |
+
return rwkv
|
| 134 |
+
|
| 135 |
+
class Block(nn.Module):
|
| 136 |
+
def __init__(self, layer_id):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.layer_id = layer_id
|
| 139 |
+
|
| 140 |
+
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
|
| 141 |
+
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
|
| 142 |
+
if self.layer_id == 0:
|
| 143 |
+
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
|
| 144 |
+
|
| 145 |
+
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
| 146 |
+
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
|
| 147 |
+
else:
|
| 148 |
+
self.att = RWKV_TimeMix(layer_id)
|
| 149 |
+
|
| 150 |
+
self.ffn = RWKV_ChannelMix(layer_id)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
if self.layer_id == 0:
|
| 154 |
+
x = self.ln0(x)
|
| 155 |
+
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
| 156 |
+
x = x + self.ffnPre(self.ln1(x))
|
| 157 |
+
else:
|
| 158 |
+
x = x + self.att(self.ln1(x))
|
| 159 |
+
x = x + self.ffn(self.ln2(x))
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
class RWKV_GPT(nn.Module):
|
| 163 |
+
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
|
| 164 |
+
global RWKV_CFG
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
|
| 168 |
+
RWKV_CFG.model_type = model_type
|
| 169 |
+
RWKV_CFG.vocab_size = vocab_size
|
| 170 |
+
RWKV_CFG.n_layer = n_layer
|
| 171 |
+
RWKV_CFG.n_embd = n_embd
|
| 172 |
+
RWKV_CFG.ctx_len = ctx_len
|
| 173 |
+
|
| 174 |
+
print('\nloading RWKV-GPT', MODEL_NAME)
|
| 175 |
+
|
| 176 |
+
self.emb = nn.Embedding(vocab_size, n_embd)
|
| 177 |
+
|
| 178 |
+
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
|
| 179 |
+
|
| 180 |
+
self.ln_out = nn.LayerNorm(n_embd)
|
| 181 |
+
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
| 182 |
+
|
| 183 |
+
if RWKV_HEAD_QK_DIM > 0:
|
| 184 |
+
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
| 185 |
+
self.head_q.scale_init = 0
|
| 186 |
+
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
| 187 |
+
self.head_k.scale_init = 0.1
|
| 188 |
+
self.register_buffer("copy_mask", torch.tril(
|
| 189 |
+
torch.ones(ctx_len, ctx_len)))
|
| 190 |
+
|
| 191 |
+
self.ctx_len = ctx_len
|
| 192 |
+
self.eval()
|
| 193 |
+
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
|
| 194 |
+
self.eval()
|
| 195 |
+
|
| 196 |
+
def forward(self, idx):
|
| 197 |
+
B, T = idx.size()
|
| 198 |
+
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
| 199 |
+
|
| 200 |
+
x = self.emb(idx)
|
| 201 |
+
x = self.blocks(x)
|
| 202 |
+
x = self.ln_out(x)
|
| 203 |
+
|
| 204 |
+
if RWKV_HEAD_QK_DIM > 0:
|
| 205 |
+
q = self.head_q(x)[:, :T, :]
|
| 206 |
+
k = self.head_k(x)[:, :T, :]
|
| 207 |
+
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
| 208 |
+
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
| 209 |
+
|
| 210 |
+
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
| 211 |
+
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size)
|
| 212 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
| 213 |
+
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half()
|
| 214 |
+
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
| 215 |
+
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16()
|
| 216 |
+
|
| 217 |
+
x = self.head(x) + c
|
| 218 |
+
else:
|
| 219 |
+
x = self.head(x)
|
| 220 |
+
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class RWKV_RNN():
|
| 225 |
+
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
| 226 |
+
self.RUN_DEVICE = RUN_DEVICE
|
| 227 |
+
self.model_type = model_type
|
| 228 |
+
self.n_layer = n_layer
|
| 229 |
+
self.n_embd = n_embd
|
| 230 |
+
self.ctx_len = ctx_len
|
| 231 |
+
|
| 232 |
+
self.w = types.SimpleNamespace()
|
| 233 |
+
|
| 234 |
+
#w = torch.load(MODEL_NAME + '.pth',map_location=torch.device(RUN_DEVICE))
|
| 235 |
+
w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE), weights_only=True)
|
| 236 |
+
for x in w.keys():
|
| 237 |
+
w[x] = w[x].float()
|
| 238 |
+
if '.time_' in x:
|
| 239 |
+
w[x] = w[x].squeeze()
|
| 240 |
+
if '.time_decay' in x:
|
| 241 |
+
w[x] = -torch.exp(w[x])
|
| 242 |
+
if DEBUG_TIME and '.time_' in x:
|
| 243 |
+
print(x, w[x].squeeze().cpu().numpy())
|
| 244 |
+
|
| 245 |
+
xx = x.split('.')
|
| 246 |
+
here = self.w
|
| 247 |
+
for i in range(len(xx)):
|
| 248 |
+
if xx[i].isdigit():
|
| 249 |
+
ii = int(xx[i])
|
| 250 |
+
if ii not in here:
|
| 251 |
+
here[ii] = types.SimpleNamespace()
|
| 252 |
+
here = here[ii]
|
| 253 |
+
else:
|
| 254 |
+
if i == len(xx) - 1:
|
| 255 |
+
setattr(here, xx[i], w[x])
|
| 256 |
+
elif not hasattr(here, xx[i]):
|
| 257 |
+
if xx[i+1].isdigit():
|
| 258 |
+
setattr(here, xx[i], {})
|
| 259 |
+
else:
|
| 260 |
+
setattr(here, xx[i], types.SimpleNamespace())
|
| 261 |
+
here = getattr(here, xx[i])
|
| 262 |
+
|
| 263 |
+
self.clear()
|
| 264 |
+
|
| 265 |
+
def clear(self):
|
| 266 |
+
self.xx = {}
|
| 267 |
+
self.aa = {}
|
| 268 |
+
self.bb = {}
|
| 269 |
+
self.pp = {}
|
| 270 |
+
self.hk = None
|
| 271 |
+
|
| 272 |
+
def save(self, target):
|
| 273 |
+
target.xx = copy.deepcopy(self.xx)
|
| 274 |
+
target.aa = copy.deepcopy(self.aa)
|
| 275 |
+
target.bb = copy.deepcopy(self.bb)
|
| 276 |
+
target.pp = copy.deepcopy(self.pp)
|
| 277 |
+
target.hk = copy.deepcopy(self.hk)
|
| 278 |
+
|
| 279 |
+
def load(self, target):
|
| 280 |
+
self.xx = copy.deepcopy(target.xx)
|
| 281 |
+
self.aa = copy.deepcopy(target.aa)
|
| 282 |
+
self.bb = copy.deepcopy(target.bb)
|
| 283 |
+
self.pp = copy.deepcopy(target.pp)
|
| 284 |
+
self.hk = copy.deepcopy(target.hk)
|
| 285 |
+
|
| 286 |
+
def LN(self, xx, w):
|
| 287 |
+
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
| 288 |
+
|
| 289 |
+
def FF(self, xx, w, name):
|
| 290 |
+
if name not in self.xx:
|
| 291 |
+
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
| 292 |
+
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
| 293 |
+
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
| 294 |
+
self.xx[name] = xx
|
| 295 |
+
|
| 296 |
+
r = torch.sigmoid(w.receptance.weight @ xr)
|
| 297 |
+
k = torch.square(torch.relu(w.key.weight @ xk))
|
| 298 |
+
kv = w.value.weight @ k
|
| 299 |
+
|
| 300 |
+
return r * kv
|
| 301 |
+
|
| 302 |
+
def SA(self, xx, w, name):
|
| 303 |
+
if name not in self.xx:
|
| 304 |
+
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
| 305 |
+
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
| 306 |
+
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
| 307 |
+
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
|
| 308 |
+
|
| 309 |
+
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
| 310 |
+
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
|
| 311 |
+
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
| 312 |
+
self.xx[name] = xx
|
| 313 |
+
|
| 314 |
+
r = torch.sigmoid(w.receptance.weight @ xr)
|
| 315 |
+
|
| 316 |
+
k = w.key.weight @ xk
|
| 317 |
+
v = w.value.weight @ xv
|
| 318 |
+
|
| 319 |
+
pp = self.pp[name]
|
| 320 |
+
aa = self.aa[name]
|
| 321 |
+
bb = self.bb[name]
|
| 322 |
+
ww = w.time_first + k
|
| 323 |
+
p = torch.maximum(pp, ww)
|
| 324 |
+
e1 = torch.exp(pp - p)
|
| 325 |
+
e2 = torch.exp(ww - p)
|
| 326 |
+
a = e1 * aa + e2 * v
|
| 327 |
+
b = e1 * bb + e2
|
| 328 |
+
ww = pp + w.time_decay
|
| 329 |
+
p = torch.maximum(ww, k)
|
| 330 |
+
e1 = torch.exp(ww - p)
|
| 331 |
+
e2 = torch.exp(k - p)
|
| 332 |
+
self.aa[name] = e1 * aa + e2 * v
|
| 333 |
+
self.bb[name] = e1 * bb + e2
|
| 334 |
+
self.pp[name] = p
|
| 335 |
+
|
| 336 |
+
rwkv = r * a / b
|
| 337 |
+
|
| 338 |
+
return w.output.weight @ rwkv
|
| 339 |
+
|
| 340 |
+
def run(self, ctx):
|
| 341 |
+
w = self.w
|
| 342 |
+
x = w.emb.weight[ctx[-1]]
|
| 343 |
+
|
| 344 |
+
for i in range(self.n_layer):
|
| 345 |
+
if i == 0:
|
| 346 |
+
x = self.LN(x, w.blocks[i].ln0)
|
| 347 |
+
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
| 348 |
+
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
|
| 349 |
+
else:
|
| 350 |
+
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
|
| 351 |
+
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
|
| 352 |
+
|
| 353 |
+
x = self.LN(x, w.ln_out)
|
| 354 |
+
|
| 355 |
+
if RWKV_HEAD_QK_DIM > 0:
|
| 356 |
+
if self.hk == None:
|
| 357 |
+
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
| 358 |
+
else:
|
| 359 |
+
self.hk = torch.cat(
|
| 360 |
+
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
| 361 |
+
if self.hk.shape[0] > self.ctx_len:
|
| 362 |
+
self.hk = self.hk[-self.ctx_len:, :]
|
| 363 |
+
|
| 364 |
+
q = w.head_q.weight @ x
|
| 365 |
+
|
| 366 |
+
x = w.head.weight @ x
|
| 367 |
+
x = x.cpu().numpy().tolist()
|
| 368 |
+
|
| 369 |
+
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
| 370 |
+
for i in range(len(c)):
|
| 371 |
+
x[ctx[i]] += c[i]
|
| 372 |
+
else:
|
| 373 |
+
x = w.head.weight @ x
|
| 374 |
+
x = x.cpu().numpy().tolist()
|
| 375 |
+
|
| 376 |
+
return x
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
fluidsynth
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pretty_midi==0.2.10
|
| 2 |
+
pedalboard==0.9.3
|
| 3 |
+
torch
|
| 4 |
+
gradio
|
sf2/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
sf2/piano.sf2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39568d475db895ab5e372dfbb1611d90b4a267306595dd7d619e99c0816ae1f9
|
| 3 |
+
size 74921906
|
temp/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
temp/output.mid
ADDED
|
Binary file (256 Bytes). View file
|
|
|
temp/output.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae0bc1d9d63ad001452ef8414d0c91c3248867d126992d8f48c4276fb5cc0c36
|
| 3 |
+
size 2823724
|
temp/output_fx.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60bb94895397db75ff8913f3e0aad7da1d7f917b18db71c3db19609944b8096d
|
| 3 |
+
size 2823784
|
utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn import functional as F
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
stoi = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '\n': 10, '000000000000': 11}
|
| 10 |
+
itos = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '\n', 11: '000000000000'}
|
| 11 |
+
|
| 12 |
+
tok_chars = re.compile(r'000000000000|\d{1}|\n')
|
| 13 |
+
|
| 14 |
+
def encode(text, stoi, tokenizer):
|
| 15 |
+
matches = tokenizer.findall(text)
|
| 16 |
+
return [stoi[c] for c in matches if c in stoi]
|
| 17 |
+
|
| 18 |
+
def decode(encoded, itos):
|
| 19 |
+
return ''.join([itos[i] for i in encoded])
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Dataset:
|
| 23 |
+
def __init__(self, data, ctx_len, epoch_length_fixed, time_aug=True):
|
| 24 |
+
self.ctx_len = ctx_len
|
| 25 |
+
self.epoch_length_fixed = epoch_length_fixed
|
| 26 |
+
self.start_token = '000000000000'
|
| 27 |
+
self.tokenizer = tok_chars
|
| 28 |
+
self.stoi = stoi
|
| 29 |
+
self.itos = itos
|
| 30 |
+
self.vocab_size = len(stoi)
|
| 31 |
+
print('vocab size:', self.vocab_size)
|
| 32 |
+
self.data = encode(data, self.stoi, self.tokenizer)
|
| 33 |
+
self.data_size = len(self.data)
|
| 34 |
+
print(f'data has {self.data_size} tokens')
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return self.epoch_length_fixed
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, idx):
|
| 40 |
+
cues = []
|
| 41 |
+
idx_randm = random.randint(0, len(self.data) - (self.ctx_len) * 4)
|
| 42 |
+
i = idx_randm
|
| 43 |
+
|
| 44 |
+
while True:
|
| 45 |
+
if self.data[i] == self.stoi[self.start_token]:
|
| 46 |
+
cues = [i]
|
| 47 |
+
break
|
| 48 |
+
else:
|
| 49 |
+
i = (i + 1) % len(self.data)
|
| 50 |
+
|
| 51 |
+
if not cues:
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
start_idx = cues[0]
|
| 55 |
+
dix = self.data[start_idx : start_idx + self.ctx_len + 2]
|
| 56 |
+
|
| 57 |
+
# 96 tick resolution
|
| 58 |
+
time_shift = [
|
| 59 |
+
[0, 0, 0, 0, 0, 7, 6, 8, 0, 7, 6, 8, 0],
|
| 60 |
+
[0, 0, 0, 0, 1, 5, 3, 6, 1, 5, 3, 6, 0],
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
data_aug = random.choice([True, False])
|
| 64 |
+
|
| 65 |
+
t = dix[2:2 + self.ctx_len] # testing
|
| 66 |
+
|
| 67 |
+
if data_aug:
|
| 68 |
+
ts_rndm = random.choice(time_shift)
|
| 69 |
+
ts = ts_rndm * ((self.ctx_len - 1) // len(ts_rndm) + 1)
|
| 70 |
+
tsx = torch.tensor(ts[:self.ctx_len])
|
| 71 |
+
|
| 72 |
+
for j in reversed(range(len(t))):
|
| 73 |
+
if j % 13 not in range(2, 12):
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
aug_int = t[j] + tsx[j]
|
| 77 |
+
if aug_int >= 10 and (aug_int not in [10, 11] or j not in [9, 10]):
|
| 78 |
+
left_int = aug_int // 10
|
| 79 |
+
right_int = aug_int % 10
|
| 80 |
+
if j > 0:
|
| 81 |
+
t[j - 1] += left_int
|
| 82 |
+
t[j] = right_int
|
| 83 |
+
else:
|
| 84 |
+
t[j] = aug_int
|
| 85 |
+
|
| 86 |
+
x = t
|
| 87 |
+
y = t[1:] + [t[-1]]
|
| 88 |
+
else:
|
| 89 |
+
x = dix[:-1][:self.ctx_len]
|
| 90 |
+
y = dix[1:][:self.ctx_len]
|
| 91 |
+
|
| 92 |
+
x = torch.tensor(x, dtype=torch.int64)
|
| 93 |
+
y = torch.tensor(y, dtype=torch.int64)
|
| 94 |
+
|
| 95 |
+
return x, y
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class TOKENIZER():
|
| 99 |
+
def __init__(self):
|
| 100 |
+
self.tokenizer = tok_chars
|
| 101 |
+
self.stoi = stoi
|
| 102 |
+
self.itos = itos
|
| 103 |
+
self.vocab_size = len(self.stoi)
|
| 104 |
+
|
| 105 |
+
def encode(self, text):
|
| 106 |
+
matches = self.tokenizer.findall(text)
|
| 107 |
+
return [self.stoi[c] for c in matches if c in self.stoi]
|
| 108 |
+
|
| 109 |
+
def decode(self, encoded):
|
| 110 |
+
return ''.join([self.itos[i] for i in encoded])
|
| 111 |
+
|
| 112 |
+
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_k=50):
|
| 113 |
+
probs = F.softmax(torch.tensor(out), dim=-1)
|
| 114 |
+
|
| 115 |
+
if top_k > 0:
|
| 116 |
+
top_k = min(top_k, probs.size(-1))
|
| 117 |
+
sorted_probs, sorted_indices = torch.topk(probs, top_k)
|
| 118 |
+
probs.fill_(0)
|
| 119 |
+
probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs)
|
| 120 |
+
|
| 121 |
+
if temperature != 1.0:
|
| 122 |
+
probs = probs.pow(1.0 / temperature)
|
| 123 |
+
|
| 124 |
+
return torch.multinomial(probs, num_samples=1)[0]
|
| 125 |
+
|