patchbanks commited on
Commit
7deef83
·
1 Parent(s): 6438498

Upload 12 files

Browse files
Files changed (12) hide show
  1. .gitattributes +3 -0
  2. app.py +258 -0
  3. model_run.py +376 -0
  4. packages.txt +1 -0
  5. requirements.txt +4 -0
  6. sf2/.DS_Store +0 -0
  7. sf2/piano.sf2 +3 -0
  8. temp/.DS_Store +0 -0
  9. temp/output.mid +0 -0
  10. temp/output.wav +3 -0
  11. temp/output_fx.wav +3 -0
  12. utils.py +125 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sf2/piano.sf2 filter=lfs diff=lfs merge=lfs -text
37
+ temp/output_fx.wav filter=lfs diff=lfs merge=lfs -text
38
+ temp/output.wav filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ from torch.nn import functional as F
3
+ from utils import TOKENIZER, Dataset
4
+ from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter
5
+ from pedalboard.io import AudioFile
6
+ import pandas as pd
7
+ import subprocess
8
+ import pretty_midi
9
+ import gradio as gr
10
+ import time
11
+ import copy
12
+ import types
13
+ import torch
14
+ import random
15
+ import os
16
+
17
+ torch.backends.cudnn.benchmark = True
18
+ torch.backends.cudnn.allow_tf32 = True
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+
21
+ in_space = os.getenv("SYSTEM") == "spaces"
22
+
23
+ n_layer = 8
24
+ n_embd = 768
25
+ ctx_len = 1536
26
+ top_k = 16
27
+
28
+ os.environ['RWKV_FLOAT_MODE'] = 'fp32'
29
+ os.environ['RWKV_RUN_DEVICE'] = 'cpu'
30
+ model_type = 'RWKV'
31
+
32
+ MODEL_NAME = 'model'
33
+ LENGTH_PER_TRIAL = round((2000) / 13) * 13
34
+ TEMPERATURE = 1.0
35
+
36
+ from model_run import RWKV_RNN
37
+ model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len)
38
+ tokenizer = TOKENIZER()
39
+
40
+ temp_dir = 'temp'
41
+ if not os.path.exists(temp_dir):
42
+ os.makedirs(temp_dir)
43
+
44
+ def clear_midi(dir):
45
+ for file in os.listdir(dir):
46
+ if file.endswith('.mid'):
47
+ os.remove(os.path.join(dir, file))
48
+
49
+ clear_midi(temp_dir)
50
+
51
+
52
+ ctx_seed = "000000000000\n"
53
+ ctx = tokenizer.encode(ctx_seed)
54
+ src_len = len(ctx)
55
+ src_ctx = ctx.copy()
56
+
57
+
58
+ def generate_midi(LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, TEMPERATURE, top_k, tokenizer, ctx_seed, bpm):
59
+ midi_seq = []
60
+
61
+ for TRIAL in range(1):
62
+ t_begin = time.time_ns()
63
+
64
+ if TRIAL > 0:
65
+ midi_seq.append("\n")
66
+
67
+ ctx = src_ctx.copy()
68
+ model.clear()
69
+ midi_tokens = []
70
+
71
+ if TRIAL == 0:
72
+ init_state = types.SimpleNamespace()
73
+ for i in range(src_len):
74
+ x = ctx[:i+1]
75
+ if i == src_len - 1:
76
+ init_state.out = model.run(x)
77
+ else:
78
+ model.run(x)
79
+ model.save(init_state)
80
+ else:
81
+ model.load(init_state)
82
+
83
+ midi_seq.append(ctx_seed)
84
+
85
+ for i in range(src_len, src_len + LENGTH_PER_TRIAL):
86
+ x = ctx[:i+1]
87
+ x = x[-ctx_len:]
88
+
89
+ if i == src_len:
90
+ out = copy.deepcopy(init_state.out)
91
+ else:
92
+ out = model.run(x)
93
+
94
+ char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, top_k=top_k).item()
95
+ midi_tokens.append(char)
96
+
97
+ if len(midi_tokens) > 2:
98
+ midi_tokens.pop(0)
99
+
100
+ if midi_tokens == [11, 10]: # stop token pattern
101
+ break
102
+
103
+ midi_seq.append(tokenizer.decode([int(char)]))
104
+
105
+ if midi_tokens != [11, 10]:
106
+ ctx += [char]
107
+
108
+ t_end = time.time_ns()
109
+
110
+ trim_seq = "".join(midi_seq)
111
+ events = trim_seq.split("\n")
112
+
113
+ midi_events = []
114
+ sequence = []
115
+ rndm_num = 895645
116
+
117
+ for event in events:
118
+ if event.strip() == "":
119
+ midi_events.append(sequence)
120
+ sequence = []
121
+ rndm_num = random.randint(100000, 999999)
122
+ try:
123
+ pitch = int(event[0:2])
124
+ velocity = int(event[2:4])
125
+ start = int(event[4:8])
126
+ end = int(event[8:12])
127
+ except ValueError:
128
+ pitch = 0
129
+ velocity = 0
130
+ start = 0
131
+ end = 0
132
+
133
+ sequence.append({'file_name': f'rwkv_{rndm_num}', 'pitch': pitch, 'velocity': velocity, 'start': start, 'end': end})
134
+
135
+ if sequence:
136
+ midi_events.append(sequence)
137
+
138
+ midi_events = pd.DataFrame([pd.Series(event) for sequence in midi_events for event in sequence])
139
+ midi_events = midi_events[['file_name', 'pitch', 'velocity', 'start', 'end']]
140
+ midi_events = midi_events.sort_values(by=['file_name', 'start']).reset_index(drop=True)
141
+ midi_events = midi_events[(midi_events['start'] < 3072) & (midi_events['end'] <= 3072)]
142
+
143
+ for file_name, events in midi_events.groupby('file_name'):
144
+ midi_obj = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96)
145
+ instrument = pretty_midi.Instrument(0)
146
+ midi_obj.instruments.append(instrument)
147
+
148
+ for _, event in events.iterrows():
149
+ note = pretty_midi.Note(
150
+ pitch=event['pitch'],
151
+ velocity=event['velocity'],
152
+ start=midi_obj.tick_to_time(event['start']),
153
+ end=midi_obj.tick_to_time(event['end'])
154
+ )
155
+ instrument.notes.append(note)
156
+
157
+ midi_path = os.path.join(temp_dir, 'output.mid')
158
+ midi_obj.write(midi_path)
159
+
160
+ return midi_path
161
+
162
+
163
+ def render_wav(midi_file, uploaded_sf2=None):
164
+ sf2_dir = 'sf2'
165
+ audio_format = 's16'
166
+ sample_rate = '44100'
167
+ gain = '2.0'
168
+
169
+ if uploaded_sf2:
170
+ sf2_file = uploaded_sf2
171
+ else:
172
+ sf2_files = [f for f in os.listdir(os.path.join(sf2_dir)) if f.endswith('.sf2')]
173
+ if not sf2_files:
174
+ raise ValueError("No SoundFont (.sf2) file found in directory.")
175
+ sf2_file = os.path.join(sf2_dir, random.choice(sf2_files))
176
+
177
+ print(f"Using SoundFont: {sf2_file}")
178
+ output_wav = os.path.join(temp_dir, 'output.wav')
179
+
180
+ with open(os.devnull, 'w') as devnull:
181
+ command = [
182
+ 'fluidsynth', '-ni', sf2_file, midi_file, '-F', output_wav, '-r', str(sample_rate),
183
+ '-o', f'audio.file.format={audio_format}', '-g', str(gain)
184
+ ]
185
+ subprocess.call(command, stdout=devnull, stderr=devnull)
186
+
187
+ return output_wav
188
+
189
+
190
+ def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None):
191
+ midi_events = generate_midi(
192
+ LENGTH_PER_TRIAL, src_ctx, model, src_len, ctx_len, temperature, top_k,
193
+ tokenizer, ctx_seed, bpm
194
+ )
195
+
196
+ midi_file = 'temp/output.mid'
197
+ wav_raw = render_wav(midi_file, uploaded_sf2)
198
+ wav_fx = os.path.join(temp_dir, 'output_fx.wav')
199
+
200
+ sfx_settings = [
201
+ {
202
+ 'board': Pedalboard([
203
+ Reverb(room_size=0.50, wet_level=0.40, dry_level=0.70, width=1.0),
204
+ Compressor(threshold_db=-3.0, ratio=8.0, attack_ms=0.0, release_ms=300.0),
205
+ ])
206
+ }
207
+ ]
208
+
209
+ for setting in sfx_settings:
210
+ board = setting['board']
211
+
212
+ with AudioFile(wav_raw) as f:
213
+ with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o:
214
+ while f.tell() < f.frames:
215
+ chunk = f.read(int(f.samplerate))
216
+ effected = board(chunk, f.samplerate, reset=False)
217
+ o.write(effected)
218
+
219
+ return midi_file, wav_fx
220
+
221
+
222
+ custom_css = """
223
+ #generate-btn {
224
+ background-color: #6366f1 !important;
225
+ color: white !important;
226
+ border: none !important;
227
+ font-size: 16px;
228
+ padding: 10px 20px;
229
+ border-radius: 5px;
230
+ cursor: pointer;
231
+ }
232
+ #generate-btn:hover {
233
+ background-color: #4f51c5 !important;
234
+ }
235
+ """
236
+
237
+ with gr.Blocks(css=custom_css, theme="soft") as iface:
238
+ gr.Markdown("<h1 style='font-weight: bold; text-align: center;'>Pop-K</h1>")
239
+ gr.Markdown("<p style='text-align:center;'>Pop-K is a small RWKV model that generates pop melodies in C major and A minor.</p>")
240
+
241
+ with gr.Row():
242
+ with gr.Column(scale=1):
243
+ bpm = gr.Slider(minimum=50, maximum=200, step=1, value=120, label="BPM")
244
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.01, value=1.0, label="Temperature")
245
+ top_k = gr.Slider(minimum=1, maximum=32, step=1, value=16, label="Top-K")
246
+
247
+ with gr.Column(scale=1):
248
+ midi_file = gr.File(label="MIDI File Output")
249
+ audio_file = gr.Audio(label="Generated Audio Output", type="filepath")
250
+ generate_button = gr.Button("Generate", elem_id="generate-btn")
251
+
252
+ generate_button.click(
253
+ fn=generate_and_return_files,
254
+ inputs=[bpm, temperature, top_k],
255
+ outputs=[midi_file, audio_file]
256
+ )
257
+
258
+ iface.launch(share=True)
model_run.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import copy
3
+ import torch
4
+ import math, os
5
+ from torch.nn import functional as F
6
+ import torch.nn as nn
7
+
8
+ RWKV_HEAD_QK_DIM = 1536
9
+ DEBUG_TIME = False
10
+
11
+ if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
12
+ T_MAX = 1536
13
+
14
+ from torch.utils.cpp_extension import load
15
+ wkv_cuda = load(name="wkv", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"],
16
+ verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
17
+
18
+ class WKV(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(ctx, B, T, C, w, u, k, v):
21
+ ctx.B = B
22
+ ctx.T = T
23
+ ctx.C = C
24
+ assert T <= T_MAX
25
+ assert B * C % min(C, 1024) == 0
26
+ if '32' in os.environ['RWKV_FLOAT_MODE']:
27
+ w = -torch.exp(w.contiguous())
28
+ u = u.contiguous()
29
+ k = k.contiguous()
30
+ v = v.contiguous()
31
+ else:
32
+ w = -torch.exp(w.float().contiguous())
33
+ u = u.float().contiguous()
34
+ k = k.float().contiguous()
35
+ v = v.float().contiguous()
36
+ ctx.save_for_backward(w, u, k, v)
37
+ y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
38
+ wkv_cuda.forward(B, T, C, w, u, k, v, y)
39
+ if '32' in os.environ['RWKV_FLOAT_MODE']:
40
+ return y
41
+ elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
42
+ return y.half()
43
+ elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
44
+ return y.bfloat16()
45
+
46
+ @staticmethod
47
+ def backward(ctx, gy):
48
+ B = ctx.B
49
+ T = ctx.T
50
+ C = ctx.C
51
+ assert T <= T_MAX
52
+ assert B * C % min(C, 1024) == 0
53
+ w, u, k, v = ctx.saved_tensors
54
+ gw = torch.zeros((B, C), device='cuda').contiguous()
55
+ gu = torch.zeros((B, C), device='cuda').contiguous()
56
+ gk = torch.zeros((B, T, C), device='cuda').contiguous()
57
+ gv = torch.zeros((B, T, C), device='cuda').contiguous()
58
+ if '32' in os.environ['RWKV_FLOAT_MODE']:
59
+ wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
60
+ else:
61
+ wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
62
+ gw = torch.sum(gw, dim=0)
63
+ gu = torch.sum(gu, dim=0)
64
+ if '32' in os.environ['RWKV_FLOAT_MODE']:
65
+ return (None, None, None, gw, gu, gk, gv)
66
+ elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
67
+ return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
68
+ elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
69
+ return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
70
+
71
+ def RUN_CUDA(B, T, C, w, u, k, v):
72
+ return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
73
+
74
+
75
+ RWKV_CFG = types.SimpleNamespace()
76
+
77
+ class RWKV_ChannelMix(nn.Module):
78
+ def __init__(self, layer_id):
79
+ super().__init__()
80
+ self.layer_id = layer_id
81
+
82
+ self.time_shift = nn.ZeroPad2d((0,0,1,-1))
83
+ self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
84
+ self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
85
+
86
+ hidden_sz = 4 * RWKV_CFG.n_embd
87
+ self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
88
+ self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
89
+ self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
90
+
91
+ def forward(self, x):
92
+ xx = self.time_shift(x)
93
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
94
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
95
+
96
+ k = self.key(xk)
97
+ k = torch.square(torch.relu(k))
98
+ kv = self.value(k)
99
+
100
+ rkv = torch.sigmoid(self.receptance(xr)) * kv
101
+ return rkv
102
+
103
+ class RWKV_TimeMix(nn.Module):
104
+ def __init__(self, layer_id):
105
+ super().__init__()
106
+ self.layer_id = layer_id
107
+ self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd))
108
+ self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3))
109
+ self.time_shift = nn.ZeroPad2d((0,0,1,-1))
110
+ self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
111
+ self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
112
+ self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
113
+ self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
114
+ self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
115
+ self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
116
+ self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
117
+
118
+ def forward(self, x):
119
+ B, T, C = x.size()
120
+
121
+ xx = self.time_shift(x)
122
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
123
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
124
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
125
+
126
+ k = self.key(xk)
127
+ v = self.value(xv)
128
+ r = self.receptance(xr)
129
+
130
+ rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
131
+
132
+ rwkv = self.output(rwkv)
133
+ return rwkv
134
+
135
+ class Block(nn.Module):
136
+ def __init__(self, layer_id):
137
+ super().__init__()
138
+ self.layer_id = layer_id
139
+
140
+ self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
141
+ self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
142
+ if self.layer_id == 0:
143
+ self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
144
+
145
+ if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
146
+ self.ffnPre = RWKV_ChannelMix(layer_id+1000)
147
+ else:
148
+ self.att = RWKV_TimeMix(layer_id)
149
+
150
+ self.ffn = RWKV_ChannelMix(layer_id)
151
+
152
+ def forward(self, x):
153
+ if self.layer_id == 0:
154
+ x = self.ln0(x)
155
+ if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
156
+ x = x + self.ffnPre(self.ln1(x))
157
+ else:
158
+ x = x + self.att(self.ln1(x))
159
+ x = x + self.ffn(self.ln2(x))
160
+ return x
161
+
162
+ class RWKV_GPT(nn.Module):
163
+ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
164
+ global RWKV_CFG
165
+ super().__init__()
166
+
167
+ RWKV_CFG.RUN_DEVICE = RUN_DEVICE
168
+ RWKV_CFG.model_type = model_type
169
+ RWKV_CFG.vocab_size = vocab_size
170
+ RWKV_CFG.n_layer = n_layer
171
+ RWKV_CFG.n_embd = n_embd
172
+ RWKV_CFG.ctx_len = ctx_len
173
+
174
+ print('\nloading RWKV-GPT', MODEL_NAME)
175
+
176
+ self.emb = nn.Embedding(vocab_size, n_embd)
177
+
178
+ self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
179
+
180
+ self.ln_out = nn.LayerNorm(n_embd)
181
+ self.head = nn.Linear(n_embd, vocab_size, bias=False)
182
+
183
+ if RWKV_HEAD_QK_DIM > 0:
184
+ self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
185
+ self.head_q.scale_init = 0
186
+ self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
187
+ self.head_k.scale_init = 0.1
188
+ self.register_buffer("copy_mask", torch.tril(
189
+ torch.ones(ctx_len, ctx_len)))
190
+
191
+ self.ctx_len = ctx_len
192
+ self.eval()
193
+ self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
194
+ self.eval()
195
+
196
+ def forward(self, idx):
197
+ B, T = idx.size()
198
+ assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
199
+
200
+ x = self.emb(idx)
201
+ x = self.blocks(x)
202
+ x = self.ln_out(x)
203
+
204
+ if RWKV_HEAD_QK_DIM > 0:
205
+ q = self.head_q(x)[:, :T, :]
206
+ k = self.head_k(x)[:, :T, :]
207
+ c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
208
+ c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
209
+
210
+ if '32' in os.environ['RWKV_FLOAT_MODE']:
211
+ c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size)
212
+ elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
213
+ c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half()
214
+ elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
215
+ c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16()
216
+
217
+ x = self.head(x) + c
218
+ else:
219
+ x = self.head(x)
220
+
221
+ return x
222
+
223
+
224
+ class RWKV_RNN():
225
+ def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
226
+ self.RUN_DEVICE = RUN_DEVICE
227
+ self.model_type = model_type
228
+ self.n_layer = n_layer
229
+ self.n_embd = n_embd
230
+ self.ctx_len = ctx_len
231
+
232
+ self.w = types.SimpleNamespace()
233
+
234
+ #w = torch.load(MODEL_NAME + '.pth',map_location=torch.device(RUN_DEVICE))
235
+ w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE), weights_only=True)
236
+ for x in w.keys():
237
+ w[x] = w[x].float()
238
+ if '.time_' in x:
239
+ w[x] = w[x].squeeze()
240
+ if '.time_decay' in x:
241
+ w[x] = -torch.exp(w[x])
242
+ if DEBUG_TIME and '.time_' in x:
243
+ print(x, w[x].squeeze().cpu().numpy())
244
+
245
+ xx = x.split('.')
246
+ here = self.w
247
+ for i in range(len(xx)):
248
+ if xx[i].isdigit():
249
+ ii = int(xx[i])
250
+ if ii not in here:
251
+ here[ii] = types.SimpleNamespace()
252
+ here = here[ii]
253
+ else:
254
+ if i == len(xx) - 1:
255
+ setattr(here, xx[i], w[x])
256
+ elif not hasattr(here, xx[i]):
257
+ if xx[i+1].isdigit():
258
+ setattr(here, xx[i], {})
259
+ else:
260
+ setattr(here, xx[i], types.SimpleNamespace())
261
+ here = getattr(here, xx[i])
262
+
263
+ self.clear()
264
+
265
+ def clear(self):
266
+ self.xx = {}
267
+ self.aa = {}
268
+ self.bb = {}
269
+ self.pp = {}
270
+ self.hk = None
271
+
272
+ def save(self, target):
273
+ target.xx = copy.deepcopy(self.xx)
274
+ target.aa = copy.deepcopy(self.aa)
275
+ target.bb = copy.deepcopy(self.bb)
276
+ target.pp = copy.deepcopy(self.pp)
277
+ target.hk = copy.deepcopy(self.hk)
278
+
279
+ def load(self, target):
280
+ self.xx = copy.deepcopy(target.xx)
281
+ self.aa = copy.deepcopy(target.aa)
282
+ self.bb = copy.deepcopy(target.bb)
283
+ self.pp = copy.deepcopy(target.pp)
284
+ self.hk = copy.deepcopy(target.hk)
285
+
286
+ def LN(self, xx, w):
287
+ return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
288
+
289
+ def FF(self, xx, w, name):
290
+ if name not in self.xx:
291
+ self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
292
+ xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
293
+ xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
294
+ self.xx[name] = xx
295
+
296
+ r = torch.sigmoid(w.receptance.weight @ xr)
297
+ k = torch.square(torch.relu(w.key.weight @ xk))
298
+ kv = w.value.weight @ k
299
+
300
+ return r * kv
301
+
302
+ def SA(self, xx, w, name):
303
+ if name not in self.xx:
304
+ self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
305
+ self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
306
+ self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
307
+ self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
308
+
309
+ xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
310
+ xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
311
+ xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
312
+ self.xx[name] = xx
313
+
314
+ r = torch.sigmoid(w.receptance.weight @ xr)
315
+
316
+ k = w.key.weight @ xk
317
+ v = w.value.weight @ xv
318
+
319
+ pp = self.pp[name]
320
+ aa = self.aa[name]
321
+ bb = self.bb[name]
322
+ ww = w.time_first + k
323
+ p = torch.maximum(pp, ww)
324
+ e1 = torch.exp(pp - p)
325
+ e2 = torch.exp(ww - p)
326
+ a = e1 * aa + e2 * v
327
+ b = e1 * bb + e2
328
+ ww = pp + w.time_decay
329
+ p = torch.maximum(ww, k)
330
+ e1 = torch.exp(ww - p)
331
+ e2 = torch.exp(k - p)
332
+ self.aa[name] = e1 * aa + e2 * v
333
+ self.bb[name] = e1 * bb + e2
334
+ self.pp[name] = p
335
+
336
+ rwkv = r * a / b
337
+
338
+ return w.output.weight @ rwkv
339
+
340
+ def run(self, ctx):
341
+ w = self.w
342
+ x = w.emb.weight[ctx[-1]]
343
+
344
+ for i in range(self.n_layer):
345
+ if i == 0:
346
+ x = self.LN(x, w.blocks[i].ln0)
347
+ if i == 0 and self.model_type == 'RWKV-ffnPre':
348
+ x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
349
+ else:
350
+ x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
351
+ x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
352
+
353
+ x = self.LN(x, w.ln_out)
354
+
355
+ if RWKV_HEAD_QK_DIM > 0:
356
+ if self.hk == None:
357
+ self.hk = (w.head_k.weight @ x).unsqueeze(0)
358
+ else:
359
+ self.hk = torch.cat(
360
+ [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
361
+ if self.hk.shape[0] > self.ctx_len:
362
+ self.hk = self.hk[-self.ctx_len:, :]
363
+
364
+ q = w.head_q.weight @ x
365
+
366
+ x = w.head.weight @ x
367
+ x = x.cpu().numpy().tolist()
368
+
369
+ c = (self.hk @ q) / RWKV_HEAD_QK_DIM
370
+ for i in range(len(c)):
371
+ x[ctx[i]] += c[i]
372
+ else:
373
+ x = w.head.weight @ x
374
+ x = x.cpu().numpy().tolist()
375
+
376
+ return x
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pretty_midi==0.2.10
2
+ pedalboard==0.9.3
3
+ torch
4
+ gradio
sf2/.DS_Store ADDED
Binary file (6.15 kB). View file
 
sf2/piano.sf2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39568d475db895ab5e372dfbb1611d90b4a267306595dd7d619e99c0816ae1f9
3
+ size 74921906
temp/.DS_Store ADDED
Binary file (6.15 kB). View file
 
temp/output.mid ADDED
Binary file (256 Bytes). View file
 
temp/output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae0bc1d9d63ad001452ef8414d0c91c3248867d126992d8f48c4276fb5cc0c36
3
+ size 2823724
temp/output_fx.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60bb94895397db75ff8913f3e0aad7da1d7f917b18db71c3db19609944b8096d
3
+ size 2823784
utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import functional as F
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import re
7
+
8
+
9
+ stoi = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '\n': 10, '000000000000': 11}
10
+ itos = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '\n', 11: '000000000000'}
11
+
12
+ tok_chars = re.compile(r'000000000000|\d{1}|\n')
13
+
14
+ def encode(text, stoi, tokenizer):
15
+ matches = tokenizer.findall(text)
16
+ return [stoi[c] for c in matches if c in stoi]
17
+
18
+ def decode(encoded, itos):
19
+ return ''.join([itos[i] for i in encoded])
20
+
21
+
22
+ class Dataset:
23
+ def __init__(self, data, ctx_len, epoch_length_fixed, time_aug=True):
24
+ self.ctx_len = ctx_len
25
+ self.epoch_length_fixed = epoch_length_fixed
26
+ self.start_token = '000000000000'
27
+ self.tokenizer = tok_chars
28
+ self.stoi = stoi
29
+ self.itos = itos
30
+ self.vocab_size = len(stoi)
31
+ print('vocab size:', self.vocab_size)
32
+ self.data = encode(data, self.stoi, self.tokenizer)
33
+ self.data_size = len(self.data)
34
+ print(f'data has {self.data_size} tokens')
35
+
36
+ def __len__(self):
37
+ return self.epoch_length_fixed
38
+
39
+ def __getitem__(self, idx):
40
+ cues = []
41
+ idx_randm = random.randint(0, len(self.data) - (self.ctx_len) * 4)
42
+ i = idx_randm
43
+
44
+ while True:
45
+ if self.data[i] == self.stoi[self.start_token]:
46
+ cues = [i]
47
+ break
48
+ else:
49
+ i = (i + 1) % len(self.data)
50
+
51
+ if not cues:
52
+ return None
53
+
54
+ start_idx = cues[0]
55
+ dix = self.data[start_idx : start_idx + self.ctx_len + 2]
56
+
57
+ # 96 tick resolution
58
+ time_shift = [
59
+ [0, 0, 0, 0, 0, 7, 6, 8, 0, 7, 6, 8, 0],
60
+ [0, 0, 0, 0, 1, 5, 3, 6, 1, 5, 3, 6, 0],
61
+ ]
62
+
63
+ data_aug = random.choice([True, False])
64
+
65
+ t = dix[2:2 + self.ctx_len] # testing
66
+
67
+ if data_aug:
68
+ ts_rndm = random.choice(time_shift)
69
+ ts = ts_rndm * ((self.ctx_len - 1) // len(ts_rndm) + 1)
70
+ tsx = torch.tensor(ts[:self.ctx_len])
71
+
72
+ for j in reversed(range(len(t))):
73
+ if j % 13 not in range(2, 12):
74
+ continue
75
+
76
+ aug_int = t[j] + tsx[j]
77
+ if aug_int >= 10 and (aug_int not in [10, 11] or j not in [9, 10]):
78
+ left_int = aug_int // 10
79
+ right_int = aug_int % 10
80
+ if j > 0:
81
+ t[j - 1] += left_int
82
+ t[j] = right_int
83
+ else:
84
+ t[j] = aug_int
85
+
86
+ x = t
87
+ y = t[1:] + [t[-1]]
88
+ else:
89
+ x = dix[:-1][:self.ctx_len]
90
+ y = dix[1:][:self.ctx_len]
91
+
92
+ x = torch.tensor(x, dtype=torch.int64)
93
+ y = torch.tensor(y, dtype=torch.int64)
94
+
95
+ return x, y
96
+
97
+
98
+ class TOKENIZER():
99
+ def __init__(self):
100
+ self.tokenizer = tok_chars
101
+ self.stoi = stoi
102
+ self.itos = itos
103
+ self.vocab_size = len(self.stoi)
104
+
105
+ def encode(self, text):
106
+ matches = self.tokenizer.findall(text)
107
+ return [self.stoi[c] for c in matches if c in self.stoi]
108
+
109
+ def decode(self, encoded):
110
+ return ''.join([self.itos[i] for i in encoded])
111
+
112
+ def sample_logits(self, out, x, ctx_len, temperature=1.0, top_k=50):
113
+ probs = F.softmax(torch.tensor(out), dim=-1)
114
+
115
+ if top_k > 0:
116
+ top_k = min(top_k, probs.size(-1))
117
+ sorted_probs, sorted_indices = torch.topk(probs, top_k)
118
+ probs.fill_(0)
119
+ probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs)
120
+
121
+ if temperature != 1.0:
122
+ probs = probs.pow(1.0 / temperature)
123
+
124
+ return torch.multinomial(probs, num_samples=1)[0]
125
+