bgremover / app.py
Diggz10's picture
Update app.py
c582b0c verified
import os, sys, time, tempfile, shutil, subprocess
import numpy as np
import soundfile as sf
import torch
import torchaudio
import gradio as gr
# ---------------- Runtime + defaults ----------------
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
if "TORCH_NUM_THREADS" in os.environ:
try:
torch.set_num_threads(max(1, int(os.environ["TORCH_NUM_THREADS"])))
except Exception:
pass
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_SR = 16000
RESEMBLE_SR = 44100
CPU_TIME_BUDGET = float(os.environ.get("CPU_TIME_BUDGET", "160"))
GPU_TIME_BUDGET = float(os.environ.get("GPU_TIME_BUDGET", "70"))
RESEMBLE_TIMEOUT = float(os.environ.get("RESEMBLE_TIMEOUT", "90"))
# ---------------- SpeechBrain SepFormer DNS4 ----------------
try:
from speechbrain.inference.separation import SepformerSeparation # SB >= 1.x
except Exception:
from speechbrain.pretrained import SepformerSeparation # SB 0.5.x
_DNS4_MODEL = None
def load_dns4_model():
global _DNS4_MODEL
if _DNS4_MODEL is not None:
return _DNS4_MODEL
_DNS4_MODEL = SepformerSeparation.from_hparams(
source="speechbrain/sepformer-dns4-16k-enhancement",
savedir="pretrained_models/sepformer-dns4-16k-enhancement",
run_opts={"device": DEVICE},
)
return _DNS4_MODEL
# ---------------- Helpers ----------------
def to_mono(wav: np.ndarray) -> np.ndarray:
return wav.mean(axis=1).astype(np.float32) if wav.ndim == 2 else wav.astype(np.float32)
def resample_np(wav: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
if sr_in == sr_out:
return wav
t = torch.from_numpy(wav).float().unsqueeze(0) # (1, T)
out = torchaudio.functional.resample(t, sr_in, sr_out)
return out.squeeze(0).cpu().numpy()
def crossfade_add(dst: np.ndarray, src: np.ndarray, start: int, fade: int):
end = start + len(src)
if end > len(dst):
pad = end - len(dst)
dst.resize(len(dst) + pad, refcheck=False)
if fade > 0:
fade = min(fade, len(src), len(dst) - start)
fade_in = np.linspace(0.0, 1.0, fade, dtype=np.float32)
fade_out = 1.0 - fade_in
dst[start:start+fade] = dst[start:start+fade] * fade_out + src[:fade] * fade_in
if start+fade < end:
dst[start+fade:end] = src[fade:]
else:
dst[start:end] = src
def trim_silence(wav: np.ndarray, sr: int, thresh_db: float = -55.0, look_sec: float = 0.15) -> np.ndarray:
if len(wav) == 0:
return wav
hop = max(1, int(0.02 * sr))
win = max(hop, int(0.05 * sr))
pad = (win - hop) // 2
x = np.pad(wav, (pad, pad))
frames = []
for i in range(0, len(x) - win + 1, hop):
frames.append(np.sqrt(np.mean(x[i:i+win] ** 2) + 1e-9))
rms = np.array(frames, dtype=np.float32)
rms_db = 20.0 * np.log10(rms + 1e-9)
active = rms_db > thresh_db
if not np.any(active):
return wav
first = np.argmax(active)
last = len(active) - 1 - np.argmax(active[::-1])
pre = int(max(0, first * hop - look_sec * sr))
post = int(min(len(wav), (last * hop + win) + look_sec * sr))
return wav[pre:post]
# ---- Band utilities
def bandpass_torch(x: torch.Tensor, sr: int, low_hz: float, high_hz: float) -> torch.Tensor:
y = torchaudio.functional.highpass_biquad(x, sr, cutoff_freq=float(low_hz))
y = torchaudio.functional.lowpass_biquad(y, sr, cutoff_freq=float(high_hz))
return y
def speechband_ratio_torch(x: torch.Tensor, sr: int) -> float:
if x.dim() == 1:
x = x.unsqueeze(0)
band = bandpass_torch(x, sr, 300.0, 3400.0)
num = float((band ** 2).mean().item() + 1e-12)
den = float((x ** 2).mean().item() + 1e-12)
return num / den
# ---------------- Safety helpers for filter freqs ----------------
def _safe_cf(freq: float, sr: int, max_ratio: float = 0.45) -> float:
"""Clamp center/cutoff to <= max_ratio * sr (leave headroom from Nyquist)."""
return float(np.clip(freq, 20.0, max_ratio * sr))
def _choose_speech_source(est) -> torch.Tensor:
cand = []
if isinstance(est, (list, tuple)):
for e in est:
cand.append(torch.as_tensor(e))
else:
y = torch.as_tensor(est)
if y.dim() == 3:
if y.shape[1] <= y.shape[2]:
for s in range(y.shape[1]):
cand.append(y[0, s, :])
else:
for s in range(y.shape[2]):
cand.append(y[0, :, s])
elif y.dim() == 2:
if y.shape[0] <= y.shape[1]:
for s in range(y.shape[0]):
cand.append(y[s, :])
else:
for s in range(y.shape[1]):
cand.append(y[:, s])
elif y.dim() == 1:
cand.append(y)
else:
cand.append(y.reshape(-1))
best_idx, best_score = 0, -1.0
for i, c in enumerate(cand):
c = c.float().unsqueeze(0)
score = speechband_ratio_torch(c, MODEL_SR)
if score > best_score:
best_score = score
best_idx = i
return cand[best_idx].float()
def enhance_chunk_tensor(model: SepformerSeparation, chunk_tensor_1xT: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
try:
use_amp = (DEVICE == "cuda")
with torch.cuda.amp.autocast(enabled=use_amp):
est = model.separate_batch(chunk_tensor_1xT.to(DEVICE))
except Exception:
tmpd = tempfile.mkdtemp(prefix="sf_chunk_")
inpath = os.path.join(tmpd, "in.wav")
sf.write(inpath, chunk_tensor_1xT.squeeze(0).cpu().numpy(), MODEL_SR, subtype="PCM_16")
est = model.separate_file(path=inpath)
shutil.rmtree(tmpd, ignore_errors=True)
speech = _choose_speech_source(est).unsqueeze(0)
return speech.cpu()
def process_in_chunks(model, wav_16k: np.ndarray, chunk_sec: float, overlap_sec: float,
time_budget: float = None, progress=None, base=0.0, span=0.5) -> np.ndarray:
t_start = time.time()
chunk = int(chunk_sec * MODEL_SR)
overlap = int(overlap_sec * MODEL_SR)
step = max(1, chunk - overlap)
n = len(wav_16k)
if n <= chunk:
ten = torch.from_numpy(wav_16k).float().unsqueeze(0)
out = enhance_chunk_tensor(model, ten).squeeze(0).numpy()
if progress: progress(base + span, "Enhancement (single chunk)")
return out
out = np.zeros(n, dtype=np.float32)
pos, first, fade = 0, True, overlap
total_steps = max(1, (n + step - 1) // step)
done_steps = 0
while pos < n:
if time_budget is not None and (time.time() - t_start) > time_budget:
if progress: progress(base + span * 0.98, "Time budget reached — returning partial")
break
end = min(pos + chunk, n)
piece = wav_16k[pos:end]
if len(piece) < chunk:
pad = np.zeros(chunk, dtype=np.float32); pad[:len(piece)] = piece; piece = pad
ten = torch.from_numpy(piece).float().unsqueeze(0)
enhanced = enhance_chunk_tensor(model, ten).squeeze(0).numpy()[:end - pos]
if first: out[pos:end] = enhanced; first = False
else: crossfade_add(out, enhanced, pos, fade)
pos += step
done_steps += 1
if progress: progress(base + span * (done_steps / total_steps), f"Enhancement {done_steps}/{total_steps}")
if pos < n:
out[pos:] = wav_16k[pos:]
return out
# ---------------- Post filters ----------------
def spectral_gate_np(wav_16k: np.ndarray, strength: float, progress=None, base=0.5, span=0.2) -> np.ndarray:
if progress: progress(base, "Post-filter: estimating noise")
x = torch.from_numpy(wav_16k).float()
n_fft, hop = 1024, 256
win = torch.hann_window(n_fft)
if x.ndim == 1: x = x.unsqueeze(0)
X = torch.stft(x, n_fft=n_fft, hop_length=hop, window=win, return_complex=True)
mag = X.abs()
noise = torch.quantile(mag, q=0.2, dim=-1, keepdim=True)
S2, N2 = mag**2, noise**2
beta = 0.8 + 3.2 * strength
p = 1.0 + 2.0 * strength
wiener = torch.clamp((S2 - beta * N2) / (S2 + 1e-8), 0.0, 1.0) ** p
import torch.nn.functional as F
wiener_s = F.avg_pool2d(wiener.unsqueeze(1), kernel_size=(5, 5), stride=1, padding=2).squeeze(1)
if progress: progress(base + span * 0.5, "Post-filter: applying mask")
Y = wiener_s * X
y = torch.istft(Y, n_fft=n_fft, hop_length=hop, window=win, length=x.shape[-1])
if progress: progress(base + span * 0.98, "Post-filter: finishing")
return y.squeeze(0).cpu().numpy()
def strict_background_kill(wav_16k: np.ndarray, level: float, progress=None, base=0.7, span=0.2) -> np.ndarray:
if level <= 0.0:
return wav_16k
x = torch.from_numpy(wav_16k).float().unsqueeze(0)
sr = MODEL_SR
if progress: progress(base + span * 0.05, "Strict: high-pass")
x = torchaudio.functional.highpass_biquad(x, sr, cutoff_freq=_safe_cf(80.0, sr))
# Map level [0,1] -> cutoff [0.45*sr .. 0.35*sr] (higher level => lower cutoff)
lp_cut = (0.45 - 0.10 * float(np.clip(level, 0.0, 1.0))) * sr
lp_cut = _safe_cf(lp_cut, sr) # clamp to <= 0.45*sr
if progress: progress(base + span * 0.15, f"Strict: low-pass ~{int(lp_cut)} Hz")
x = torchaudio.functional.lowpass_biquad(x, sr, cutoff_freq=lp_cut)
if progress: progress(base + span * 0.35, "Strict: spectral gate")
sg_strength = 0.55 + 0.45 * level
y = spectral_gate_np(x.squeeze(0).cpu().numpy(), strength=float(np.clip(sg_strength, 0.0, 1.0)),
progress=progress, base=base + span * 0.35, span=span * 0.6)
return y
# ---------------- Voice Enhance (clarity & loudness) ----------------
def _eq_chain(x: torch.Tensor, sr: int, amt: float) -> torch.Tensor:
"""
Safe clarity curve for 16 kHz+:
- Body @ ~180 Hz (+1..+2.5 dB)
- Presence @ ~3.2 kHz (+2..+4 dB)
- Air @ ~0.40*SR (<= 0.45*SR) (+0.5..+1.5 dB)
"""
if x.dim() == 1:
x = x.unsqueeze(0)
body_gain = 1.0 + 1.5 * amt # dB
pres_gain = 2.0 + 2.0 * amt # dB
air_gain = 0.5 + 1.0 * amt # dB
cf_body = _safe_cf(180.0, sr)
cf_pres = _safe_cf(3200.0, sr)
cf_air = _safe_cf(0.40 * sr, sr) # ~6.4 kHz at 16 k
y = torchaudio.functional.equalizer_biquad(x, sr, center_freq=cf_body, gain=body_gain, Q=0.8)
y = torchaudio.functional.equalizer_biquad(y, sr, center_freq=cf_pres, gain=pres_gain, Q=0.9)
y = torchaudio.functional.equalizer_biquad(y, sr, center_freq=cf_air, gain=air_gain, Q=0.7)
return y
def _soft_compressor_np(wav: np.ndarray, thresh_db=-20.0, ratio=2.0, attack_ms=8.0, release_ms=120.0, makeup_db=2.0, sr=16000) -> np.ndarray:
x = wav.astype(np.float32)
atk = np.exp(-1.0 / (sr * (attack_ms / 1000.0)))
rel = np.exp(-1.0 / (sr * (release_ms / 1000.0)))
env = 0.0
gain = np.ones_like(x)
thr = 10 ** (thresh_db / 20.0)
for i, s in enumerate(x):
a = abs(s)
env = atk * env + (1 - atk) * a if a > env else rel * env + (1 - rel) * a
if env <= 1e-9:
g = 1.0
else:
if env <= thr:
g = 1.0
else:
over = env / thr
comp = over ** (1.0 - 1.0/ratio)
g = (thr * comp) / env
gain[i] = g
y = x * gain
mk = 10 ** (makeup_db / 20.0)
y = y * mk
peak = np.max(np.abs(y)) + 1e-9
if peak > 0.999:
y = y / peak * 0.999
return y.astype(np.float32)
def voice_enhance(wav_16k: np.ndarray, amount: float, deess: float = 0.20) -> np.ndarray:
"""
amount in [0,1]: scales EQ and compression.
deess in [0,1]: narrow cut around ~0.35*SR.
"""
x = torch.from_numpy(wav_16k).float().unsqueeze(0)
sr = MODEL_SR
y = _eq_chain(x, sr, amt=float(np.clip(amount, 0.0, 1.0))).squeeze(0)
if deess > 0:
cf_ds = _safe_cf(0.35 * sr, sr)
cut_db = -2.0 * float(deess) * 6.0 # up to ~-12 dB
y = torchaudio.functional.equalizer_biquad(y.unsqueeze(0), sr, center_freq=cf_ds, gain=cut_db, Q=1.2).squeeze(0)
y = y.cpu().numpy()
thr = -22.0 + 4.0 * amount
ratio = 1.6 + 0.7 * amount
makeup = 1.5 + 1.5 * amount
y = _soft_compressor_np(y, thresh_db=thr, ratio=ratio, makeup_db=makeup, sr=sr)
return y
# ---------------- Resemble Enhance (optional) ----------------
def _find_resemble_cmd():
candidates = [
["resemble_enhance"],
["resemble-enhance"],
[sys.executable, "-m", "resemble_enhance.enhance"],
]
for cmd in candidates:
exe = shutil.which(cmd[0])
if len(cmd) > 1 or exe is not None:
return cmd
return None
def run_resemble_enhance(audio: np.ndarray, in_sr: int, denoise_only: bool,
timeout_sec: float, progress=None, base=0.0, span=0.2) -> np.ndarray:
cmd = _find_resemble_cmd()
if cmd is None:
raise RuntimeError("Resemble Enhance CLI not found. Is the package installed?")
if progress: progress(base, "Resemble: resampling to 44.1 kHz")
audio44 = resample_np(audio, in_sr, RESEMBLE_SR)
in_dir = tempfile.mkdtemp(prefix="resem_in_")
out_dir = tempfile.mkdtemp(prefix="resem_out_")
in_path = os.path.join(in_dir, "in.wav")
sf.write(in_path, audio44, RESEMBLE_SR, subtype="PCM_16")
full_cmd = cmd + [in_dir, out_dir] + (["--denoise_only"] if denoise_only else [])
if progress: progress(base + span * 0.05, "Resemble: enhancing…")
proc = subprocess.Popen(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
t0 = time.time()
try:
while True:
if proc.poll() is not None:
break
_ = proc.stdout.readline() if proc.stdout else ""
elapsed = time.time() - t0
frac = min(0.95, (elapsed % 5.0) / 5.0)
if progress: progress(base + span * (0.05 + 0.9 * frac), "Resemble: enhancing…")
if elapsed > timeout_sec:
proc.kill()
raise TimeoutError(f"Resemble Enhance exceeded {timeout_sec}s")
time.sleep(0.1)
if proc.returncode != 0:
raise RuntimeError(f"Resemble Enhance exited with code {proc.returncode}")
out_path = os.path.join(out_dir, "in.wav")
if not os.path.exists(out_path):
cand = [f for f in os.listdir(out_dir) if f.lower().endswith((".wav", ".flac", ".mp3", ".m4a", ".ogg"))]
if not cand:
raise RuntimeError("Resemble Enhance did not produce an audio file.")
out_path = os.path.join(out_dir, cand[0])
if progress: progress(base + span * 0.98, "Resemble: reading result")
out_audio, out_sr = sf.read(out_path, dtype="float32", always_2d=False)
out_audio = to_mono(out_audio)
back_audio = resample_np(out_audio, out_sr, in_sr)
return back_audio
finally:
shutil.rmtree(in_dir, ignore_errors=True)
shutil.rmtree(out_dir, ignore_errors=True)
def ensure_runtime(target_seconds: float, start_time: float, progress=None, base=0.8, span=0.2):
elapsed = time.time() - start_time
remaining = max(0.0, target_seconds - elapsed)
if remaining <= 0 or (progress is None):
return
steps = max(5, int(remaining / 0.25))
for i in range(steps):
time.sleep(remaining / steps)
progress(base + span * (i + 1) / steps, "Final polish")
# ---------------- Main pipeline ----------------
def denoise(
file_path: str,
model_label: str = "SepFormer (DNS4, 16 kHz)",
processing_mode: str = "Standard (1 pass + post-filter)",
chunk_seconds: float = 8.0,
overlap_seconds: float = 0.2,
wet_percent: int = 100,
post_strength: int = 55,
strictness: int = 65,
voice_enhance_on: bool = True,
voice_enhance_amount: int = 55, # gentler default
use_resemble: bool = False,
resemble_mode: str = "Enhance (denoise+enhance)",
output_sr_choice: str = "Original",
target_seconds: int = 18,
progress=gr.Progress(track_tqdm=False),
):
start_t = time.time()
budget = GPU_TIME_BUDGET if DEVICE == "cuda" else CPU_TIME_BUDGET
progress(0.02, "Loading file")
if not file_path:
return None, None, "No file provided."
wav, sr = sf.read(file_path, dtype="float32", always_2d=False)
wav = to_mono(wav)
dur_sec = len(wav) / float(sr)
progress(0.05, "Trimming silence")
wav = trim_silence(wav, sr)
progress(0.07, "Resampling")
wav_16k = resample_np(wav, sr, MODEL_SR)
mode = processing_mode.lower()
cpu = (DEVICE == "cpu")
if mode.startswith("standard"):
passes, do_gate_times = 1, 1
elif mode.startswith("pro"):
passes, do_gate_times = (1 if cpu else 2), 1
else:
passes, do_gate_times = (2 if cpu else 3), 2
if (not cpu) and strictness >= 80 and passes < 3:
passes += 1
extra_strict_stage = 1 if strictness > 0 else 0
enhance_stage = 1 if voice_enhance_on else 0
resemble_stage = 1 if use_resemble else 0
total_slices = passes + do_gate_times + extra_strict_stage + enhance_stage + resemble_stage + 1
slice_span = 0.72 / max(1, total_slices)
cur_base = 0.08
audio = wav_16k
model = load_dns4_model()
for idx in range(passes):
progress(cur_base, f"SepFormer pass {idx+1}/{passes}")
heavy_left = (passes - idx) + (1 if use_resemble else 0)
remain_budget = max(0.0, budget - (time.time() - start_t))
pass_budget = None if cpu else (remain_budget / (heavy_left + 0.5)) if heavy_left > 0 else None
audio = process_in_chunks(model, audio, float(chunk_seconds), float(overlap_seconds),
time_budget=pass_budget if cpu else None,
progress=progress, base=cur_base, span=slice_span)
cur_base += slice_span
strength = float(np.clip(post_strength / 100.0, 0.0, 1.0))
for j in range(do_gate_times):
progress(cur_base, f"Post-filter {j+1}/{do_gate_times}")
audio = spectral_gate_np(audio, strength=strength, progress=progress, base=cur_base, span=slice_span)
cur_base += slice_span
if strictness > 0:
progress(cur_base, "Strict background kill")
audio = strict_background_kill(audio, level=float(np.clip(strictness / 100.0, 0.0, 1.0)),
progress=progress, base=cur_base, span=slice_span)
cur_base += slice_span
if voice_enhance_on:
progress(cur_base, "Voice Enhance (clarity)")
ve_amt = float(np.clip(voice_enhance_amount / 100.0, 0.0, 1.0))
audio = voice_enhance(audio, amount=ve_amt, deess=0.20)
cur_base += slice_span
if use_resemble:
denoise_only = resemble_mode.lower().startswith("denoise")
progress(cur_base, "Resemble Enhance: preparing")
try:
remain_budget = max(0.0, budget - (time.time() - start_t))
r_timeout = min(RESEMBLE_TIMEOUT, remain_budget * 0.8) if cpu else RESEMBLE_TIMEOUT
audio = run_resemble_enhance(audio, MODEL_SR, denoise_only=denoise_only,
timeout_sec=r_timeout,
progress=progress, base=cur_base, span=slice_span)
except Exception as e:
progress(cur_base + slice_span * 0.99, f"Resemble skipped: {e}")
cur_base += slice_span
progress(cur_base, "Mixing & resampling")
wet = float(np.clip(wet_percent / 100.0, 0.0, 1.0))
mixed_16k = wet * audio + (1.0 - wet) * wav_16k
out_sr = sr if output_sr_choice == "Original" else MODEL_SR
out_wav = resample_np(mixed_16k, MODEL_SR, out_sr)
peak = float(np.max(np.abs(out_wav)) + 1e-9)
if peak > 0.999:
out_wav = out_wav / peak * 0.999
ensure_runtime(float(target_seconds), start_t, progress=progress, base=cur_base, span=min(0.2, 1.0 - cur_base))
progress(0.99, "Saving")
out_path = tempfile.mkstemp(suffix=".wav")[1]
sf.write(out_path, out_wav, out_sr, subtype="PCM_16")
status = (
f"Preset: Crisp & Fast (SAFE) | Model: {model_label} | Passes: {passes}+post({do_gate_times})"
f"{' + Strict' if strictness>0 else ''}{' + VoiceEnh' if voice_enhance_on else ''}"
f"{' + Resemble' if use_resemble else ''} | Input: {sr} Hz, {dur_sec:.1f}s | "
f"Output: {out_sr} Hz | Wet: {wet_percent}% | Post: {post_strength}% | "
f"Strict: {strictness}% | VoiceEnh: {voice_enhance_amount}% | Device: {DEVICE.upper()}"
)
progress(1.0, "Done")
return out_path, file_path, status
# ---------------- Presets (auto crisp on load) ----------------
def apply_preset(preset: str):
cpu = (DEVICE == "cpu")
if preset.startswith("Crisp & Fast"):
mode = "Standard (1 pass + post-filter)"
chunk = 8.0 if cpu else 10.0
overlap = 0.2 if cpu else 0.3
wet = 100
post_strength = 55
strictness = 65
voice_on = True
voice_amt = 55 # gentler default
use_resem = False
resemmode = "Enhance (denoise+enhance)"
out_sr = "Original"
target = 16 if cpu else 18
elif preset.startswith("Broadcast Clean"):
mode = "Pro (dual-pass + post-filter)" if not cpu else "Standard (1 pass + post-filter)"
chunk = 10.0 if cpu else 12.0
overlap = 0.25 if cpu else 0.4
wet = 100
post_strength = 58
strictness = 70
voice_on = True
voice_amt = 60
use_resem = False
resemmode = "Enhance (denoise+enhance)"
out_sr = "Original"
target = 20 if cpu else 22
else: # "Max Clean (slow)"
mode = "Ultra (triple-pass + extra post)" if not cpu else "Pro (dual-pass + post-filter)"
chunk = 10.0
overlap = 0.5
wet = 100
post_strength = 62
strictness = 80
voice_on = True
voice_amt = 60
use_resem = (False if cpu else True)
resemmode = "Enhance (denoise+enhance)"
out_sr = "16 kHz"
target = 26 if cpu else 30
return (mode, chunk, overlap, wet, post_strength, strictness, voice_on, voice_amt,
use_resem, resemmode, out_sr, target)
# ---------------- Gradio UI ----------------
TITLE = "AI Background Noise Remover — SepFormer (DNS4) • Safe & Crisp Defaults"
DESC = (
"Safe defaults to avoid harsh highs. SepFormer (DNS4) + smart post-filter + Voice Enhance. "
"Use presets to trade speed vs. aggressiveness."
)
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as app:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESC)
with gr.Row():
audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Noisy Audio")
with gr.Row():
preset = gr.Dropdown(
["Crisp & Fast (default)", "Broadcast Clean (balanced)", "Max Clean (slow)"],
value="Crisp & Fast (default)",
label="Preset"
)
model_dd = gr.Dropdown(
["SepFormer (DNS4, 16 kHz)"],
value="SepFormer (DNS4, 16 kHz)",
label="Enhancement Model"
)
mode_dd = gr.Radio(
["Standard (1 pass + post-filter)", "Pro (dual-pass + post-filter)", "Ultra (triple-pass + extra post)"],
value="Standard (1 pass + post-filter)",
label="Processing Mode"
)
with gr.Row():
chunk = gr.Slider(5, 30, value=8, step=1, label="Chunk Length (s)")
overlap = gr.Slider(0.0, 3.0, value=0.2, step=0.05, label="Overlap Crossfade (s)")
out_sr = gr.Radio(["Original", "16 kHz"], value="Original", label="Output Sample Rate")
with gr.Row():
wet = gr.Slider(0, 100, value=100, step=1, label="Enhancement Mix (Wet %)")
post_strength = gr.Slider(0, 100, value=55, step=1, label="Post Polish Strength")
strictness = gr.Slider(0, 100, value=65, step=1, label="Strictness (extra background kill)")
with gr.Row():
voice_enhance_on = gr.Checkbox(True, label="Voice Enhance (clarity/loudness)")
voice_enhance_amount = gr.Slider(0, 100, value=55, step=1, label="Voice Enhance Amount")
with gr.Row():
use_resemble = gr.Checkbox(False, label="Resemble Enhance (optional, slower)")
resemble_mode = gr.Radio(
["Enhance (denoise+enhance)", "Denoise only"],
value="Enhance (denoise+enhance)",
label="Resemble Mode"
)
target_secs = gr.Slider(5, 60, value=18, step=1, label="Target Runtime (s)")
preset.change(
fn=apply_preset,
inputs=[preset],
outputs=[mode_dd, chunk, overlap, wet, post_strength, strictness, voice_enhance_on, voice_enhance_amount,
use_resemble, resemble_mode, out_sr, target_secs],
api_name="preset"
)
run_btn = gr.Button("Enhance Audio", variant="primary")
with gr.Row():
enhanced_out = gr.Audio(type="filepath", label="Enhanced (Download)")
original_out = gr.Audio(type="filepath", label="Original")
status = gr.Markdown("Ready.")
run_btn.click(
fn=denoise,
inputs=[audio_in, model_dd, mode_dd, chunk, overlap, wet, post_strength, strictness,
voice_enhance_on, voice_enhance_amount, use_resemble, resemble_mode, out_sr, target_secs],
outputs=[enhanced_out, original_out, status],
api_name="enhance",
)
try:
app.queue(concurrency_count=1, max_size=16, status_update_rate=1)
except TypeError:
app.queue()
if __name__ == "__main__":
app.launch()