VeuReu commited on
Commit
287f01b
·
verified ·
1 Parent(s): 467c0ec

Upload 17 files

Browse files
Files changed (7) hide show
  1. README.md +3 -3
  2. api.py +98 -0
  3. audio_tools.py +468 -746
  4. background_descriptor.py +116 -10
  5. config.yaml +60 -64
  6. llm_router.py +5 -11
  7. scripts/remote_clients.py +78 -0
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Veureu Engine (Docker)
3
  emoji: 🎧
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: docker
7
- app_file: main_api.py
8
  pinned: false
9
  ---
10
 
11
- # Veureu Engine API (FastAPI via Docker)
12
 
13
  Endpoints:
14
  - `POST /process_video`
 
1
  ---
2
+ title: veureu-engine
3
  emoji: 🎧
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: docker
7
+ app_file: api.py
8
  pinned: false
9
  ---
10
 
11
+ # veureu-engine
12
 
13
  Endpoints:
14
  - `POST /process_video`
api.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from fastapi import FastAPI, UploadFile, File, Form
3
+ from fastapi.responses import JSONResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pathlib import Path
6
+ import shutil
7
+ import uvicorn
8
+ import json
9
+
10
+ from video_processing import process_video_pipeline
11
+ from casting_loader import ensure_chroma, build_faces_index, build_voices_index
12
+ from narration_system import NarrationSystem
13
+ from llm_router import load_yaml, LLMRouter
14
+
15
+ app = FastAPI(title="Veureu Engine API", version="0.2.0")
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ ROOT = Path("/tmp/veureu")
25
+ ROOT.mkdir(parents=True, exist_ok=True)
26
+
27
+ @app.get("/")
28
+ def root():
29
+ return {"ok": True, "service": "veureu-engine"}
30
+
31
+ @app.post("/process_video")
32
+ async def process_video(
33
+ video_file: UploadFile = File(...),
34
+ config_path: str = Form("config.yaml"),
35
+ out_root: str = Form("results"),
36
+ db_dir: str = Form("chroma_db"),
37
+ ):
38
+ tmp_video = ROOT / video_file.filename
39
+ with tmp_video.open("wb") as f:
40
+ shutil.copyfileobj(video_file.file, f)
41
+ result = process_video_pipeline(str(tmp_video), config_path=config_path, out_root=out_root, db_dir=db_dir)
42
+ return JSONResponse(result)
43
+
44
+ @app.post("/load_casting")
45
+ async def load_casting(
46
+ faces_dir: str = Form("identities/faces"),
47
+ voices_dir: str = Form("identities/voices"),
48
+ db_dir: str = Form("chroma_db"),
49
+ drop_collections: bool = Form(False),
50
+ ):
51
+ client = ensure_chroma(Path(db_dir))
52
+ n_faces = build_faces_index(Path(faces_dir), client, collection_name="index_faces", drop=drop_collections)
53
+ n_voices = build_voices_index(Path(voices_dir), client, collection_name="index_voices", drop=drop_collections)
54
+ return {"ok": True, "faces": n_faces, "voices": n_voices}
55
+
56
+ @app.post("/refine_narration")
57
+ async def refine_narration(
58
+ dialogues_srt: str = Form(...),
59
+ frame_descriptions_json: str = Form("[]"),
60
+ config_path: str = Form("config.yaml"),
61
+ ):
62
+ cfg = load_yaml(config_path)
63
+ frames = json.loads(frame_descriptions_json)
64
+ model_name = cfg.get("narration", {}).get("model", "salamandra-instruct")
65
+ use_remote = model_name in (cfg.get("models", {}).get("routing", {}).get("use_remote_for", []))
66
+
67
+ if use_remote:
68
+ router = LLMRouter(cfg)
69
+ system_msg = (
70
+ "Eres un sistema de audiodescripción que cumple UNE-153010. "
71
+ "Fusiona diálogos del SRT con descripciones concisas en los huecos, evitando redundancias. "
72
+ "Devuelve JSON con {narrative_text, srt_text}."
73
+ )
74
+ prompt = json.dumps({"dialogues_srt": dialogues_srt, "frames": frames, "rules": cfg.get("narration", {})}, ensure_ascii=False)
75
+ try:
76
+ txt = router.instruct(prompt=prompt, system=system_msg, model=model_name)
77
+ out = {}
78
+ try:
79
+ out = json.loads(txt)
80
+ except Exception:
81
+ out = {"narrative_text": txt, "srt_text": ""}
82
+ return {
83
+ "narrative_text": out.get("narrative_text", ""),
84
+ "srt_text": out.get("srt_text", ""),
85
+ "approved": True,
86
+ "critic_feedback": "",
87
+ }
88
+ except Exception:
89
+ ns = NarrationSystem(model_url=None, une_guidelines_path=cfg.get("narration", {}).get("narration_une_guidelines_path", "UNE_153010.txt"))
90
+ res = ns.run(dialogues_srt, frames)
91
+ return {"narrative_text": res.narrative_text, "srt_text": res.srt_text, "approved": res.approved, "critic_feedback": res.critic_feedback}
92
+
93
+ ns = NarrationSystem(model_url=None, une_guidelines_path=cfg.get("narration", {}).get("une_guidelines_path", "UNE_153010.txt"))
94
+ out = ns.run(dialogues_srt, frames)
95
+ return {"narrative_text": out.narrative_text, "srt_text": out.srt_text, "approved": out.approved, "critic_feedback": out.critic_feedback}
96
+
97
+ if __name__ == "__main__":
98
+ uvicorn.run(app, host="0.0.0.0", port=7860)
audio_tools.py CHANGED
@@ -1,746 +1,468 @@
1
- # audio_tools.py
2
- # -----------------------------------------------------------------------------
3
- # Veureu — AUDIO utilities (self-contained)
4
- # - FFmpeg extraction (WAV)
5
- # - Diarization (pyannote)
6
- # - ASR:
7
- # * Catalan ("ca") -> AINA Whisper
8
- # * Other languages -> Lightweight generic Whisper
9
- # - Integrated Language ID (Whisper via faster-whisper)
10
- # - Voice embeddings (SpeechBrain ECAPA)
11
- # - Speaker identification (KMeans + optional ChromaDB collection)
12
- # - SRT generation
13
- # - Orchestrator: process_audio_for_video(...)
14
- # - ADDED: ASR of full audio and LLM-based SRT correction
15
- # -----------------------------------------------------------------------------
16
- from __future__ import annotations
17
-
18
- import numpy as np
19
- import json
20
- import logging
21
- import math
22
- import os
23
- import shlex
24
- import subprocess
25
- from pathlib import Path
26
- from typing import List, Dict, Any, Tuple, Optional
27
- from dataclasses import dataclass
28
-
29
- # al principio de audio_tools.py
30
- try:
31
- import torchaudio as ta
32
- HAS_TORCHAUDIO = True
33
- except ImportError:
34
- ta = None
35
- HAS_TORCHAUDIO = False
36
- import soundfile as sf
37
-
38
- import torch
39
- import torchaudio.transforms as T
40
- from pydub import AudioSegment
41
- from pyannote.audio import Pipeline
42
- from speechbrain.pretrained import SpeakerRecognition
43
- from sklearn.cluster import KMeans
44
- from sklearn.metrics import silhouette_score
45
- from transformers import WhisperForConditionalGeneration, WhisperProcessor
46
- from openai import OpenAI as OpenAIClient
47
- import noisereduce as nr
48
-
49
- # -------------------------------- Logging ------------------------------------
50
- log = logging.getLogger("audio_tools")
51
- if not log.handlers:
52
- _h = logging.StreamHandler()
53
- _h.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
54
- log.addHandler(_h)
55
- log.setLevel(logging.INFO)
56
-
57
- # ------------------------------- Utilities -----------------------------------
58
-
59
- def load_wav(path, sr=16000):
60
- if HAS_TORCHAUDIO:
61
- wav, in_sr = ta.load(path)
62
- if in_sr != sr:
63
- wav = ta.functional.resample(wav, in_sr, sr)
64
- return wav.squeeze(0).numpy(), sr
65
- # fallback con soundfile + resample con librosa
66
- import librosa
67
- y, in_sr = sf.read(path, dtype="float32", always_2d=False)
68
- if in_sr != sr:
69
- y = librosa.resample(y, orig_sr=in_sr, target_sr=sr)
70
- return y.astype(np.float32), sr
71
-
72
- def save_wav(path, y, sr=16000):
73
- if HAS_TORCHAUDIO:
74
- ta.save(path, torch.from_numpy(y).unsqueeze(0), sr) # si usas torch
75
- else:
76
- sf.write(path, y, sr)
77
-
78
- def _pick_device_auto(dev_cfg: str) -> str:
79
- """Resolve 'auto' device to cuda/cpu."""
80
- if dev_cfg == "auto":
81
- return "cuda" if torch.cuda.is_available() else "cpu"
82
- return dev_cfg
83
-
84
- def load_config(path: str = "configs/config_veureu.yaml") -> Dict[str, Any]:
85
- p = Path(path)
86
- if not p.exists():
87
- log.warning("Config file not found: %s (using defaults)", path)
88
- return {}
89
- try:
90
- import yaml
91
- cfg = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
92
- cfg["__path__"] = str(p)
93
- return cfg
94
- except Exception as e:
95
- log.error("Failed to read YAML config: %s", e)
96
- return {}
97
-
98
- # ------------------------------- Extraction ----------------------------------
99
-
100
- def extract_audio_ffmpeg(
101
- video_path: str,
102
- audio_out: Path,
103
- sr: int = 16000,
104
- mono: bool = True,
105
- ) -> str:
106
- """Extract audio from video to WAV using ffmpeg."""
107
- audio_out.parent.mkdir(parents=True, exist_ok=True)
108
- cmd = f'ffmpeg -y -i "{video_path}" -vn {"-ac 1" if mono else ""} -ar {sr} -f wav "{audio_out}"'
109
- subprocess.run(
110
- shlex.split(cmd),
111
- check=True,
112
- stdout=subprocess.DEVNULL,
113
- stderr=subprocess.DEVNULL,
114
- )
115
- return str(audio_out)
116
-
117
- # ----------------------------------- ASR -------------------------------------
118
-
119
- @dataclass
120
- class AinaASR:
121
- """ASR for Catalan using the AINA Whisper model."""
122
- model_name: str = "projecte-aina/whisper-large-v3-ca-3catparla"
123
- device: str = "cuda"
124
-
125
- def __post_init__(self):
126
- dev = self.device
127
- if dev == "cuda" and not torch.cuda.is_available():
128
- dev = "cpu"
129
- self.processor = WhisperProcessor.from_pretrained(self.model_name)
130
- self.model = WhisperForConditionalGeneration.from_pretrained(self.model_name).to(dev)
131
- self.device = dev
132
- log.info(f"ASR AINA loaded on {self.device}: {self.model_name}")
133
-
134
- def transcribe_wav(self, wav_path: str) -> str:
135
- waveform, sr = torchaudio.load(wav_path)
136
- inputs = self.processor(
137
- waveform.numpy(), sampling_rate=sr, return_tensors="pt"
138
- ).input_features.to(self.model.device)
139
- with torch.no_grad():
140
- ids = self.model.generate(inputs, max_new_tokens=440)[0]
141
- txt = self.processor.decode(ids)
142
- norm = getattr(self.processor.tokenizer, "_normalize", None)
143
- return norm(txt) if callable(norm) else txt
144
-
145
- def transcribe_long_audio(
146
- self,
147
- wav_path: str,
148
- chunk_length_s: int = 20,
149
- overlap_s: int = 2,
150
- ) -> str:
151
- waveform, sr = torchaudio.load(wav_path)
152
- total_samples = waveform.shape[1]
153
- chunk_size = chunk_length_s * sr
154
- overlap_size = overlap_s * sr
155
-
156
- transcriptions = []
157
- start = 0
158
-
159
- while start < total_samples:
160
- end = min(start + chunk_size, total_samples)
161
- chunk = waveform[:, start:end]
162
-
163
- input_features = self.processor(
164
- chunk.numpy(),
165
- sampling_rate=sr,
166
- return_tensors="pt"
167
- ).input_features.to(self.model.device)
168
-
169
- with torch.no_grad():
170
- predicted_ids = self.model.generate(
171
- input_features,
172
- max_new_tokens=440,
173
- num_beams=1, # puedes probar beam search
174
- )[0]
175
-
176
- text = self.processor.decode(predicted_ids, skip_special_tokens=True)
177
- transcriptions.append(text.strip())
178
-
179
- # avanzar con solapamiento
180
- start += chunk_size - overlap_size
181
-
182
- return " ".join(transcriptions).strip()
183
-
184
- @dataclass
185
- class WhisperASR:
186
- """Lightweight generic ASR based on Whisper for non-Catalan languages."""
187
- model_name: str = "openai/whisper-small" # change to 'base' for an even lighter model
188
- device: str = "cuda"
189
- language: Optional[str] = None # force language, e.g. "es", "en", etc.
190
-
191
- def __post_init__(self):
192
- dev = self.device
193
- if dev == "cuda" and not torch.cuda.is_available():
194
- dev = "cpu"
195
- self.processor = WhisperProcessor.from_pretrained(self.model_name)
196
- self.model = WhisperForConditionalGeneration.from_pretrained(self.model_name).to(dev)
197
- self.device = dev
198
- log.info(f"ASR Whisper loaded on {self.device}: {self.model_name} (lang hint: {self.language})")
199
-
200
- def transcribe_wav(self, wav_path: str) -> str:
201
- waveform, sr = torchaudio.load(wav_path)
202
- inputs = self.processor(
203
- waveform.numpy(), sampling_rate=sr, return_tensors="pt"
204
- ).input_features.to(self.model.device)
205
-
206
- gen_kwargs: Dict[str, Any] = dict(max_new_tokens=444)
207
- if self.language and self.language != "auto":
208
- try:
209
- forced_ids = self.processor.get_decoder_prompt_ids(
210
- language=self.language, task="transcribe"
211
- )
212
- gen_kwargs["forced_decoder_ids"] = forced_ids
213
- except Exception:
214
- # If the model/processor does not support forced ids, continue without forcing
215
- pass
216
-
217
- with torch.no_grad():
218
- ids = self.model.generate(inputs, **gen_kwargs)[0]
219
- txt = self.processor.decode(ids)
220
- norm = getattr(self.processor.tokenizer, "_normalize", None)
221
- return norm(txt) if callable(norm) else txt
222
-
223
- # ------------------------------ Language ID ----------------------------------
224
-
225
- @dataclass
226
- class WhisperLIDConfig:
227
- """Configuration for language detection with faster-whisper."""
228
- model_name: str = "Systran/faster-whisper-small"
229
- device: str = "auto"
230
- compute_type: str = "float32" # "int8" | "float16" | "float32"
231
- beam_size: int = 1
232
- chunk_seconds: float = 30.0
233
- prob_threshold: float = 0.5
234
- fallback_lang: str = "auto"
235
-
236
- def detect_language_with_whisper(
237
- wav_path: str,
238
- cfg: Dict[str, Any],
239
- ) -> Tuple[str, float]:
240
- """
241
- Detects language using faster-whisper (WhisperModel). Returns (lang_iso, prob).
242
- In case of failure, returns (fallback_lang, 0.0).
243
- """
244
- lid_cfg_d = (cfg.get("asr", {})
245
- .get("language_detection", {})
246
- .get("whisper_lid", {}))
247
- lid_cfg = WhisperLIDConfig(
248
- model_name=lid_cfg_d.get("model_name", "Systran/faster-whisper-small",),
249
- device=_pick_device_auto(lid_cfg_d.get("device", "auto")),
250
- compute_type=lid_cfg_d.get("compute_type", "float32"),
251
- beam_size=int(lid_cfg_d.get("beam_size", 1)),
252
- chunk_seconds=float(lid_cfg_d.get("chunk_seconds", 30.0)),
253
- prob_threshold=float(lid_cfg_d.get("prob_threshold", 0.5)),
254
- fallback_lang=lid_cfg_d.get("fallback_lang", "auto"),
255
- )
256
-
257
- try:
258
- from faster_whisper import WhisperModel # type: ignore
259
- except Exception as e:
260
- log.warning(f"LID: faster-whisper not available ({e}). Fallback='{lid_cfg.fallback_lang}'")
261
- return lid_cfg.fallback_lang, 0.0
262
-
263
- try:
264
- model = WhisperModel(lid_cfg.model_name, device=lid_cfg.device, compute_type=lid_cfg.compute_type)
265
- except Exception as e:
266
- log.warning(f"LID: failed to load '{lid_cfg.model_name}': {e}. Fallback='{lid_cfg.fallback_lang}'")
267
- return lid_cfg.fallback_lang, 0.0
268
-
269
- try:
270
- segments, info = model.transcribe(
271
- wav_path,
272
- beam_size=lid_cfg.beam_size,
273
- vad_filter=True,
274
- without_timestamps=True,
275
- language=None
276
- )
277
- lang = info.language or lid_cfg.fallback_lang
278
- prob = float(info.language_probability or 0.0)
279
- if prob < lid_cfg.prob_threshold:
280
- return lid_cfg.fallback_lang, prob
281
- return lang, prob
282
- except Exception as e:
283
- log.warning(f"LID: error in transcription/detection: {e}. Fallback='{lid_cfg.fallback_lang}'")
284
- return lid_cfg.fallback_lang, 0.0
285
-
286
-
287
- def _build_asr_backend_for_language(lang_iso: str, cfg: Dict[str, Any]):
288
- """
289
- Selects ASR backend based on language:
290
- - 'ca' -> AINA
291
- - other -> Generic Whisper
292
- """
293
- asr_cfg = cfg.get("asr", {})
294
- device_pref = _pick_device_auto(asr_cfg.get("device", "auto"))
295
- if lang_iso and lang_iso.lower() == "ca":
296
- return AinaASR(
297
- model_name=asr_cfg.get("model_name", "projecte-aina/whisper-large-v3-ca-3catparla"),
298
- device=device_pref,
299
- )
300
- else:
301
- return WhisperASR(
302
- model_name=asr_cfg.get("whisper_model_name", "openai/whisper-small"),
303
- device=device_pref,
304
- language=None,
305
- )
306
-
307
- # -------------------------------- Diarization --------------------------------
308
-
309
- def diarize_audio(
310
- wav_path: str,
311
- base_dir: Path,
312
- clips_folder: str = "clips",
313
- min_segment_duration: float = 20,
314
- max_segment_duration: float = 50.0,
315
- hf_token_env: str | None = None,
316
- ) -> Tuple[List[str], List[Dict[str, Any]]]:
317
- """Diarization with pyannote and clip export with pydub.
318
- Returns clip paths and segments [{'start','end','speaker'}].
319
- """
320
- audio = AudioSegment.from_wav(wav_path)
321
- duration = len(audio) / 1000.0
322
-
323
- pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=(hf_token_env
324
- or os.getenv("HF_TOKEN")))
325
- diarization = pipeline(wav_path)
326
-
327
- clips_dir = (base_dir / clips_folder)
328
- clips_dir.mkdir(parents=True, exist_ok=True)
329
- clip_paths: List[str] = []
330
- segments: List[Dict[str, Any]] = []
331
- spk_map: Dict[str, int] = {}
332
-
333
- prev_end = 0.0 # referencia al final del último segmento exportado
334
-
335
- for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
336
- start, end = max(0.0, float(turn.start)), min(duration, float(turn.end))
337
-
338
- if start < prev_end:
339
- start = prev_end
340
- if end <= start:
341
- continue
342
-
343
- seg_dur = end - start
344
- if seg_dur < min_segment_duration:
345
- continue
346
-
347
- if seg_dur > max_segment_duration:
348
- # split long segments
349
- n = int(math.ceil(seg_dur / max_segment_duration))
350
- sub_d = seg_dur / n
351
- for j in range(n):
352
- s = start + j * sub_d
353
- e = min(end, start + (j + 1) * sub_d)
354
- if e <= s:
355
- continue
356
- clip = audio[int(s * 1000):int(e * 1000)]
357
- cp = clips_dir / f"segment_{i:03d}_{j:02d}.wav"
358
- clip.export(cp, format="wav")
359
- if speaker not in spk_map:
360
- spk_map[speaker] = len(spk_map)
361
- segments.append({
362
- "start": s,
363
- "end": e,
364
- "speaker": f"SPEAKER_{spk_map[speaker]:02d}"
365
- })
366
- clip_paths.append(str(cp))
367
- prev_end = e
368
- else:
369
- clip = audio[int(start * 1000):int(end * 1000)]
370
- cp = clips_dir / f"segment_{i:03d}.wav"
371
- clip.export(cp, format="wav")
372
- if speaker not in spk_map:
373
- spk_map[speaker] = len(spk_map)
374
- segments.append({
375
- "start": start,
376
- "end": end,
377
- "speaker": f"SPEAKER_{spk_map[speaker]:02d}"
378
- })
379
- clip_paths.append(str(cp))
380
- prev_end = end # actualizar referencia
381
-
382
- if not segments:
383
- # fallback single clip
384
- cp = clips_dir / "segment_000.wav"
385
- audio.export(cp, format="wav")
386
- return [str(cp)], [{"start": 0.0, "end": duration, "speaker": "SPEAKER_00"}]
387
-
388
- # sort by start time
389
- pairs = sorted(zip(clip_paths, segments), key=lambda x: x[1]["start"])
390
- clip_paths, segments = [p[0] for p in pairs], [p[1] for p in pairs]
391
- return clip_paths, segments
392
-
393
- # ------------------------------ Voice embeddings -----------------------------
394
-
395
- class VoiceEmbedder:
396
- def __init__(self):
397
- self.model = SpeakerRecognition.from_hparams(
398
- source="speechbrain/spkrec-ecapa-voxceleb",
399
- savedir="pretrained_models/spkrec-ecapa-voxceleb",
400
- )
401
- self.model.eval()
402
-
403
- def embed(self, wav_path: str) -> List[float]:
404
- waveform, sr = torchaudio.load(wav_path)
405
- target_sr = 16000
406
- if sr != target_sr:
407
- waveform = T.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
408
- if waveform.shape[0] > 1:
409
- waveform = waveform.mean(dim=0, keepdim=True)
410
- # ensure minimum length (~0.2s) for stability
411
- min_samples = int(0.2 * target_sr)
412
- if waveform.shape[1] < min_samples:
413
- pad = min_samples - waveform.shape[1]
414
- waveform = torch.cat([waveform, torch.zeros((1, pad))], dim=1)
415
- with torch.no_grad():
416
- emb = (
417
- self.model.encode_batch(waveform)
418
- .squeeze()
419
- .cpu()
420
- .numpy()
421
- .astype(float)
422
- )
423
- return emb.tolist()
424
-
425
-
426
- def embed_voice_segments(clip_paths: List[str]) -> List[List[float]]:
427
- ve = VoiceEmbedder()
428
- out: List[List[float]] = []
429
- for cp in clip_paths:
430
- try:
431
- out.append(ve.embed(cp))
432
- except Exception as e:
433
- log.warning(f"Embedding error in {cp}: {e}")
434
- out.append([])
435
- return out
436
-
437
- # --------------------------- Speaker identification --------------------------
438
-
439
- def identify_speakers(
440
- embeddings: List[List[float]],
441
- voice_collection, # ChromaDB collection with .query or None
442
- cfg: Dict[str, Any],
443
- ) -> List[str]:
444
- voice_cfg = cfg.get("voice_processing", {}).get("speaker_identification", {})
445
- if not embeddings or sum(1 for e in embeddings if e) < 2:
446
- return ["SPEAKER_00" for _ in embeddings]
447
-
448
- valid = [e for e in embeddings if e and len(e) > 0]
449
- if len(valid) < 2:
450
- return ["SPEAKER_00" for _ in embeddings]
451
-
452
- min_clusters = max(1, voice_cfg.get("min_speakers", 1))
453
- max_clusters = min(voice_cfg.get("max_speakers", 5), len(valid) - 1)
454
-
455
- # buscar k óptimo usando silhouette_score
456
- if voice_cfg.get("find_optimal_clusters", True) and len(valid) > 2:
457
- best_score = -1.0
458
- best_k = min_clusters
459
- for k in range(min_clusters, max_clusters + 1):
460
- if k >= len(valid):
461
- break
462
- km = KMeans(n_clusters=k, random_state=42, n_init="auto")
463
- labels = km.fit_predict(valid)
464
- if len(set(labels)) > 1:
465
- score = silhouette_score(valid, labels)
466
- if score > best_score:
467
- best_score, best_k = score, k
468
- else:
469
- best_k = min(max_clusters, max(min_clusters, voice_cfg.get("num_speakers", 2)))
470
- best_k = max(1, min(best_k, len(valid) - 1))
471
-
472
- # clustering final
473
- km = KMeans(n_clusters=best_k, random_state=42, n_init="auto", init="k-means++")
474
- labels = km.fit_predict(np.array(valid))
475
- centers = km.cluster_centers_
476
-
477
- cluster_to_name: Dict[int, str] = {}
478
- unknown_counter = 0
479
- for cid in range(best_k):
480
- center = centers[cid].tolist()
481
- name = f"SPEAKER_{cid:02d}"
482
-
483
- if voice_collection is not None:
484
- try:
485
- q = voice_collection.query(query_embeddings=[center], n_results=1)
486
- metas = q.get("metadatas", [[]])[0]
487
- dists = q.get("distances", [[]])[0]
488
- thr = voice_cfg.get("distance_threshold")
489
-
490
- if dists and thr is not None and dists[0] > thr:
491
- # nuevo hablante → marcar como UNKNOWN y guardar en la colección
492
- name = f"UNKNOWN_{unknown_counter}"
493
- unknown_counter += 1
494
- voice_collection.add(
495
- embeddings=[center],
496
- metadatas=[{"name": name}],
497
- ids=[f"unk_{cid}_{unknown_counter}"]
498
- )
499
- else:
500
- # coincidencia aceptable → usar nombre existente
501
- if metas and isinstance(metas[0], dict):
502
- name = metas[0].get("nombre") or metas[0].get("name") \
503
- or metas[0].get("speaker") or metas[0].get("identity") \
504
- or name
505
- except Exception as e:
506
- log.warning(f"Voice KNN query failed: {e}")
507
-
508
- cluster_to_name[cid] = name
509
-
510
- # mapear cada embedding a su hablante
511
- personas: List[str] = []
512
- vi = 0
513
- for emb in embeddings:
514
- if not emb:
515
- personas.append("UNKNOWN")
516
- else:
517
- label = int(labels[vi])
518
- personas.append(cluster_to_name.get(label, f"SPEAKER_{label:02d}"))
519
- vi += 1
520
-
521
- return personas
522
-
523
- # ----------------------------------- SRT -------------------------------------
524
-
525
- def _fmt_srt_time(seconds: float) -> str:
526
- h = int(seconds // 3600)
527
- m = int((seconds % 3600) // 60)
528
- s = int(seconds % 60)
529
- ms = int(round((seconds - int(seconds)) * 1000))
530
- return f"{h:02}:{m:02}:{s:02},{ms:03}"
531
-
532
-
533
- def generate_srt_from_diarization(
534
- diarization_segments: List[Dict[str, Any]],
535
- transcriptions: List[str],
536
- speakers_per_segment: List[str],
537
- output_srt_path: str,
538
- cfg: Dict[str, Any],
539
- ) -> None:
540
- subs = cfg.get("subtitles", {})
541
- max_cpl = int(subs.get("max_chars_per_line", 42))
542
- max_lines = int(subs.get("max_lines_per_cue", 10))
543
- speaker_display = subs.get("speaker_display", "brackets")
544
-
545
- items: List[Dict[str, Any]] = []
546
- n = min(len(diarization_segments), len(transcriptions), len(speakers_per_segment))
547
- for i in range(n):
548
- seg = diarization_segments[i]
549
- text = (transcriptions[i] or "").strip()
550
- spk = speakers_per_segment[i]
551
- items.append(
552
- {
553
- "start": float(seg.get("start", 0.0)),
554
- "end": float(seg.get("end", 0.0)),
555
- "text": text,
556
- "speaker": spk,
557
- }
558
- )
559
-
560
- out = Path(output_srt_path)
561
- out.parent.mkdir(parents=True, exist_ok=True)
562
- with out.open("w", encoding="utf-8-sig") as f:
563
- for i, it in enumerate(items, 1):
564
- text = it["text"]
565
- spk = it["speaker"]
566
- if speaker_display == "brackets" and spk:
567
- text = f"[{spk}]: {text}" # Adjusted format to match new script's style
568
- elif speaker_display == "prefix" and spk:
569
- text = f"{spk}: {text}"
570
-
571
- # wrap simple
572
- words = text.split()
573
- lines: List[str] = []
574
- cur = ""
575
- for w in words:
576
- if len(cur) + len(w) + (1 if cur else 0) <= max_cpl:
577
- cur = (cur + " " + w) if cur else w
578
- else:
579
- lines.append(cur)
580
- cur = w
581
- if len(lines) >= max_lines - 1:
582
- break
583
- if cur and len(lines) < max_lines:
584
- lines.append(cur)
585
- f.write(f"{i}\n{_fmt_srt_time(it['start'])} --> {_fmt_srt_time(it['end'])}\n")
586
- f.write("\n".join(lines) + "\n\n")
587
-
588
- # ------------------------------ Orchestrator ---------------------------------
589
-
590
- def process_audio_for_video(
591
- video_path: str,
592
- out_dir: Path,
593
- cfg: Dict[str, Any],
594
- voice_collection=None,
595
- ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
596
- """
597
- Audio pipeline: FFmpeg -> diarization -> LID -> ASR -> embeddings -> speaker-ID -> SRT.
598
- Returns (audio_segments, srt_path or None).
599
- """
600
- # 1) Audio extraction
601
- audio_cfg = cfg.get("audio_processing", {})
602
- sr = int(audio_cfg.get("sample_rate", 16000))
603
- fmt = audio_cfg.get("format", "wav")
604
- wav_path = extract_audio_ffmpeg(
605
- video_path, out_dir / f"{Path(video_path).stem}.{fmt}", sr=sr
606
- )
607
- log.info("Audio extraído")
608
-
609
- # 2) Diarización
610
- diar_cfg = audio_cfg.get("diarization", {})
611
- min_dur = float(diar_cfg.get("min_segment_duration", 0.5))
612
- max_dur = float(diar_cfg.get("max_segment_duration", 10.0))
613
- clip_paths, diar_segs = diarize_audio(
614
- wav_path, out_dir, "clips", min_dur, max_dur
615
- )
616
- log.info("Clips de audio generados.")
617
-
618
- # 3) Detección de idioma (opcional) + Selección de backend ASR
619
- asr_cfg = cfg.get("asr", {})
620
- lid_enabled = bool(asr_cfg.get("language_detection", {}).get("enabled", True))
621
-
622
- device_pref = _pick_device_auto(asr_cfg.get("device", "auto"))
623
-
624
- aina_asr = AinaASR(model_name=asr_cfg.get("model_name", "projecte-aina/whisper-large-v3-ca-3catparla"),
625
- device=device_pref)
626
-
627
- whisper_asr = WhisperASR(model_name=asr_cfg.get("whisper_model_name", "openai/whisper-small"),
628
- device=device_pref,
629
- language=None)
630
-
631
- full_transcription = ""
632
- if asr_cfg.get("enable_full_transcription", True):
633
- log.info("Iniciando transcripción del audio completo")
634
- # Assume Catalan model for full transcription, or add logic to check language
635
- full_transcription = aina_asr.transcribe_long_audio(wav_path, chunk_length_s=30)
636
- log.info("Transcripción completa del audio finalizada.")
637
- print(full_transcription)
638
-
639
- # Transcribe each segment
640
- log.info("Comenzamos con la transcripción de cada clip.")
641
- trans: List[str] = []
642
- detected_langs: List[str] = []
643
- detected_probs: List[float] = []
644
- for path in clip_paths:
645
- if not lid_enabled:
646
- txt = aina_asr.transcribe_wav(path)
647
- else:
648
- detected_lang, detected_prob = detect_language_with_whisper(path, cfg)
649
- log.info(f"LID: detected={detected_lang} (p={detected_prob:.2f})")
650
-
651
- if detected_lang.lower() in ["ca", "catalan"]:
652
- txt = aina_asr.transcribe_wav(path)
653
- else:
654
- txt = whisper_asr.transcribe_wav(path)
655
- trans.append(txt)
656
-
657
- log.info("Se han transcrito todos los clips.")
658
-
659
- # 5) Embeddings + Identificación de hablantes
660
- if audio_cfg.get("enable_voice_embeddings", True):
661
- embeddings = embed_voice_segments(clip_paths)
662
- log.info("Embeddings creados de manera correcta para cada clip.")
663
- else:
664
- embeddings = [[] for _ in clip_paths]
665
-
666
- if cfg.get("voice_processing", {}).get("speaker_identification", {}).get("enabled", True):
667
- speakers = identify_speakers(embeddings, voice_collection, cfg)
668
- log.info("Speakers identificados de manera correcta.")
669
- else:
670
- speakers = [seg.get("speaker", f"SPEAKER_{i:02d}") for i, seg in enumerate(diar_segs)]
671
-
672
- # 6) Construir tabla de segmentos
673
- audio_segments: List[Dict[str, Any]] = []
674
- for i, seg in enumerate(diar_segs):
675
- audio_segments.append(
676
- {
677
- "segment": i,
678
- "start": float(seg.get("start", 0.0)),
679
- "end": float(seg.get("end", 0.0)),
680
- "speaker": speakers[i] if i < len(speakers) else seg.get("speaker", f"SPEAKER_{i:02d}"),
681
- "text": trans[i] if i < len(trans) else "",
682
- "voice_embedding": embeddings[i],
683
- "clip_path": str(out_dir / "clips" / f"segment_{i:03d}.wav"),
684
- "lang": detected_langs[i] if i < len(detected_langs) else "auto",
685
- "lang_prob": detected_probs[i] if i < len(detected_probs) else 0.0,
686
- }
687
- )
688
-
689
- # 7) SRT
690
- srt_base_path = out_dir / f"transcripcion_diarizada_{Path(video_path).stem}"
691
- srt_unmodified_path = str(srt_base_path) + "_unmodified.srt"
692
-
693
- # Generate initial SRT
694
- try:
695
- generate_srt_from_diarization(
696
- diar_segs,
697
- [a["text"] for a in audio_segments],
698
- [a["speaker"] for a in audio_segments],
699
- srt_unmodified_path,
700
- cfg,
701
- )
702
-
703
- except Exception as e:
704
- log.warning(f"SRT generation failed: {e}")
705
- srt_unmodified_path = None
706
-
707
- return audio_segments, srt_unmodified_path, full_transcription
708
-
709
- # ----------------------------------- CLI -------------------------------------
710
- if __name__ == "__main__":
711
- import argparse
712
- import yaml
713
-
714
- ap = argparse.ArgumentParser(description="Veureu — Audio tools (self-contained)")
715
- ap.add_argument("--video", required=True)
716
- ap.add_argument("--out", default="results")
717
- ap.add_argument("--config", default="configs/config_veureu.yaml")
718
- args = ap.parse_args()
719
-
720
- cfg: Dict[str, Any] = {}
721
- p = Path(args.config)
722
- if p.exists():
723
- try:
724
- cfg = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
725
- except Exception as e:
726
- log.warning(f"No se pudo leer el YAML de config: {e}")
727
-
728
- out_dir = Path(args.out) / Path(args.video).stem
729
- out_dir.mkdir(parents=True, exist_ok=True)
730
-
731
- # Aggiungi una chiave API di OpenAI al tuo file di configurazione o qui
732
- # Esempio: cfg["api_keys"] = {"openai": "sk-your-openai-api-key"}
733
- # Assicurati di non commettere la chiave in git!
734
-
735
- segs, srt = process_audio_for_video(args.video, out_dir, cfg, voice_collection=None)
736
-
737
- print(json.dumps(
738
- {
739
- "segments": len(segs),
740
- "srt": srt,
741
- "detected_lang": (segs[0].get("lang") if segs else "auto"),
742
- "detected_prob": (segs[0].get("lang_prob") if segs else 0.0),
743
- },
744
- indent=2,
745
- ensure_ascii=False,
746
- ))
 
1
+ # audio_tools.py (ASR delegated to remote HF Space "veureu/asr")
2
+ # -----------------------------------------------------------------------------
3
+ # Veureu — AUDIO utilities (orchestrator w/ remote ASR)
4
+ # - FFmpeg extraction (WAV)
5
+ # - Diarization (pyannote) [local]
6
+ # - Voice embeddings (SpeechBrain ECAPA) [local]
7
+ # - Speaker identification (KMeans + ChromaDB optional) [local]
8
+ # - ASR: delegated to HF Space `veureu/asr` (faster-whisper-large-v3-ca-3catparla)
9
+ # - SRT generation
10
+ # - Orchestrator: process_audio_for_video(...)
11
+ # -----------------------------------------------------------------------------
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import logging
16
+ import math
17
+ import os
18
+ import shlex
19
+ import subprocess
20
+ from pathlib import Path
21
+ from typing import List, Dict, Any, Tuple, Optional
22
+
23
+ import numpy as np
24
+
25
+ # Optional torchaudio for I/O and resampling (fallback to soundfile+librosa otherwise)
26
+ try:
27
+ import torch
28
+ import torchaudio as ta
29
+ import torchaudio.transforms as T
30
+ HAS_TORCHAUDIO = True
31
+ try:
32
+ ta.set_audio_backend("soundfile")
33
+ except Exception:
34
+ pass
35
+ except Exception:
36
+ HAS_TORCHAUDIO = False
37
+ ta = None # type: ignore
38
+
39
+ import soundfile as sf
40
+
41
+ # Pyannote for diarization (local)
42
+ from pyannote.audio import Pipeline
43
+
44
+ # Speaker embeddings (local)
45
+ from speechbrain.inference import SpeakerRecognition # v1.0+
46
+
47
+ # Clustering
48
+ from sklearn.cluster import KMeans
49
+ from sklearn.metrics import silhouette_score
50
+
51
+ # Router to remote Spaces (asr)
52
+ from llm_router import load_yaml, LLMRouter
53
+
54
+ # -------------------------------- Logging ------------------------------------
55
+ log = logging.getLogger("audio_tools")
56
+ if not log.handlers:
57
+ _h = logging.StreamHandler()
58
+ _h.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
59
+ log.addHandler(_h)
60
+ log.setLevel(logging.INFO)
61
+
62
+ # ------------------------------- Utilities -----------------------------------
63
+
64
+ def load_wav(path: str | Path, sr: int = 16000):
65
+ """Load audio as mono float32 at the requested sample rate."""
66
+ if HAS_TORCHAUDIO:
67
+ wav, in_sr = ta.load(str(path))
68
+ if in_sr != sr:
69
+ wav = ta.functional.resample(wav, in_sr, sr)
70
+ if wav.dim() > 1:
71
+ wav = wav.mean(dim=0, keepdim=True)
72
+ return wav.squeeze(0).numpy(), sr
73
+ import librosa
74
+ y, in_sr = sf.read(str(path), dtype="float32", always_2d=False)
75
+ if y.ndim > 1:
76
+ y = y.mean(axis=1)
77
+ if in_sr != sr:
78
+ y = librosa.resample(y, orig_sr=in_sr, target_sr=sr)
79
+ return y.astype(np.float32), sr
80
+
81
+ def save_wav(path: str | Path, y, sr: int = 16000):
82
+ """Save mono float32 wav."""
83
+ if HAS_TORCHAUDIO:
84
+ import torch
85
+ wav = torch.from_numpy(np.asarray(y, dtype=np.float32)).unsqueeze(0)
86
+ ta.save(str(path), wav, sr)
87
+ else:
88
+ sf.write(str(path), np.asarray(y, dtype=np.float32), sr)
89
+
90
+ def extract_audio_ffmpeg(
91
+ video_path: str,
92
+ audio_out: Path,
93
+ sr: int = 16000,
94
+ mono: bool = True,
95
+ ) -> str:
96
+ """Extract audio from video to WAV using ffmpeg."""
97
+ audio_out.parent.mkdir(parents=True, exist_ok=True)
98
+ cmd = f'ffmpeg -y -i "{video_path}" -vn {"-ac 1" if mono else ""} -ar {sr} -f wav "{audio_out}"'
99
+ subprocess.run(
100
+ shlex.split(cmd),
101
+ check=True,
102
+ stdout=subprocess.DEVNULL,
103
+ stderr=subprocess.DEVNULL,
104
+ )
105
+ return str(audio_out)
106
+
107
+ # ----------------------------------- ASR (REMOTE) -------------------------------------
108
+
109
+ def transcribe_audio_remote(audio_path: str | Path, cfg: Dict[str, Any]) -> Dict[str, Any]:
110
+ """
111
+ Send the audio file to the remote ASR Space `veureu/asr` (Gradio or HTTP).
112
+ The remote model is 'faster-whisper-large-v3-ca-3catparla' (Aina).
113
+ Returns standardized dict: {'text': str, 'segments': list?}
114
+ """
115
+ if not cfg:
116
+ cfg = load_yaml("config.yaml")
117
+ router = LLMRouter(cfg)
118
+ model_name = (cfg.get("models", {}).get("asr") or "whisper-catalan")
119
+ params = {
120
+ "language": "ca",
121
+ "model": "faster-whisper-large-v3-ca-3catparla",
122
+ "timestamps": True,
123
+ "diarization": False, # diarization stays local
124
+ }
125
+ result = router.asr_transcribe(str(audio_path), model=model_name, **params)
126
+
127
+ if isinstance(result, str):
128
+ return {"text": result, "segments": []}
129
+ if isinstance(result, dict):
130
+ if "text" not in result and "transcription" in result:
131
+ result["text"] = result["transcription"]
132
+ result.setdefault("segments", [])
133
+ return result
134
+ return {"text": str(result), "segments": []}
135
+
136
+ # -------------------------------- Diarization --------------------------------
137
+
138
+ def diarize_audio(
139
+ wav_path: str,
140
+ base_dir: Path,
141
+ clips_folder: str = "clips",
142
+ min_segment_duration: float = 20.0,
143
+ max_segment_duration: float = 50.0,
144
+ hf_token_env: str | None = None,
145
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
146
+ """Diarization with pyannote and clip export with pydub."""
147
+ from pydub import AudioSegment
148
+ audio = AudioSegment.from_wav(wav_path)
149
+ duration = len(audio) / 1000.0
150
+
151
+ pipeline = Pipeline.from_pretrained(
152
+ "pyannote/speaker-diarization-3.1",
153
+ use_auth_token=(hf_token_env or os.getenv("HF_TOKEN"))
154
+ )
155
+ diarization = pipeline(wav_path)
156
+
157
+ clips_dir = (base_dir / clips_folder)
158
+ clips_dir.mkdir(parents=True, exist_ok=True)
159
+ clip_paths: List[str] = []
160
+ segments: List[Dict[str, Any]] = []
161
+ spk_map: Dict[str, int] = {}
162
+ prev_end = 0.0
163
+
164
+ for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
165
+ start, end = max(0.0, float(turn.start)), min(duration, float(turn.end))
166
+ if start < prev_end:
167
+ start = prev_end
168
+ if end <= start:
169
+ continue
170
+
171
+ seg_dur = end - start
172
+ if seg_dur < min_segment_duration:
173
+ continue
174
+
175
+ if seg_dur > max_segment_duration:
176
+ n = int(math.ceil(seg_dur / max_segment_duration))
177
+ sub_d = seg_dur / n
178
+ for j in range(n):
179
+ s = start + j * sub_d
180
+ e = min(end, start + (j + 1) * sub_d)
181
+ if e <= s:
182
+ continue
183
+ clip = audio[int(s * 1000):int(e * 1000)]
184
+ cp = clips_dir / f"segment_{i:03d}_{j:02d}.wav"
185
+ clip.export(cp, format="wav")
186
+ if speaker not in spk_map:
187
+ spk_map[speaker] = len(spk_map)
188
+ segments.append({"start": s, "end": e, "speaker": f"SPEAKER_{spk_map[speaker]:02d}"})
189
+ clip_paths.append(str(cp))
190
+ prev_end = e
191
+ else:
192
+ clip = audio[int(start * 1000):int(end * 1000)]
193
+ cp = clips_dir / f"segment_{i:03d}.wav"
194
+ clip.export(cp, format="wav")
195
+ if speaker not in spk_map:
196
+ spk_map[speaker] = len(spk_map)
197
+ segments.append({"start": start, "end": end, "speaker": f"SPEAKER_{spk_map[speaker]:02d}"})
198
+ clip_paths.append(str(cp))
199
+ prev_end = end
200
+
201
+ if not segments:
202
+ cp = clips_dir / "segment_000.wav"
203
+ audio.export(cp, format="wav")
204
+ return [str(cp)], [{"start": 0.0, "end": duration, "speaker": "SPEAKER_00"}]
205
+
206
+ pairs = sorted(zip(clip_paths, segments), key=lambda x: x[1]["start"])
207
+ clip_paths, segments = [p[0] for p in pairs], [p[1] for p in pairs]
208
+ return clip_paths, segments
209
+
210
+ # ------------------------------ Voice embeddings -----------------------------
211
+
212
+ class VoiceEmbedder:
213
+ def __init__(self):
214
+ self.model = SpeakerRecognition.from_hparams(
215
+ source="speechbrain/spkrec-ecapa-voxceleb",
216
+ savedir="pretrained_models/spkrec-ecapa-voxceleb",
217
+ )
218
+ self.model.eval()
219
+
220
+ def embed(self, wav_path: str) -> List[float]:
221
+ if HAS_TORCHAUDIO:
222
+ waveform, sr = ta.load(wav_path)
223
+ target_sr = 16000
224
+ if sr != target_sr:
225
+ waveform = T.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
226
+ if waveform.shape[0] > 1:
227
+ waveform = waveform.mean(dim=0, keepdim=True)
228
+ min_samples = int(0.2 * target_sr)
229
+ if waveform.shape[1] < min_samples:
230
+ pad = min_samples - waveform.shape[1]
231
+ import torch
232
+ waveform = torch.cat([waveform, torch.zeros((1, pad))], dim=1)
233
+ with torch.no_grad(): # type: ignore
234
+ emb = self.model.encode_batch(waveform).squeeze().cpu().numpy().astype(float)
235
+ return emb.tolist()
236
+ else:
237
+ y, sr = load_wav(wav_path, sr=16000)
238
+ min_len = int(0.2 * 16000)
239
+ if len(y) < min_len:
240
+ y = np.pad(y, (0, min_len - len(y)))
241
+ import torch
242
+ w = torch.from_numpy(y).unsqueeze(0).unsqueeze(0)
243
+ with torch.no_grad(): # type: ignore
244
+ emb = self.model.encode_batch(w).squeeze().cpu().numpy().astype(float)
245
+ return emb.tolist()
246
+
247
+
248
+ def embed_voice_segments(clip_paths: List[str]) -> List[List[float]]:
249
+ ve = VoiceEmbedder()
250
+ out: List[List[float]] = []
251
+ for cp in clip_paths:
252
+ try:
253
+ out.append(ve.embed(cp))
254
+ except Exception as e:
255
+ log.warning(f"Embedding error in {cp}: {e}")
256
+ out.append([])
257
+ return out
258
+
259
+ # --------------------------- Speaker identification --------------------------
260
+
261
+ def identify_speakers(
262
+ embeddings: List[List[float]],
263
+ voice_collection,
264
+ cfg: Dict[str, Any],
265
+ ) -> List[str]:
266
+ voice_cfg = cfg.get("voice_processing", {}).get("speaker_identification", {})
267
+ if not embeddings or sum(1 for e in embeddings if e) < 2:
268
+ return ["SPEAKER_00" for _ in embeddings]
269
+
270
+ valid = [e for e in embeddings if e and len(e) > 0]
271
+ if len(valid) < 2:
272
+ return ["SPEAKER_00" for _ in embeddings]
273
+
274
+ min_clusters = max(1, int(voice_cfg.get("min_speakers", 1)))
275
+ max_clusters = min(int(voice_cfg.get("max_speakers", 5)), len(valid) - 1)
276
+
277
+ if voice_cfg.get("find_optimal_clusters", True) and len(valid) > 2:
278
+ best_score, best_k = -1.0, min_clusters
279
+ for k in range(min_clusters, max_clusters + 1):
280
+ if k >= len(valid):
281
+ break
282
+ km = KMeans(n_clusters=k, random_state=42, n_init="auto")
283
+ labels = km.fit_predict(valid)
284
+ if len(set(labels)) > 1:
285
+ score = silhouette_score(valid, labels)
286
+ if score > best_score:
287
+ best_score, best_k = score, k
288
+ else:
289
+ best_k = min(max_clusters, max(min_clusters, int(voice_cfg.get("num_speakers", 2))))
290
+ best_k = max(1, min(best_k, len(valid) - 1))
291
+
292
+ km = KMeans(n_clusters=best_k, random_state=42, n_init="auto", init="k-means++")
293
+ labels = km.fit_predict(np.array(valid))
294
+ centers = km.cluster_centers_
295
+
296
+ cluster_to_name: Dict[int, str] = {}
297
+ unknown_counter = 0
298
+ for cid in range(best_k):
299
+ center = centers[cid].tolist()
300
+ name = f"SPEAKER_{cid:02d}"
301
+ if voice_collection is not None:
302
+ try:
303
+ q = voice_collection.query(query_embeddings=[center], n_results=1)
304
+ metas = q.get("metadatas", [[]])[0]
305
+ dists = q.get("distances", [[]])[0]
306
+ thr = voice_cfg.get("distance_threshold")
307
+ if dists and thr is not None and dists[0] > thr:
308
+ name = f"UNKNOWN_{unknown_counter}"
309
+ unknown_counter += 1
310
+ voice_collection.add(
311
+ embeddings=[center],
312
+ metadatas=[{"name": name}],
313
+ ids=[f"unk_{cid}_{unknown_counter}"],
314
+ )
315
+ else:
316
+ if metas and isinstance(metas[0], dict):
317
+ name = metas[0].get("nombre") or metas[0].get("name") \
318
+ or metas[0].get("speaker") or metas[0].get("identity") or name
319
+ except Exception as e:
320
+ log.warning(f"Voice KNN query failed: {e}")
321
+ cluster_to_name[cid] = name
322
+
323
+ personas: List[str] = []
324
+ vi = 0
325
+ for emb in embeddings:
326
+ if not emb:
327
+ personas.append("UNKNOWN")
328
+ else:
329
+ label = int(labels[vi])
330
+ personas.append(cluster_to_name.get(label, f"SPEAKER_{label:02d}"))
331
+ vi += 1
332
+ return personas
333
+
334
+ # ----------------------------------- SRT -------------------------------------
335
+
336
+ def _fmt_srt_time(seconds: float) -> str:
337
+ h = int(seconds // 3600)
338
+ m = int((seconds % 3600) // 60)
339
+ s = int(seconds % 60)
340
+ ms = int(round((seconds - int(seconds)) * 1000))
341
+ return f"{h:02}:{m:02}:{s:02},{ms:03}"
342
+
343
+ def generate_srt_from_diarization(
344
+ diarization_segments: List[Dict[str, Any]],
345
+ transcriptions: List[str],
346
+ speakers_per_segment: List[str],
347
+ output_srt_path: str,
348
+ cfg: Dict[str, Any],
349
+ ) -> None:
350
+ subs = cfg.get("subtitles", {})
351
+ max_cpl = int(subs.get("max_chars_per_line", 42))
352
+ max_lines = int(subs.get("max_lines_per_cue", 10))
353
+ speaker_display = subs.get("speaker_display", "brackets")
354
+
355
+ items: List[Dict[str, Any]] = []
356
+ n = min(len(diarization_segments), len(transcriptions), len(speakers_per_segment))
357
+ for i in range(n):
358
+ seg = diarization_segments[i]
359
+ text = (transcriptions[i] or "").strip()
360
+ spk = speakers_per_segment[i]
361
+ items.append({"start": float(seg.get("start", 0.0)), "end": float(seg.get("end", 0.0)), "text": text, "speaker": spk})
362
+
363
+ out = Path(output_srt_path)
364
+ out.parent.mkdir(parents=True, exist_ok=True)
365
+ with out.open("w", encoding="utf-8-sig") as f:
366
+ for i, it in enumerate(items, 1):
367
+ text = it["text"]
368
+ spk = it["speaker"]
369
+ if speaker_display == "brackets" and spk:
370
+ text = f"[{spk}]: {text}"
371
+ elif speaker_display == "prefix" and spk:
372
+ text = f"{spk}: {text}"
373
+ words = text.split()
374
+ lines: List[str] = []
375
+ cur = ""
376
+ for w in words:
377
+ if len(cur) + len(w) + (1 if cur else 0) <= max_cpl:
378
+ cur = (cur + " " + w) if cur else w
379
+ else:
380
+ lines.append(cur)
381
+ cur = w
382
+ if len(lines) >= max_lines - 1:
383
+ break
384
+ if cur and len(lines) < max_lines:
385
+ lines.append(cur)
386
+ f.write(f"{i}\n{_fmt_srt_time(it['start'])} --> {_fmt_srt_time(it['end'])}\n")
387
+ f.write("\n".join(lines) + "\n\n")
388
+
389
+ # ------------------------------ Orchestrator ---------------------------------
390
+
391
+ def process_audio_for_video(
392
+ video_path: str,
393
+ out_dir: Path,
394
+ cfg: Dict[str, Any],
395
+ voice_collection=None,
396
+ ) -> Tuple[List[Dict[str, Any]], Optional[str], str]:
397
+ """
398
+ Audio pipeline: FFmpeg -> diarization -> remote ASR (full + clips) -> embeddings -> speaker-ID -> SRT.
399
+ Returns (audio_segments, srt_path or None, full_transcription_text).
400
+ """
401
+ audio_cfg = cfg.get("audio_processing", {})
402
+ sr = int(audio_cfg.get("sample_rate", 16000))
403
+ fmt = audio_cfg.get("format", "wav")
404
+ wav_path = extract_audio_ffmpeg(video_path, out_dir / f"{Path(video_path).stem}.{fmt}", sr=sr)
405
+ log.info("Audio extraído")
406
+
407
+ diar_cfg = audio_cfg.get("diarization", {})
408
+ min_dur = float(diar_cfg.get("min_segment_duration", 20.0))
409
+ max_dur = float(diar_cfg.get("max_segment_duration", 50.0))
410
+ clip_paths, diar_segs = diarize_audio(wav_path, out_dir, "clips", min_dur, max_dur)
411
+ log.info("Clips de audio generados.")
412
+
413
+ full_transcription = ""
414
+ asr_section = cfg.get("asr", {})
415
+ if asr_section.get("enable_full_transcription", True):
416
+ log.info("Transcripción completa (remota, Space 'asr')...")
417
+ full_res = transcribe_audio_remote(wav_path, cfg)
418
+ full_transcription = full_res.get("text", "") or ""
419
+ log.info("Transcripción completa finalizada.")
420
+
421
+ log.info("Transcripción por clip (remota, Space 'asr')...")
422
+ trans: List[str] = []
423
+ for cp in clip_paths:
424
+ res = transcribe_audio_remote(cp, cfg)
425
+ trans.append(res.get("text", ""))
426
+
427
+ log.info("Se han transcrito todos los clips.")
428
+
429
+ embeddings = embed_voice_segments(clip_paths) if audio_cfg.get("enable_voice_embeddings", True) else [[] for _ in clip_paths]
430
+
431
+ if cfg.get("voice_processing", {}).get("speaker_identification", {}).get("enabled", True):
432
+ speakers = identify_speakers(embeddings, voice_collection, cfg)
433
+ log.info("Speakers identificados correctamente.")
434
+ else:
435
+ speakers = [seg.get("speaker", f"SPEAKER_{i:02d}") for i, seg in enumerate(diar_segs)]
436
+
437
+ audio_segments: List[Dict[str, Any]] = []
438
+ for i, seg in enumerate(diar_segs):
439
+ audio_segments.append(
440
+ {
441
+ "segment": i,
442
+ "start": float(seg.get("start", 0.0)),
443
+ "end": float(seg.get("end", 0.0)),
444
+ "speaker": speakers[i] if i < len(speakers) else seg.get("speaker", f"SPEAKER_{i:02d}"),
445
+ "text": trans[i] if i < len(trans) else "",
446
+ "voice_embedding": embeddings[i],
447
+ "clip_path": str(out_dir / "clips" / f"segment_{i:03d}.wav"),
448
+ "lang": "ca",
449
+ "lang_prob": 1.0,
450
+ }
451
+ )
452
+
453
+ srt_base_path = out_dir / f"transcripcion_diarizada_{Path(video_path).stem}"
454
+ srt_unmodified_path = str(srt_base_path) + "_unmodified.srt"
455
+
456
+ try:
457
+ generate_srt_from_diarization(
458
+ diar_segs,
459
+ [a["text"] for a in audio_segments],
460
+ [a["speaker"] for a in audio_segments],
461
+ srt_unmodified_path,
462
+ cfg,
463
+ )
464
+ except Exception as e:
465
+ log.warning(f"SRT generation failed: {e}")
466
+ srt_unmodified_path = None
467
+
468
+ return audio_segments, srt_unmodified_path, full_transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
background_descriptor.py CHANGED
@@ -1,10 +1,118 @@
1
- # ================================
2
- # PATCH: background_descriptor.py (describe_keyframes_with_llm)
3
- # ================================
4
- # Sustituimos la llamada directa a describe_montage_sequence por router.vision_describe
5
- # y mantenemos como fallback la función existente.
6
 
7
- # (reemplaza la función en este archivo por esta versión)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def describe_keyframes_with_llm(
10
  keyframes: List[Dict[str, Any]],
@@ -12,7 +120,6 @@ def describe_keyframes_with_llm(
12
  face_identities: Optional[set] = None,
13
  config_path: str | None = None,
14
  ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
15
- from llm_router import load_yaml, LLMRouter
16
  cfg = load_yaml(config_path or "config.yaml")
17
  model_name = (cfg.get("background_descriptor", {}).get("description", {}) or {}).get("model", "salamandra-vision")
18
 
@@ -29,15 +136,14 @@ def describe_keyframes_with_llm(
29
  router = LLMRouter(cfg)
30
  descs = router.vision_describe(frame_paths, context=context, model=model_name)
31
  except Exception:
32
- # Fallback a implementación local existente si falla el remoto
33
  descs = describe_montage_sequence(
34
  montage_path=str(montage_path),
35
  n=len(frame_paths),
36
  informacion=keyframes,
37
  face_identities=face_identities or set(),
38
- config_path=config_path or "config_veureu.yaml",
39
  )
40
  for i, fr in enumerate(keyframes):
41
  if i < len(descs):
42
  fr["description"] = descs[i]
43
- return keyframes, str(montage_path) if montage_path else None
 
1
+ from __future__ import annotations
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+ from pathlib import Path
 
 
4
 
5
+ from sentence_transformers import SentenceTransformer
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+ from vision_tools import (
9
+ keyframe_conditional_extraction_ana,
10
+ keyframe_every_second,
11
+ process_frames,
12
+ FaceOfImageEmbedding,
13
+ generar_montage,
14
+ describe_montage_sequence, # fallback local
15
+ )
16
+
17
+ from llm_router import load_yaml, LLMRouter
18
+
19
+ def cluster_ocr_sequential(ocr_list: List[Dict[str, Any]], threshold: float = 0.6) -> List[Dict[str, Any]]:
20
+ if not ocr_list:
21
+ return []
22
+ ocr_text = [item.get("ocr") for item in ocr_list if item and isinstance(item.get("ocr"), str)]
23
+ if not ocr_text:
24
+ return []
25
+ model = SentenceTransformer("all-MiniLM-L6-v2")
26
+ embeddings = model.encode(ocr_text, normalize_embeddings=True)
27
+
28
+ clusters_repr = []
29
+ prev_emb = embeddings[0]
30
+ start_time = ocr_list[0]["start"]
31
+ for i, emb in enumerate(embeddings[1:], 1):
32
+ sim = cosine_similarity([prev_emb], [emb])[0][0]
33
+ if sim < threshold:
34
+ clusters_repr.append({"index": i - 1, "start_time": start_time})
35
+ prev_emb = emb
36
+ start_time = ocr_list[i]["start"]
37
+ clusters_repr.append({"index": len(embeddings) - 1, "start_time": start_time})
38
+
39
+ ocr_final = []
40
+ for cluster in clusters_repr:
41
+ idx = cluster["index"]
42
+ if idx < len(ocr_list) and ocr_list[idx].get("ocr"):
43
+ it = ocr_list[idx]
44
+ ocr_final.append({
45
+ "ocr": it.get("ocr"),
46
+ "image_path": it.get("image_path"),
47
+ "start": cluster["start_time"],
48
+ "end": it.get("end"),
49
+ "faces": it.get("faces"),
50
+ })
51
+ return ocr_final
52
+
53
+ def build_keyframes_and_per_second(
54
+ video_path: str,
55
+ out_dir: Path,
56
+ cfg: Dict[str, Any],
57
+ face_collection=None,
58
+ ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], float]:
59
+ kf_dir = out_dir / "keyframes"
60
+ ps_dir = out_dir / "frames_per_second"
61
+
62
+ keyframes = keyframe_conditional_extraction_ana(video_path=video_path, output_dir=str(kf_dir))
63
+ per_second = keyframe_every_second(video_path=video_path, output_dir=str(ps_dir))
64
+
65
+ embedder = FaceOfImageEmbedding(deepface_model="Facenet512")
66
+ kf_proc = process_frames(frames=keyframes, config=cfg, face_col=face_collection, embedding_model=embedder)
67
+ ps_proc = process_frames(frames=per_second, config=cfg, face_col=face_collection, embedding_model=embedder)
68
+
69
+ ocr_list = [{
70
+ "ocr": fr.get("ocr"),
71
+ "image_path": fr.get("image_path"),
72
+ "start": fr.get("start"),
73
+ "end": fr.get("end"),
74
+ "faces": fr.get("faces"),
75
+ } for fr in ps_proc]
76
+ ocr_final = cluster_ocr_sequential(ocr_list, threshold=float(cfg.get("video_processing", {}).get("ocr_clustering", {}).get("similarity_threshold", 0.6)))
77
+
78
+ kf_mod: List[Dict[str, Any]] = []
79
+ idx = 1
80
+ for k in kf_proc:
81
+ ks, ke = k["start"], k["end"]
82
+ inicio = True
83
+ sustituido = False
84
+ for f in ocr_final:
85
+ if f["start"] >= ks and f["end"] <= ke and inicio:
86
+ kf_mod.append({
87
+ "id": idx,
88
+ "start": k["start"],
89
+ "end": None,
90
+ "image_path": f["image_path"],
91
+ "faces": f["faces"],
92
+ "ocr": f.get("ocr"),
93
+ "description": None,
94
+ })
95
+ idx += 1
96
+ sustituido = True
97
+ inicio = False
98
+ elif f["start"] >= ks and f["end"] <= ke and not inicio:
99
+ kf_mod.append({
100
+ "id": idx,
101
+ "start": f["start"],
102
+ "end": None,
103
+ "image_path": f["image_path"],
104
+ "faces": f["faces"],
105
+ "ocr": f.get("ocr"),
106
+ "description": None,
107
+ })
108
+ idx += 1
109
+ if not sustituido:
110
+ k2 = dict(k)
111
+ k2["id"] = idx
112
+ kf_mod.append(k2)
113
+ idx += 1
114
+
115
+ return kf_mod, ps_proc, 0.0
116
 
117
  def describe_keyframes_with_llm(
118
  keyframes: List[Dict[str, Any]],
 
120
  face_identities: Optional[set] = None,
121
  config_path: str | None = None,
122
  ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
 
123
  cfg = load_yaml(config_path or "config.yaml")
124
  model_name = (cfg.get("background_descriptor", {}).get("description", {}) or {}).get("model", "salamandra-vision")
125
 
 
136
  router = LLMRouter(cfg)
137
  descs = router.vision_describe(frame_paths, context=context, model=model_name)
138
  except Exception:
 
139
  descs = describe_montage_sequence(
140
  montage_path=str(montage_path),
141
  n=len(frame_paths),
142
  informacion=keyframes,
143
  face_identities=face_identities or set(),
144
+ config_path=config_path or "config.yaml",
145
  )
146
  for i, fr in enumerate(keyframes):
147
  if i < len(descs):
148
  fr["description"] = descs[i]
149
+ return keyframes, str(montage_path) if montage_path else None
config.yaml CHANGED
@@ -3,93 +3,90 @@
3
  # ===========================
4
 
5
  engine:
6
- # Salida de artefactos
7
  output_root: "results"
8
- # Persistencia de índices vectoriales
9
- database:
10
- enabled: true
11
- persist_directory: "chroma_db"
12
- enable_face_recognition: true
13
- enable_voice_recognition: true
14
- face_collection: "index_faces"
15
- voice_collection: "index_voices"
16
-
17
- # Jobs asíncronos (si implementas el patrón de cola)
18
- jobs:
19
- enabled: true
20
- max_workers: 1 # Ajusta según recursos del Space
21
- result_ttl_seconds: 86400 # 1 día
22
 
23
  api:
24
  cors_allow_origins: ["*"]
25
- # Tiempo máximo (segundos) de una petición síncrona (si usas el endpoint sync)
26
  sync_timeout_seconds: 3600
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  video_processing:
29
- # Metadatos de extracción de frames
30
  keyframes:
31
- # Si tu extractor condicional usa umbrales, puedes incluirlos aquí:
32
  conditional_extraction:
33
  enable: true
34
- # ejemplos de parámetros (ajústalos a tu extractor real)
35
  min_scene_length_seconds: 1.5
36
  difference_threshold: 28.0
37
 
38
  frames_per_second:
39
  enable: true
40
- fps: 1.0 # "frecuencia de frames de análisis" (1 frame/segundo por defecto)
41
 
42
  ocr:
43
- engine: "tesseract" # o "easyocr"
44
- language_hint: "spa" # si aplica
45
- # Solo si usas pytesseract:
46
- tesseract_cmd: "" # ruta binaria si no está en PATH
47
 
48
  faces:
49
- detector_model: "mtcnn" # ejemplar; ajústalo a tu vision_tools
50
- embedding_model: "Facenet512" # usado en background_descriptor.FaceOfImageEmbedding
51
  min_face_size: 32
52
  detection_confidence: 0.85
53
 
54
  ocr_clustering:
55
  method: "sequential_similarity"
56
  sentence_transformer: "all-MiniLM-L6-v2"
57
- similarity_threshold: 0.60 # "número de clusters" implícito por umbral (más alto ⇒ menos clusters)
58
 
59
  audio_processing:
60
- # El ASR principal será remoto si seleccionas whisper catalan (ver models/routing)
 
 
61
  diarization:
62
  enabled: true
63
- # ejemplo de parámetros de diarización si los usas en audio_tools
64
- min_speaker_duration: 0.8
65
- max_speakers: 8
66
 
 
67
  speaker_embedding:
68
  enabled: true
69
- # umbral para asig. de identidad en voice_collection
 
 
70
  speaker_identification:
 
 
 
 
71
  distance_threshold: 0.40
72
 
73
- # Si mantienes transcripción local para otros idiomas/modelos:
74
- local_asr:
75
- enabled: false # usarás remoto para whisper catalan
76
- model: "" # (vacío porque lo gestionas vía Spaces)
77
 
78
  background_descriptor:
79
- # Parámetros del montaje y descripción con LLM
80
  montage:
81
  enable: true
82
- max_frames: 12 # tope de frames en el collage/descripcion
83
- grid: "auto" # o 3x4, etc.
84
 
85
  description:
86
- model: "salamandra-vision" # puede ser "salamandra-vision" o "gpt-4o-mini"
87
  max_tokens: 512
88
  temperature: 0.2
89
- # Si hay identidades detectadas, se pasan como hints (ya lo hace tu pipeline)
90
 
91
  identity:
92
- # Reglas de mapeo temporal y enriquecimiento
93
  timeline_mapping:
94
  per_second_frames_source: "frames_per_second"
95
  attach_faces_to:
@@ -100,23 +97,26 @@ identity:
100
  narration:
101
  model: "salamandra-instruct" # "salamandra-instruct" | "gpt-4o-mini"
102
  une_guidelines_path: "UNE_153010.txt"
103
- # Restricciones temporales (para UNE-153010)
104
  timing:
105
- max_ad_duration_ratio: 0.60 # proporción del hueco disponible que puede ocupar la AD
106
  min_gap_seconds: 1.20
107
  min_ad_seconds: 0.80
108
  llm:
109
  max_tokens: 1024
110
  temperature: 0.2
111
 
 
 
 
 
 
112
  models:
113
- # Selección de modelos de alto nivel por tarea
114
- instruct: "salamandra-instruct" # para NarrationSystem y otros textos
115
- vision: "salamandra-vision" # para describir frames/montajes
116
- tools: "salamandra-tools" # si necesitas funciones con tool-calling
117
- asr: "whisper-catalan" # ASR catalán
118
 
119
- # Enrutado: qué modelos se ejecutan REMOTO (vía otros Spaces)
120
  routing:
121
  use_remote_for:
122
  - "salamandra-instruct"
@@ -125,47 +125,43 @@ models:
125
  - "whisper-catalan"
126
 
127
  remote_spaces:
128
- # Dónde llamar cuando models.routing decide “remoto”
129
  user: "veureu"
130
 
131
  endpoints:
132
- # Nota: rellena las URLs reales cuando publiques los Spaces.
133
  salamandra-instruct:
134
  space: "schat"
135
  base_url: "https://veureu-schat.hf.space"
136
- client: "gradio" # "gradio" o "http"
137
- predict_route: "/run/predict" # si usas gradio_client no necesitas ruta
138
 
139
  salamandra-vision:
140
  space: "svision"
141
  base_url: "https://veureu-svision.hf.space"
142
  client: "gradio"
143
- predict_route: "/run/predict"
144
 
145
  salamandra-tools:
146
  space: "stools"
147
  base_url: "https://veureu-stools.hf.space"
148
  client: "gradio"
149
- predict_route: "/run/predict"
150
 
151
  whisper-catalan:
152
- space: "ars"
153
- base_url: "https://veureu-ars.hf.space"
154
  client: "gradio"
155
- predict_route: "/run/predict"
156
 
157
- # Parámetros de red y robustez
158
  http:
159
- timeout_seconds: 120
160
  retries: 3
161
  backoff_seconds: 2.0
162
 
163
  security:
164
- # Si necesitas pasar tokens (p. ej., tokens del Hub o auth propia)
165
  use_hf_token: true
166
- hf_token_env: "HF_TOKEN" # nombre de la variable de entorno para el token
167
  allow_insecure_tls: false
168
 
169
  logging:
170
- level: "INFO" # DEBUG | INFO | WARNING | ERROR
171
  json: false
 
3
  # ===========================
4
 
5
  engine:
 
6
  output_root: "results"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  api:
9
  cors_allow_origins: ["*"]
 
10
  sync_timeout_seconds: 3600
11
 
12
+ database:
13
+ enabled: true
14
+ persist_directory: "chroma_db"
15
+ enable_face_recognition: true
16
+ enable_voice_recognition: true
17
+ face_collection: "index_faces"
18
+ voice_collection: "index_voices"
19
+
20
+ jobs:
21
+ enabled: false # si activas cola async, cámbialo a true y añade JobManager en main_api.py
22
+ max_workers: 1
23
+ result_ttl_seconds: 86400
24
+
25
  video_processing:
 
26
  keyframes:
 
27
  conditional_extraction:
28
  enable: true
 
29
  min_scene_length_seconds: 1.5
30
  difference_threshold: 28.0
31
 
32
  frames_per_second:
33
  enable: true
34
+ fps: 1.0 # Frecuencia de frames de análisis
35
 
36
  ocr:
37
+ engine: "tesseract" # "tesseract" | "easyocr"
38
+ language_hint: "spa"
39
+ tesseract_cmd: "" # si no está en PATH, deja la ruta
 
40
 
41
  faces:
42
+ detector_model: "mtcnn" # ajusta a tu vision_tools
43
+ embedding_model: "Facenet512" # usado por FaceOfImageEmbedding
44
  min_face_size: 32
45
  detection_confidence: 0.85
46
 
47
  ocr_clustering:
48
  method: "sequential_similarity"
49
  sentence_transformer: "all-MiniLM-L6-v2"
50
+ similarity_threshold: 0.60 # mayor ⇒ menos clusters
51
 
52
  audio_processing:
53
+ sample_rate: 16000
54
+ format: "wav"
55
+
56
  diarization:
57
  enabled: true
58
+ min_segment_duration: 20.0 # en segundos (post-procesado de turnos)
59
+ max_segment_duration: 50.0
 
60
 
61
+ enable_voice_embeddings: true # SpeechBrain ECAPA
62
  speaker_embedding:
63
  enabled: true
64
+
65
+ # Identificación de hablantes (clustering + Chroma)
66
+ voice_processing:
67
  speaker_identification:
68
+ enabled: true
69
+ find_optimal_clusters: true
70
+ min_speakers: 1
71
+ max_speakers: 5
72
  distance_threshold: 0.40
73
 
74
+ asr:
75
+ # Controla la transcripción del audio completo además de los clips (útil para contexto global)
76
+ enable_full_transcription: true
 
77
 
78
  background_descriptor:
 
79
  montage:
80
  enable: true
81
+ max_frames: 12
82
+ grid: "auto"
83
 
84
  description:
85
+ model: "salamandra-vision" # o "gpt-4o-mini"
86
  max_tokens: 512
87
  temperature: 0.2
 
88
 
89
  identity:
 
90
  timeline_mapping:
91
  per_second_frames_source: "frames_per_second"
92
  attach_faces_to:
 
97
  narration:
98
  model: "salamandra-instruct" # "salamandra-instruct" | "gpt-4o-mini"
99
  une_guidelines_path: "UNE_153010.txt"
 
100
  timing:
101
+ max_ad_duration_ratio: 0.60
102
  min_gap_seconds: 1.20
103
  min_ad_seconds: 0.80
104
  llm:
105
  max_tokens: 1024
106
  temperature: 0.2
107
 
108
+ subtitles:
109
+ max_chars_per_line: 42
110
+ max_lines_per_cue: 10
111
+ speaker_display: "brackets" # "brackets" | "prefix" | "none"
112
+
113
  models:
114
+ # alias de tarea modelo
115
+ instruct: "salamandra-instruct"
116
+ vision: "salamandra-vision"
117
+ tools: "salamandra-tools"
118
+ asr: "whisper-catalan" # apunta al Space veureu/asr (Aina: faster-whisper-large-v3-ca-3catparla)
119
 
 
120
  routing:
121
  use_remote_for:
122
  - "salamandra-instruct"
 
125
  - "whisper-catalan"
126
 
127
  remote_spaces:
 
128
  user: "veureu"
129
 
130
  endpoints:
 
131
  salamandra-instruct:
132
  space: "schat"
133
  base_url: "https://veureu-schat.hf.space"
134
+ client: "gradio"
135
+ predict_route: "/predict"
136
 
137
  salamandra-vision:
138
  space: "svision"
139
  base_url: "https://veureu-svision.hf.space"
140
  client: "gradio"
141
+ predict_route: "/predict"
142
 
143
  salamandra-tools:
144
  space: "stools"
145
  base_url: "https://veureu-stools.hf.space"
146
  client: "gradio"
147
+ predict_route: "/predict"
148
 
149
  whisper-catalan:
150
+ space: "asr"
151
+ base_url: "https://veureu-asr.hf.space"
152
  client: "gradio"
153
+ predict_route: "/predict"
154
 
 
155
  http:
156
+ timeout_seconds: 180
157
  retries: 3
158
  backoff_seconds: 2.0
159
 
160
  security:
 
161
  use_hf_token: true
162
+ hf_token_env: "HF_TOKEN"
163
  allow_insecure_tls: false
164
 
165
  logging:
166
+ level: "INFO"
167
  json: false
llm_router.py CHANGED
@@ -1,22 +1,18 @@
1
- import os
2
- # ============================
3
- # File: llm_router.py
4
- # ============================
5
  from __future__ import annotations
6
  from typing import Any, Dict, List, Optional
7
  from pathlib import Path
 
8
  import yaml
9
 
10
  from remote_clients import InstructClient, VisionClient, ToolsClient, ASRClient
11
 
12
-
13
  def load_yaml(path: str) -> Dict[str, Any]:
14
  p = Path(path)
15
  if not p.exists():
16
  return {}
17
  return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
18
 
19
-
20
  class LLMRouter:
21
  def __init__(self, cfg: Dict[str, Any]):
22
  self.cfg = cfg
@@ -29,9 +25,9 @@ class LLMRouter:
29
  def mk(endpoint_key: str, cls):
30
  info = eps.get(endpoint_key, {})
31
  base_url = info.get("base_url") or f"https://{base_user}-{info.get('space')}.hf.space"
32
- client = cls(base_url=base_url, use_gradio=(info.get("client", "gradio") == "gradio"), hf_token=hf_token,
33
- timeout=int(cfg.get("remote_spaces", {}).get("http", {}).get("timeout_seconds", 120)))
34
- return client
35
 
36
  self.clients = {
37
  "salamandra-instruct": mk("salamandra-instruct", InstructClient),
@@ -44,8 +40,6 @@ class LLMRouter:
44
  def instruct(self, prompt: str, system: Optional[str] = None, model: str = "salamandra-instruct", **kwargs) -> str:
45
  if model in self.rem:
46
  return self.clients[model].generate(prompt, system=system, **kwargs) # type: ignore
47
- # fallback local (p. ej., gpt-4o-mini o gpt-oss vía tu API local si existiera)
48
- # Aquí podrías integrar una API OpenAI-compatible si la tienes.
49
  raise RuntimeError(f"Modelo local no implementado para: {model}")
50
 
51
  # ---- VISION ----
 
1
+ # llm_router.py — enruta llamadas a Spaces remotos según config.yaml
 
 
 
2
  from __future__ import annotations
3
  from typing import Any, Dict, List, Optional
4
  from pathlib import Path
5
+ import os
6
  import yaml
7
 
8
  from remote_clients import InstructClient, VisionClient, ToolsClient, ASRClient
9
 
 
10
  def load_yaml(path: str) -> Dict[str, Any]:
11
  p = Path(path)
12
  if not p.exists():
13
  return {}
14
  return yaml.safe_load(p.read_text(encoding="utf-8")) or {}
15
 
 
16
  class LLMRouter:
17
  def __init__(self, cfg: Dict[str, Any]):
18
  self.cfg = cfg
 
25
  def mk(endpoint_key: str, cls):
26
  info = eps.get(endpoint_key, {})
27
  base_url = info.get("base_url") or f"https://{base_user}-{info.get('space')}.hf.space"
28
+ use_gradio = (info.get("client", "gradio") == "gradio")
29
+ timeout = int(cfg.get("remote_spaces", {}).get("http", {}).get("timeout_seconds", 180))
30
+ return cls(base_url=base_url, use_gradio=use_gradio, hf_token=hf_token, timeout=timeout)
31
 
32
  self.clients = {
33
  "salamandra-instruct": mk("salamandra-instruct", InstructClient),
 
40
  def instruct(self, prompt: str, system: Optional[str] = None, model: str = "salamandra-instruct", **kwargs) -> str:
41
  if model in self.rem:
42
  return self.clients[model].generate(prompt, system=system, **kwargs) # type: ignore
 
 
43
  raise RuntimeError(f"Modelo local no implementado para: {model}")
44
 
45
  # ---- VISION ----
scripts/remote_clients.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # remote_clients.py — clientes para Spaces remotos (Gradio/HTTP)
2
+ from __future__ import annotations
3
+ from typing import Any, Dict, List, Optional
4
+ import os, json, requests
5
+ from tenacity import retry, stop_after_attempt, wait_exponential
6
+
7
+ try:
8
+ from gradio_client import Client as GradioClient
9
+ except Exception:
10
+ GradioClient = None # type: ignore
11
+
12
+ class BaseRemoteClient:
13
+ def __init__(self, base_url: str, use_gradio: bool = True, hf_token: Optional[str] = None, timeout: int = 180):
14
+ self.base_url = base_url.rstrip("/")
15
+ self.use_gradio = use_gradio and GradioClient is not None
16
+ self.hf_token = hf_token or os.getenv("HF_TOKEN")
17
+ self.timeout = timeout
18
+ self._client = None
19
+ if self.use_gradio:
20
+ headers = {"Authorization": f"Bearer {self.hf_token}"} if self.hf_token else None
21
+ self._client = GradioClient(self.base_url, hf_token=self.hf_token, headers=headers)
22
+
23
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=8))
24
+ def _post_json(self, route: str, payload: Dict[str, Any]) -> Dict[str, Any]:
25
+ url = f"{self.base_url}{route}"
26
+ headers = {"Authorization": f"Bearer {self.hf_token}"} if self.hf_token else {}
27
+ r = requests.post(url, json=payload, headers=headers, timeout=self.timeout)
28
+ r.raise_for_status()
29
+ return r.json()
30
+
31
+ class InstructClient(BaseRemoteClient):
32
+ def generate(self, prompt: str, system: Optional[str] = None, **kwargs) -> str:
33
+ if self.use_gradio and self._client:
34
+ out = self._client.predict(prompt, api_name="/predict")
35
+ return str(out)
36
+ data = {"prompt": prompt, "system": system, **kwargs}
37
+ res = self._post_json("/generate", data)
38
+ return res.get("text", "")
39
+
40
+ class VisionClient(BaseRemoteClient):
41
+ def describe(self, image_paths: List[str], context: Optional[Dict[str, Any]] = None, **kwargs) -> List[str]:
42
+ if self.use_gradio and self._client:
43
+ out = self._client.predict(image_paths, json.dumps(context or {}), api_name="/predict")
44
+ if isinstance(out, str):
45
+ try:
46
+ return json.loads(out)
47
+ except Exception:
48
+ return [out]
49
+ return list(out)
50
+ data = {"images": image_paths, "context": context or {}, **kwargs}
51
+ res = self._post_json("/describe", data)
52
+ return res.get("descriptions", [])
53
+
54
+ class ToolsClient(BaseRemoteClient):
55
+ def chat(self, messages: List[Dict[str, str]], tools: Optional[List[Dict[str, Any]]] = None, **kwargs) -> Dict[str, Any]:
56
+ if self.use_gradio and self._client:
57
+ out = self._client.predict(json.dumps(messages), json.dumps(tools or []), api_name="/predict")
58
+ if isinstance(out, str):
59
+ try:
60
+ return json.loads(out)
61
+ except Exception:
62
+ return {"text": out}
63
+ return out
64
+ data = {"messages": messages, "tools": tools or [], **kwargs}
65
+ return self._post_json("/chat", data)
66
+
67
+ class ASRClient(BaseRemoteClient):
68
+ def transcribe(self, audio_path: str, **kwargs) -> Dict[str, Any]:
69
+ if self.use_gradio and self._client:
70
+ out = self._client.predict(audio_path, api_name="/predict")
71
+ if isinstance(out, str):
72
+ return {"text": out}
73
+ return out
74
+ files = {"file": open(audio_path, "rb")}
75
+ headers = {"Authorization": f"Bearer {self.hf_token}"} if self.hf_token else {}
76
+ r = requests.post(f"{self.base_url}/transcribe", files=files, data=kwargs, headers=headers, timeout=self.timeout)
77
+ r.raise_for_status()
78
+ return r.json()