Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -208,27 +208,35 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence,
|
|
| 208 |
print(f"Duration: {duration} seconds")
|
| 209 |
# inference
|
| 210 |
with torch.inference_mode():
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
# Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
|
| 233 |
# generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
|
| 234 |
|
|
|
|
| 208 |
print(f"Duration: {duration} seconds")
|
| 209 |
# inference
|
| 210 |
with torch.inference_mode():
|
| 211 |
+
# Ensure all inputs are on the same device as ema_model
|
| 212 |
+
audio = audio.to(ema_model.device) # Match ema_model's device
|
| 213 |
+
final_text_list = [t.to(ema_model.device) if isinstance(t, torch.Tensor) else t for t in final_text_list]
|
| 214 |
+
generated, _ = ema_model.sample(
|
| 215 |
+
cond=audio,
|
| 216 |
+
text=final_text_list,
|
| 217 |
+
duration=duration,
|
| 218 |
+
steps=nfe_step,
|
| 219 |
+
cfg_strength=cfg_strength,
|
| 220 |
+
sway_sampling_coef=sway_sampling_coef,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Process generated tensor
|
| 224 |
+
generated = generated[:, ref_audio_len:, :]
|
| 225 |
+
generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
|
| 226 |
+
|
| 227 |
+
# Convert to appropriate dtype and device
|
| 228 |
+
generated_mel_spec = generated_mel_spec.to(dtype=torch.float16, device=vocos.device) # Ensure device matches vocos
|
| 229 |
+
generated_wave = vocos.decode(generated_mel_spec)
|
| 230 |
+
|
| 231 |
+
# Adjust wave RMS if needed
|
| 232 |
+
if rms < target_rms:
|
| 233 |
+
generated_wave = generated_wave * rms / target_rms
|
| 234 |
+
|
| 235 |
+
# Convert to numpy
|
| 236 |
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
| 237 |
+
|
| 238 |
+
# Append to list
|
| 239 |
+
generated_waves.append(generated_wave)spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
| 240 |
# Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
|
| 241 |
# generated_mel_spec = generated_mel_spec.to(dtype=torch.float32) # Convert to float32 if it's in bfloat16
|
| 242 |
|