karesaeedff commited on
Commit
621b172
·
verified ·
1 Parent(s): 3cf7df4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -32
app.py CHANGED
@@ -2,49 +2,63 @@ import gradio as gr
2
  import librosa
3
  import numpy as np
4
  import torch
5
- from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, AutoProcessor
6
  import tempfile
7
  import soundfile as sf
8
  import json
9
 
 
10
  SAMPLE_RATE = 16000
11
- CHUNK_SIZE = 60
12
- STEP = 10
13
  MUSIC_THRESHOLD = 0.5
14
  VOICE_THRESHOLD = 0.3
15
- MIN_SEG_DURATION = 8
16
 
17
- # === 修正版 ===
 
 
 
18
  music_model_id = "AI-Music-Detection/ai_music_detection_large_60s"
19
  music_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
20
  music_model = AutoModelForAudioClassification.from_pretrained(music_model_id)
21
 
 
22
  voice_model_id = "superb/hubert-large-superb-sid"
23
- voice_processor = AutoProcessor.from_pretrained(voice_model_id)
24
  voice_model = AutoModelForAudioClassification.from_pretrained(voice_model_id)
25
 
 
 
 
 
26
  def predict_music_score(wav):
 
27
  wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
28
- inputs = music_processor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
29
  with torch.no_grad():
30
  outputs = music_model(**inputs)
31
- scores = torch.softmax(outputs.logits, dim=-1).squeeze()
32
- music_score = float(scores[1]) if scores.numel() > 1 else float(scores[0])
33
- return music_score
 
34
 
35
  def predict_voice_score(wav):
 
36
  wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
37
- inputs = voice_processor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
38
  with torch.no_grad():
39
  outputs = voice_model(**inputs)
40
- scores = torch.softmax(outputs.logits, dim=-1).squeeze()
41
- voice_score = float(scores.mean()) # 简单平均
42
- return voice_score
 
43
 
44
  def detect_singing(audio_path):
45
- wav, sr = librosa.load(audio_path, sr=SAMPLE_RATE)
 
46
  duration = len(wav) / SAMPLE_RATE
47
- results = []
48
 
49
  for start in np.arange(0, max(0, duration - CHUNK_SIZE), STEP):
50
  end = start + CHUNK_SIZE
@@ -54,11 +68,11 @@ def detect_singing(audio_path):
54
  voice_score = predict_voice_score(snippet)
55
 
56
  if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD:
57
- results.append((float(start), float(end)))
58
 
59
- # 合并连续窗口
60
  merged = []
61
- for seg in results:
62
  if not merged or seg[0] > merged[-1][1]:
63
  merged.append(list(seg))
64
  else:
@@ -67,12 +81,13 @@ def detect_singing(audio_path):
67
  return merged
68
 
69
 
70
- def analyze_audio(file):
71
- if file is None:
72
- return "请上传音频文件", None
 
73
 
74
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
75
- data, sr = librosa.load(file, sr=SAMPLE_RATE)
76
  sf.write(tmp.name, data, sr)
77
  segments = detect_singing(tmp.name)
78
 
@@ -86,16 +101,22 @@ def analyze_audio(file):
86
  return f"检测到 {len(segments)} 段唱歌片段", json_output
87
 
88
 
89
- with gr.Blocks(title="🎵 Singing Segment Detector (Plan A)") as demo:
 
90
  gr.Markdown(
91
- "# 🎤 高精度唱歌片段检测\n"
92
- "使用 `AI-Music-Detection/ai_music_detection_large_60s` 模型。\n"
93
- "将视频音频分块分析(60s输入),输出唱歌时间戳 JSON。"
 
 
 
94
  )
95
- audio_in = gr.Audio(type="filepath", label="上传音频文件(从视频抽取)")
96
- btn = gr.Button("开始分析")
97
- status = gr.Textbox(label="分析状态", interactive=False)
98
- json_out = gr.Code(label="唱歌片段时间戳(JSON)", language="json")
99
- btn.click(fn=analyze_audio, inputs=[audio_in], outputs=[status, json_out])
 
 
100
 
101
  demo.launch()
 
2
  import librosa
3
  import numpy as np
4
  import torch
5
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
  import tempfile
7
  import soundfile as sf
8
  import json
9
 
10
+ # === 参数设置 ===
11
  SAMPLE_RATE = 16000
