|
|
import os, sys, time, tempfile, shutil, subprocess |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import torchaudio |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") |
|
|
|
|
|
if "TORCH_NUM_THREADS" in os.environ: |
|
|
try: |
|
|
torch.set_num_threads(max(1, int(os.environ["TORCH_NUM_THREADS"]))) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
try: |
|
|
torch.set_float32_matmul_precision("high") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
MODEL_SR = 16000 |
|
|
RESEMBLE_SR = 44100 |
|
|
|
|
|
CPU_TIME_BUDGET = float(os.environ.get("CPU_TIME_BUDGET", "160")) |
|
|
GPU_TIME_BUDGET = float(os.environ.get("GPU_TIME_BUDGET", "70")) |
|
|
RESEMBLE_TIMEOUT = float(os.environ.get("RESEMBLE_TIMEOUT", "90")) |
|
|
|
|
|
|
|
|
try: |
|
|
from speechbrain.inference.separation import SepformerSeparation |
|
|
except Exception: |
|
|
from speechbrain.pretrained import SepformerSeparation |
|
|
|
|
|
_DNS4_MODEL = None |
|
|
def load_dns4_model(): |
|
|
global _DNS4_MODEL |
|
|
if _DNS4_MODEL is not None: |
|
|
return _DNS4_MODEL |
|
|
_DNS4_MODEL = SepformerSeparation.from_hparams( |
|
|
source="speechbrain/sepformer-dns4-16k-enhancement", |
|
|
savedir="pretrained_models/sepformer-dns4-16k-enhancement", |
|
|
run_opts={"device": DEVICE}, |
|
|
) |
|
|
return _DNS4_MODEL |
|
|
|
|
|
|
|
|
def to_mono(wav: np.ndarray) -> np.ndarray: |
|
|
return wav.mean(axis=1).astype(np.float32) if wav.ndim == 2 else wav.astype(np.float32) |
|
|
|
|
|
def resample_np(wav: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray: |
|
|
if sr_in == sr_out: |
|
|
return wav |
|
|
t = torch.from_numpy(wav).float().unsqueeze(0) |
|
|
out = torchaudio.functional.resample(t, sr_in, sr_out) |
|
|
return out.squeeze(0).cpu().numpy() |
|
|
|
|
|
def crossfade_add(dst: np.ndarray, src: np.ndarray, start: int, fade: int): |
|
|
end = start + len(src) |
|
|
if end > len(dst): |
|
|
pad = end - len(dst) |
|
|
dst.resize(len(dst) + pad, refcheck=False) |
|
|
if fade > 0: |
|
|
fade = min(fade, len(src), len(dst) - start) |
|
|
fade_in = np.linspace(0.0, 1.0, fade, dtype=np.float32) |
|
|
fade_out = 1.0 - fade_in |
|
|
dst[start:start+fade] = dst[start:start+fade] * fade_out + src[:fade] * fade_in |
|
|
if start+fade < end: |
|
|
dst[start+fade:end] = src[fade:] |
|
|
else: |
|
|
dst[start:end] = src |
|
|
|
|
|
def trim_silence(wav: np.ndarray, sr: int, thresh_db: float = -55.0, look_sec: float = 0.15) -> np.ndarray: |
|
|
if len(wav) == 0: |
|
|
return wav |
|
|
hop = max(1, int(0.02 * sr)) |
|
|
win = max(hop, int(0.05 * sr)) |
|
|
pad = (win - hop) // 2 |
|
|
x = np.pad(wav, (pad, pad)) |
|
|
frames = [] |
|
|
for i in range(0, len(x) - win + 1, hop): |
|
|
frames.append(np.sqrt(np.mean(x[i:i+win] ** 2) + 1e-9)) |
|
|
rms = np.array(frames, dtype=np.float32) |
|
|
rms_db = 20.0 * np.log10(rms + 1e-9) |
|
|
active = rms_db > thresh_db |
|
|
if not np.any(active): |
|
|
return wav |
|
|
first = np.argmax(active) |
|
|
last = len(active) - 1 - np.argmax(active[::-1]) |
|
|
pre = int(max(0, first * hop - look_sec * sr)) |
|
|
post = int(min(len(wav), (last * hop + win) + look_sec * sr)) |
|
|
return wav[pre:post] |
|
|
|
|
|
|
|
|
def bandpass_torch(x: torch.Tensor, sr: int, low_hz: float, high_hz: float) -> torch.Tensor: |
|
|
y = torchaudio.functional.highpass_biquad(x, sr, cutoff_freq=float(low_hz)) |
|
|
y = torchaudio.functional.lowpass_biquad(y, sr, cutoff_freq=float(high_hz)) |
|
|
return y |
|
|
|
|
|
def speechband_ratio_torch(x: torch.Tensor, sr: int) -> float: |
|
|
if x.dim() == 1: |
|
|
x = x.unsqueeze(0) |
|
|
band = bandpass_torch(x, sr, 300.0, 3400.0) |
|
|
num = float((band ** 2).mean().item() + 1e-12) |
|
|
den = float((x ** 2).mean().item() + 1e-12) |
|
|
return num / den |
|
|
|
|
|
|
|
|
def _safe_cf(freq: float, sr: int, max_ratio: float = 0.45) -> float: |
|
|
"""Clamp center/cutoff to <= max_ratio * sr (leave headroom from Nyquist).""" |
|
|
return float(np.clip(freq, 20.0, max_ratio * sr)) |
|
|
|
|
|
def _choose_speech_source(est) -> torch.Tensor: |
|
|
cand = [] |
|
|
if isinstance(est, (list, tuple)): |
|
|
for e in est: |
|
|
cand.append(torch.as_tensor(e)) |
|
|
else: |
|
|
y = torch.as_tensor(est) |
|
|
if y.dim() == 3: |
|
|
if y.shape[1] <= y.shape[2]: |
|
|
for s in range(y.shape[1]): |
|
|
cand.append(y[0, s, :]) |
|
|
else: |
|
|
for s in range(y.shape[2]): |
|
|
cand.append(y[0, :, s]) |
|
|
elif y.dim() == 2: |
|
|
if y.shape[0] <= y.shape[1]: |
|
|
for s in range(y.shape[0]): |
|
|
cand.append(y[s, :]) |
|
|
else: |
|
|
for s in range(y.shape[1]): |
|
|
cand.append(y[:, s]) |
|
|
elif y.dim() == 1: |
|
|
cand.append(y) |
|
|
else: |
|
|
cand.append(y.reshape(-1)) |
|
|
best_idx, best_score = 0, -1.0 |
|
|
for i, c in enumerate(cand): |
|
|
c = c.float().unsqueeze(0) |
|
|
score = speechband_ratio_torch(c, MODEL_SR) |
|
|
if score > best_score: |
|
|
best_score = score |
|
|
best_idx = i |
|
|
return cand[best_idx].float() |
|
|
|
|
|
def enhance_chunk_tensor(model: SepformerSeparation, chunk_tensor_1xT: torch.Tensor) -> torch.Tensor: |
|
|
with torch.no_grad(): |
|
|
try: |
|
|
use_amp = (DEVICE == "cuda") |
|
|
with torch.cuda.amp.autocast(enabled=use_amp): |
|
|
est = model.separate_batch(chunk_tensor_1xT.to(DEVICE)) |
|
|
except Exception: |
|
|
tmpd = tempfile.mkdtemp(prefix="sf_chunk_") |
|
|
inpath = os.path.join(tmpd, "in.wav") |
|
|
sf.write(inpath, chunk_tensor_1xT.squeeze(0).cpu().numpy(), MODEL_SR, subtype="PCM_16") |
|
|
est = model.separate_file(path=inpath) |
|
|
shutil.rmtree(tmpd, ignore_errors=True) |
|
|
speech = _choose_speech_source(est).unsqueeze(0) |
|
|
return speech.cpu() |
|
|
|
|
|
def process_in_chunks(model, wav_16k: np.ndarray, chunk_sec: float, overlap_sec: float, |
|
|
time_budget: float = None, progress=None, base=0.0, span=0.5) -> np.ndarray: |
|
|
t_start = time.time() |
|
|
chunk = int(chunk_sec * MODEL_SR) |
|
|
overlap = int(overlap_sec * MODEL_SR) |
|
|
step = max(1, chunk - overlap) |
|
|
n = len(wav_16k) |
|
|
|
|
|
if n <= chunk: |
|
|
ten = torch.from_numpy(wav_16k).float().unsqueeze(0) |
|
|
out = enhance_chunk_tensor(model, ten).squeeze(0).numpy() |
|
|
if progress: progress(base + span, "Enhancement (single chunk)") |
|
|
return out |
|
|
|
|
|
out = np.zeros(n, dtype=np.float32) |
|
|
pos, first, fade = 0, True, overlap |
|
|
total_steps = max(1, (n + step - 1) // step) |
|
|
done_steps = 0 |
|
|
|
|
|
while pos < n: |
|
|
if time_budget is not None and (time.time() - t_start) > time_budget: |
|
|
if progress: progress(base + span * 0.98, "Time budget reached — returning partial") |
|
|
break |
|
|
end = min(pos + chunk, n) |
|
|
piece = wav_16k[pos:end] |
|
|
if len(piece) < chunk: |
|
|
pad = np.zeros(chunk, dtype=np.float32); pad[:len(piece)] = piece; piece = pad |
|
|
ten = torch.from_numpy(piece).float().unsqueeze(0) |
|
|
enhanced = enhance_chunk_tensor(model, ten).squeeze(0).numpy()[:end - pos] |
|
|
if first: out[pos:end] = enhanced; first = False |
|
|
else: crossfade_add(out, enhanced, pos, fade) |
|
|
pos += step |
|
|
done_steps += 1 |
|
|
if progress: progress(base + span * (done_steps / total_steps), f"Enhancement {done_steps}/{total_steps}") |
|
|
|
|
|
if pos < n: |
|
|
out[pos:] = wav_16k[pos:] |
|
|
return out |
|
|
|
|
|
|
|
|
def spectral_gate_np(wav_16k: np.ndarray, strength: float, progress=None, base=0.5, span=0.2) -> np.ndarray: |
|
|
if progress: progress(base, "Post-filter: estimating noise") |
|
|
x = torch.from_numpy(wav_16k).float() |
|
|
n_fft, hop = 1024, 256 |
|
|
win = torch.hann_window(n_fft) |
|
|
if x.ndim == 1: x = x.unsqueeze(0) |
|
|
X = torch.stft(x, n_fft=n_fft, hop_length=hop, window=win, return_complex=True) |
|
|
mag = X.abs() |
|
|
noise = torch.quantile(mag, q=0.2, dim=-1, keepdim=True) |
|
|
S2, N2 = mag**2, noise**2 |
|
|
beta = 0.8 + 3.2 * strength |
|
|
p = 1.0 + 2.0 * strength |
|
|
wiener = torch.clamp((S2 - beta * N2) / (S2 + 1e-8), 0.0, 1.0) ** p |
|
|
import torch.nn.functional as F |
|
|
wiener_s = F.avg_pool2d(wiener.unsqueeze(1), kernel_size=(5, 5), stride=1, padding=2).squeeze(1) |
|
|
if progress: progress(base + span * 0.5, "Post-filter: applying mask") |
|
|
Y = wiener_s * X |
|
|
y = torch.istft(Y, n_fft=n_fft, hop_length=hop, window=win, length=x.shape[-1]) |
|
|
if progress: progress(base + span * 0.98, "Post-filter: finishing") |
|
|
return y.squeeze(0).cpu().numpy() |
|
|
|
|
|
def strict_background_kill(wav_16k: np.ndarray, level: float, progress=None, base=0.7, span=0.2) -> np.ndarray: |
|
|
if level <= 0.0: |
|
|
return wav_16k |
|
|
x = torch.from_numpy(wav_16k).float().unsqueeze(0) |
|
|
sr = MODEL_SR |
|
|
if progress: progress(base + span * 0.05, "Strict: high-pass") |
|
|
x = torchaudio.functional.highpass_biquad(x, sr, cutoff_freq=_safe_cf(80.0, sr)) |
|
|
|
|
|
lp_cut = (0.45 - 0.10 * float(np.clip(level, 0.0, 1.0))) * sr |
|
|
lp_cut = _safe_cf(lp_cut, sr) |
|
|
if progress: progress(base + span * 0.15, f"Strict: low-pass ~{int(lp_cut)} Hz") |
|
|
x = torchaudio.functional.lowpass_biquad(x, sr, cutoff_freq=lp_cut) |
|
|
if progress: progress(base + span * 0.35, "Strict: spectral gate") |
|
|
sg_strength = 0.55 + 0.45 * level |
|
|
y = spectral_gate_np(x.squeeze(0).cpu().numpy(), strength=float(np.clip(sg_strength, 0.0, 1.0)), |
|
|
progress=progress, base=base + span * 0.35, span=span * 0.6) |
|
|
return y |
|
|
|
|
|
|
|
|
def _eq_chain(x: torch.Tensor, sr: int, amt: float) -> torch.Tensor: |
|
|
""" |
|
|
Safe clarity curve for 16 kHz+: |
|
|
- Body @ ~180 Hz (+1..+2.5 dB) |
|
|
- Presence @ ~3.2 kHz (+2..+4 dB) |
|
|
- Air @ ~0.40*SR (<= 0.45*SR) (+0.5..+1.5 dB) |
|
|
""" |
|
|
if x.dim() == 1: |
|
|
x = x.unsqueeze(0) |
|
|
body_gain = 1.0 + 1.5 * amt |
|
|
pres_gain = 2.0 + 2.0 * amt |
|
|
air_gain = 0.5 + 1.0 * amt |
|
|
|
|
|
cf_body = _safe_cf(180.0, sr) |
|
|
cf_pres = _safe_cf(3200.0, sr) |
|
|
cf_air = _safe_cf(0.40 * sr, sr) |
|
|
|
|
|
y = torchaudio.functional.equalizer_biquad(x, sr, center_freq=cf_body, gain=body_gain, Q=0.8) |
|
|
y = torchaudio.functional.equalizer_biquad(y, sr, center_freq=cf_pres, gain=pres_gain, Q=0.9) |
|
|
y = torchaudio.functional.equalizer_biquad(y, sr, center_freq=cf_air, gain=air_gain, Q=0.7) |
|
|
return y |
|
|
|
|
|
def _soft_compressor_np(wav: np.ndarray, thresh_db=-20.0, ratio=2.0, attack_ms=8.0, release_ms=120.0, makeup_db=2.0, sr=16000) -> np.ndarray: |
|
|
x = wav.astype(np.float32) |
|
|
atk = np.exp(-1.0 / (sr * (attack_ms / 1000.0))) |
|
|
rel = np.exp(-1.0 / (sr * (release_ms / 1000.0))) |
|
|
env = 0.0 |
|
|
gain = np.ones_like(x) |
|
|
thr = 10 ** (thresh_db / 20.0) |
|
|
for i, s in enumerate(x): |
|
|
a = abs(s) |
|
|
env = atk * env + (1 - atk) * a if a > env else rel * env + (1 - rel) * a |
|
|
if env <= 1e-9: |
|
|
g = 1.0 |
|
|
else: |
|
|
if env <= thr: |
|
|
g = 1.0 |
|
|
else: |
|
|
over = env / thr |
|
|
comp = over ** (1.0 - 1.0/ratio) |
|
|
g = (thr * comp) / env |
|
|
gain[i] = g |
|
|
y = x * gain |
|
|
mk = 10 ** (makeup_db / 20.0) |
|
|
y = y * mk |
|
|
peak = np.max(np.abs(y)) + 1e-9 |
|
|
if peak > 0.999: |
|
|
y = y / peak * 0.999 |
|
|
return y.astype(np.float32) |
|
|
|
|
|
def voice_enhance(wav_16k: np.ndarray, amount: float, deess: float = 0.20) -> np.ndarray: |
|
|
""" |
|
|
amount in [0,1]: scales EQ and compression. |
|
|
deess in [0,1]: narrow cut around ~0.35*SR. |
|
|
""" |
|
|
x = torch.from_numpy(wav_16k).float().unsqueeze(0) |
|
|
sr = MODEL_SR |
|
|
y = _eq_chain(x, sr, amt=float(np.clip(amount, 0.0, 1.0))).squeeze(0) |
|
|
if deess > 0: |
|
|
cf_ds = _safe_cf(0.35 * sr, sr) |
|
|
cut_db = -2.0 * float(deess) * 6.0 |
|
|
y = torchaudio.functional.equalizer_biquad(y.unsqueeze(0), sr, center_freq=cf_ds, gain=cut_db, Q=1.2).squeeze(0) |
|
|
y = y.cpu().numpy() |
|
|
thr = -22.0 + 4.0 * amount |
|
|
ratio = 1.6 + 0.7 * amount |
|
|
makeup = 1.5 + 1.5 * amount |
|
|
y = _soft_compressor_np(y, thresh_db=thr, ratio=ratio, makeup_db=makeup, sr=sr) |
|
|
return y |
|
|
|
|
|
|
|
|
def _find_resemble_cmd(): |
|
|
candidates = [ |
|
|
["resemble_enhance"], |
|
|
["resemble-enhance"], |
|
|
[sys.executable, "-m", "resemble_enhance.enhance"], |
|
|
] |
|
|
for cmd in candidates: |
|
|
exe = shutil.which(cmd[0]) |
|
|
if len(cmd) > 1 or exe is not None: |
|
|
return cmd |
|
|
return None |
|
|
|
|
|
def run_resemble_enhance(audio: np.ndarray, in_sr: int, denoise_only: bool, |
|
|
timeout_sec: float, progress=None, base=0.0, span=0.2) -> np.ndarray: |
|
|
cmd = _find_resemble_cmd() |
|
|
if cmd is None: |
|
|
raise RuntimeError("Resemble Enhance CLI not found. Is the package installed?") |
|
|
if progress: progress(base, "Resemble: resampling to 44.1 kHz") |
|
|
audio44 = resample_np(audio, in_sr, RESEMBLE_SR) |
|
|
in_dir = tempfile.mkdtemp(prefix="resem_in_") |
|
|
out_dir = tempfile.mkdtemp(prefix="resem_out_") |
|
|
in_path = os.path.join(in_dir, "in.wav") |
|
|
sf.write(in_path, audio44, RESEMBLE_SR, subtype="PCM_16") |
|
|
full_cmd = cmd + [in_dir, out_dir] + (["--denoise_only"] if denoise_only else []) |
|
|
if progress: progress(base + span * 0.05, "Resemble: enhancing…") |
|
|
proc = subprocess.Popen(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) |
|
|
t0 = time.time() |
|
|
try: |
|
|
while True: |
|
|
if proc.poll() is not None: |
|
|
break |
|
|
_ = proc.stdout.readline() if proc.stdout else "" |
|
|
elapsed = time.time() - t0 |
|
|
frac = min(0.95, (elapsed % 5.0) / 5.0) |
|
|
if progress: progress(base + span * (0.05 + 0.9 * frac), "Resemble: enhancing…") |
|
|
if elapsed > timeout_sec: |
|
|
proc.kill() |
|
|
raise TimeoutError(f"Resemble Enhance exceeded {timeout_sec}s") |
|
|
time.sleep(0.1) |
|
|
if proc.returncode != 0: |
|
|
raise RuntimeError(f"Resemble Enhance exited with code {proc.returncode}") |
|
|
out_path = os.path.join(out_dir, "in.wav") |
|
|
if not os.path.exists(out_path): |
|
|
cand = [f for f in os.listdir(out_dir) if f.lower().endswith((".wav", ".flac", ".mp3", ".m4a", ".ogg"))] |
|
|
if not cand: |
|
|
raise RuntimeError("Resemble Enhance did not produce an audio file.") |
|
|
out_path = os.path.join(out_dir, cand[0]) |
|
|
if progress: progress(base + span * 0.98, "Resemble: reading result") |
|
|
out_audio, out_sr = sf.read(out_path, dtype="float32", always_2d=False) |
|
|
out_audio = to_mono(out_audio) |
|
|
back_audio = resample_np(out_audio, out_sr, in_sr) |
|
|
return back_audio |
|
|
finally: |
|
|
shutil.rmtree(in_dir, ignore_errors=True) |
|
|
shutil.rmtree(out_dir, ignore_errors=True) |
|
|
|
|
|
def ensure_runtime(target_seconds: float, start_time: float, progress=None, base=0.8, span=0.2): |
|
|
elapsed = time.time() - start_time |
|
|
remaining = max(0.0, target_seconds - elapsed) |
|
|
if remaining <= 0 or (progress is None): |
|
|
return |
|
|
steps = max(5, int(remaining / 0.25)) |
|
|
for i in range(steps): |
|
|
time.sleep(remaining / steps) |
|
|
progress(base + span * (i + 1) / steps, "Final polish") |
|
|
|
|
|
|
|
|
def denoise( |
|
|
file_path: str, |
|
|
model_label: str = "SepFormer (DNS4, 16 kHz)", |
|
|
processing_mode: str = "Standard (1 pass + post-filter)", |
|
|
chunk_seconds: float = 8.0, |
|
|
overlap_seconds: float = 0.2, |
|
|
wet_percent: int = 100, |
|
|
post_strength: int = 55, |
|
|
strictness: int = 65, |
|
|
voice_enhance_on: bool = True, |
|
|
voice_enhance_amount: int = 55, |
|
|
use_resemble: bool = False, |
|
|
resemble_mode: str = "Enhance (denoise+enhance)", |
|
|
output_sr_choice: str = "Original", |
|
|
target_seconds: int = 18, |
|
|
progress=gr.Progress(track_tqdm=False), |
|
|
): |
|
|
start_t = time.time() |
|
|
budget = GPU_TIME_BUDGET if DEVICE == "cuda" else CPU_TIME_BUDGET |
|
|
|
|
|
progress(0.02, "Loading file") |
|
|
if not file_path: |
|
|
return None, None, "No file provided." |
|
|
|
|
|
wav, sr = sf.read(file_path, dtype="float32", always_2d=False) |
|
|
wav = to_mono(wav) |
|
|
dur_sec = len(wav) / float(sr) |
|
|
|
|
|
progress(0.05, "Trimming silence") |
|
|
wav = trim_silence(wav, sr) |
|
|
|
|
|
progress(0.07, "Resampling") |
|
|
wav_16k = resample_np(wav, sr, MODEL_SR) |
|
|
|
|
|
mode = processing_mode.lower() |
|
|
cpu = (DEVICE == "cpu") |
|
|
if mode.startswith("standard"): |
|
|
passes, do_gate_times = 1, 1 |
|
|
elif mode.startswith("pro"): |
|
|
passes, do_gate_times = (1 if cpu else 2), 1 |
|
|
else: |
|
|
passes, do_gate_times = (2 if cpu else 3), 2 |
|
|
if (not cpu) and strictness >= 80 and passes < 3: |
|
|
passes += 1 |
|
|
|
|
|
extra_strict_stage = 1 if strictness > 0 else 0 |
|
|
enhance_stage = 1 if voice_enhance_on else 0 |
|
|
resemble_stage = 1 if use_resemble else 0 |
|
|
total_slices = passes + do_gate_times + extra_strict_stage + enhance_stage + resemble_stage + 1 |
|
|
slice_span = 0.72 / max(1, total_slices) |
|
|
cur_base = 0.08 |
|
|
|
|
|
audio = wav_16k |
|
|
model = load_dns4_model() |
|
|
for idx in range(passes): |
|
|
progress(cur_base, f"SepFormer pass {idx+1}/{passes}") |
|
|
heavy_left = (passes - idx) + (1 if use_resemble else 0) |
|
|
remain_budget = max(0.0, budget - (time.time() - start_t)) |
|
|
pass_budget = None if cpu else (remain_budget / (heavy_left + 0.5)) if heavy_left > 0 else None |
|
|
audio = process_in_chunks(model, audio, float(chunk_seconds), float(overlap_seconds), |
|
|
time_budget=pass_budget if cpu else None, |
|
|
progress=progress, base=cur_base, span=slice_span) |
|
|
cur_base += slice_span |
|
|
|
|
|
strength = float(np.clip(post_strength / 100.0, 0.0, 1.0)) |
|
|
for j in range(do_gate_times): |
|
|
progress(cur_base, f"Post-filter {j+1}/{do_gate_times}") |
|
|
audio = spectral_gate_np(audio, strength=strength, progress=progress, base=cur_base, span=slice_span) |
|
|
cur_base += slice_span |
|
|
|
|
|
if strictness > 0: |
|
|
progress(cur_base, "Strict background kill") |
|
|
audio = strict_background_kill(audio, level=float(np.clip(strictness / 100.0, 0.0, 1.0)), |
|
|
progress=progress, base=cur_base, span=slice_span) |
|
|
cur_base += slice_span |
|
|
|
|
|
if voice_enhance_on: |
|
|
progress(cur_base, "Voice Enhance (clarity)") |
|
|
ve_amt = float(np.clip(voice_enhance_amount / 100.0, 0.0, 1.0)) |
|
|
audio = voice_enhance(audio, amount=ve_amt, deess=0.20) |
|
|
cur_base += slice_span |
|
|
|
|
|
if use_resemble: |
|
|
denoise_only = resemble_mode.lower().startswith("denoise") |
|
|
progress(cur_base, "Resemble Enhance: preparing") |
|
|
try: |
|
|
remain_budget = max(0.0, budget - (time.time() - start_t)) |
|
|
r_timeout = min(RESEMBLE_TIMEOUT, remain_budget * 0.8) if cpu else RESEMBLE_TIMEOUT |
|
|
audio = run_resemble_enhance(audio, MODEL_SR, denoise_only=denoise_only, |
|
|
timeout_sec=r_timeout, |
|
|
progress=progress, base=cur_base, span=slice_span) |
|
|
except Exception as e: |
|
|
progress(cur_base + slice_span * 0.99, f"Resemble skipped: {e}") |
|
|
cur_base += slice_span |
|
|
|
|
|
progress(cur_base, "Mixing & resampling") |
|
|
wet = float(np.clip(wet_percent / 100.0, 0.0, 1.0)) |
|
|
mixed_16k = wet * audio + (1.0 - wet) * wav_16k |
|
|
|
|
|
out_sr = sr if output_sr_choice == "Original" else MODEL_SR |
|
|
out_wav = resample_np(mixed_16k, MODEL_SR, out_sr) |
|
|
|
|
|
peak = float(np.max(np.abs(out_wav)) + 1e-9) |
|
|
if peak > 0.999: |
|
|
out_wav = out_wav / peak * 0.999 |
|
|
|
|
|
ensure_runtime(float(target_seconds), start_t, progress=progress, base=cur_base, span=min(0.2, 1.0 - cur_base)) |
|
|
progress(0.99, "Saving") |
|
|
|
|
|
out_path = tempfile.mkstemp(suffix=".wav")[1] |
|
|
sf.write(out_path, out_wav, out_sr, subtype="PCM_16") |
|
|
|
|
|
status = ( |
|
|
f"Preset: Crisp & Fast (SAFE) | Model: {model_label} | Passes: {passes}+post({do_gate_times})" |
|
|
f"{' + Strict' if strictness>0 else ''}{' + VoiceEnh' if voice_enhance_on else ''}" |
|
|
f"{' + Resemble' if use_resemble else ''} | Input: {sr} Hz, {dur_sec:.1f}s | " |
|
|
f"Output: {out_sr} Hz | Wet: {wet_percent}% | Post: {post_strength}% | " |
|
|
f"Strict: {strictness}% | VoiceEnh: {voice_enhance_amount}% | Device: {DEVICE.upper()}" |
|
|
) |
|
|
progress(1.0, "Done") |
|
|
return out_path, file_path, status |
|
|
|
|
|
|
|
|
def apply_preset(preset: str): |
|
|
cpu = (DEVICE == "cpu") |
|
|
if preset.startswith("Crisp & Fast"): |
|
|
mode = "Standard (1 pass + post-filter)" |
|
|
chunk = 8.0 if cpu else 10.0 |
|
|
overlap = 0.2 if cpu else 0.3 |
|
|
wet = 100 |
|
|
post_strength = 55 |
|
|
strictness = 65 |
|
|
voice_on = True |
|
|
voice_amt = 55 |
|
|
use_resem = False |
|
|
resemmode = "Enhance (denoise+enhance)" |
|
|
out_sr = "Original" |
|
|
target = 16 if cpu else 18 |
|
|
elif preset.startswith("Broadcast Clean"): |
|
|
mode = "Pro (dual-pass + post-filter)" if not cpu else "Standard (1 pass + post-filter)" |
|
|
chunk = 10.0 if cpu else 12.0 |
|
|
overlap = 0.25 if cpu else 0.4 |
|
|
wet = 100 |
|
|
post_strength = 58 |
|
|
strictness = 70 |
|
|
voice_on = True |
|
|
voice_amt = 60 |
|
|
use_resem = False |
|
|
resemmode = "Enhance (denoise+enhance)" |
|
|
out_sr = "Original" |
|
|
target = 20 if cpu else 22 |
|
|
else: |
|
|
mode = "Ultra (triple-pass + extra post)" if not cpu else "Pro (dual-pass + post-filter)" |
|
|
chunk = 10.0 |
|
|
overlap = 0.5 |
|
|
wet = 100 |
|
|
post_strength = 62 |
|
|
strictness = 80 |
|
|
voice_on = True |
|
|
voice_amt = 60 |
|
|
use_resem = (False if cpu else True) |
|
|
resemmode = "Enhance (denoise+enhance)" |
|
|
out_sr = "16 kHz" |
|
|
target = 26 if cpu else 30 |
|
|
|
|
|
return (mode, chunk, overlap, wet, post_strength, strictness, voice_on, voice_amt, |
|
|
use_resem, resemmode, out_sr, target) |
|
|
|
|
|
|
|
|
TITLE = "AI Background Noise Remover — SepFormer (DNS4) • Safe & Crisp Defaults" |
|
|
DESC = ( |
|
|
"Safe defaults to avoid harsh highs. SepFormer (DNS4) + smart post-filter + Voice Enhance. " |
|
|
"Use presets to trade speed vs. aggressiveness." |
|
|
) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as app: |
|
|
gr.Markdown(f"# {TITLE}") |
|
|
gr.Markdown(DESC) |
|
|
|
|
|
with gr.Row(): |
|
|
audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Noisy Audio") |
|
|
|
|
|
with gr.Row(): |
|
|
preset = gr.Dropdown( |
|
|
["Crisp & Fast (default)", "Broadcast Clean (balanced)", "Max Clean (slow)"], |
|
|
value="Crisp & Fast (default)", |
|
|
label="Preset" |
|
|
) |
|
|
model_dd = gr.Dropdown( |
|
|
["SepFormer (DNS4, 16 kHz)"], |
|
|
value="SepFormer (DNS4, 16 kHz)", |
|
|
label="Enhancement Model" |
|
|
) |
|
|
mode_dd = gr.Radio( |
|
|
["Standard (1 pass + post-filter)", "Pro (dual-pass + post-filter)", "Ultra (triple-pass + extra post)"], |
|
|
value="Standard (1 pass + post-filter)", |
|
|
label="Processing Mode" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
chunk = gr.Slider(5, 30, value=8, step=1, label="Chunk Length (s)") |
|
|
overlap = gr.Slider(0.0, 3.0, value=0.2, step=0.05, label="Overlap Crossfade (s)") |
|
|
out_sr = gr.Radio(["Original", "16 kHz"], value="Original", label="Output Sample Rate") |
|
|
|
|
|
with gr.Row(): |
|
|
wet = gr.Slider(0, 100, value=100, step=1, label="Enhancement Mix (Wet %)") |
|
|
post_strength = gr.Slider(0, 100, value=55, step=1, label="Post Polish Strength") |
|
|
strictness = gr.Slider(0, 100, value=65, step=1, label="Strictness (extra background kill)") |
|
|
|
|
|
with gr.Row(): |
|
|
voice_enhance_on = gr.Checkbox(True, label="Voice Enhance (clarity/loudness)") |
|
|
voice_enhance_amount = gr.Slider(0, 100, value=55, step=1, label="Voice Enhance Amount") |
|
|
|
|
|
with gr.Row(): |
|
|
use_resemble = gr.Checkbox(False, label="Resemble Enhance (optional, slower)") |
|
|
resemble_mode = gr.Radio( |
|
|
["Enhance (denoise+enhance)", "Denoise only"], |
|
|
value="Enhance (denoise+enhance)", |
|
|
label="Resemble Mode" |
|
|
) |
|
|
target_secs = gr.Slider(5, 60, value=18, step=1, label="Target Runtime (s)") |
|
|
|
|
|
preset.change( |
|
|
fn=apply_preset, |
|
|
inputs=[preset], |
|
|
outputs=[mode_dd, chunk, overlap, wet, post_strength, strictness, voice_enhance_on, voice_enhance_amount, |
|
|
use_resemble, resemble_mode, out_sr, target_secs], |
|
|
api_name="preset" |
|
|
) |
|
|
|
|
|
run_btn = gr.Button("Enhance Audio", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
enhanced_out = gr.Audio(type="filepath", label="Enhanced (Download)") |
|
|
original_out = gr.Audio(type="filepath", label="Original") |
|
|
status = gr.Markdown("Ready.") |
|
|
|
|
|
run_btn.click( |
|
|
fn=denoise, |
|
|
inputs=[audio_in, model_dd, mode_dd, chunk, overlap, wet, post_strength, strictness, |
|
|
voice_enhance_on, voice_enhance_amount, use_resemble, resemble_mode, out_sr, target_secs], |
|
|
outputs=[enhanced_out, original_out, status], |
|
|
api_name="enhance", |
|
|
) |
|
|
|
|
|
try: |
|
|
app.queue(concurrency_count=1, max_size=16, status_update_rate=1) |
|
|
except TypeError: |
|
|
app.queue() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch() |
|
|
|