12
+ CHUNK_SIZE = 60 # 模型输入60秒
13
+ STEP = 10 # 滑动步长
14
  MUSIC_THRESHOLD = 0.5
15
  VOICE_THRESHOLD = 0.3
16
+ MIN_SEG_DURATION = 8 # 最小唱段长度(秒)
17
 
18
+ # === 模型加载 ===
19
+ print("Loading models...")
20
+
21
+ # 🎵 音乐检测模型(AST架构)
22
  music_model_id = "AI-Music-Detection/ai_music_detection_large_60s"
23
  music_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
24
  music_model = AutoModelForAudioClassification.from_pretrained(music_model_id)
25
 
26
+ # 🗣️ 语音活动检测模型(HuBERT)
27
  voice_model_id = "superb/hubert-large-superb-sid"
28
+ voice_extractor = AutoFeatureExtractor.from_pretrained(voice_model_id)
29
  voice_model = AutoModelForAudioClassification.from_pretrained(voice_model_id)
30
 
31
+ print("✅ Models loaded successfully.")
32
+
33
+
34
+ # === 模型推理函数 ===
35
  def predict_music_score(wav):
36
+ """预测音乐片段概率"""
37
  wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
38
+ inputs = music_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
39
  with torch.no_grad():
40
  outputs = music_model(**inputs)
41
+ probs = torch.softmax(outputs.logits, dim=-1).squeeze()
42
+ score = float(probs[-1]) if probs.numel() > 1 else float(probs[0])
43
+ return score
44
+
45
 
46
  def predict_voice_score(wav):
47
+ """预测语音片段概率"""
48
  wav = librosa.util.fix_length(wav, size=SAMPLE_RATE * CHUNK_SIZE)
49
+ inputs = voice_extractor(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
50
  with torch.no_grad():
51
  outputs = voice_model(**inputs)
52
+ probs = torch.softmax(outputs.logits, dim=-1).squeeze()
53
+ score = float(probs.mean()) # 平均各类别概率
54
+ return score
55
+
56
 
57
  def detect_singing(audio_path):
58
+ """检测唱歌片段"""
59
+ wav, _ = librosa.load(audio_path, sr=SAMPLE_RATE)
60
  duration = len(wav) / SAMPLE_RATE
61
+ raw_segments = []
62
 
63
  for start in np.arange(0, max(0, duration - CHUNK_SIZE), STEP):
64
  end = start + CHUNK_SIZE
 
68
  voice_score = predict_voice_score(snippet)
69
 
70
  if music_score > MUSIC_THRESHOLD and voice_score > VOICE_THRESHOLD:
71
+ raw_segments.append((float(start), float(end)))
72
 
73
+ # === 合并连续窗口 ===
74
  merged = []
75
+ for seg in raw_segments:
76
  if not merged or seg[0] > merged[-1][1]:
77
  merged.append(list(seg))
78
  else:
 
81
  return merged
82
 
83
 
84
+ # === 主推理函数 ===
85
+ def analyze_audio(file_path):
86
+ if file_path is None:
87
+ return "⚠️ 请上传音频文件", None
88
 
89
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
90
+ data, sr = librosa.load(file_path, sr=SAMPLE_RATE)
91
  sf.write(tmp.name, data, sr)
92
  segments = detect_singing(tmp.name)
93
 
 
101
  return f"检测到 {len(segments)} 段唱歌片段", json_output
102
 
103
 
104
+ # === Gradio UI ===
105
+ with gr.Blocks(title="🎵 Singing Segment Detector (Final)") as demo:
106
  gr.Markdown(
107
+ """
108
+ # 🎤 唱歌片段自动检测器(AI-Music + HuBERT)
109
+ - 自动检测视频中的演唱时间段
110
+ - 采用 `AI-Music-Detection/ai_music_detection_large_60s` + `HuBERT` 双模型融合
111
+ - 输出每段的开始、结束时间与时长
112
+ """
113
  )
114
+
115
+ audio_input = gr.Audio(type="filepath", label="上传音频(从视频提取)")
116
+ run_btn = gr.Button("🚀 开始分析")
117
+ status_box = gr.Textbox(label="分析状态", interactive=False)
118
+ json_output = gr.Code(label="唱歌片段时间戳(JSON)", language="json")
119
+
120
+ run_btn.click(fn=analyze_audio, inputs=[audio_input], outputs=[status_box, json_output])
121
 
122
  demo.launch()