engine / identity_manager.py
VeuReu's picture
Upload 3 files
05bd568 verified
raw
history blame
4.96 kB
# =========================
# File: identity_manager.py
# =========================
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
if TYPE_CHECKING:
from chromadb.api.models.Collection import Collection
class IdentityManager:
"""
Encapsula toda la l贸gica de asignaci贸n de identidades (caras + voces)
y su proyecci贸n sobre frames, clips y SRT.
"""
def __init__(self, face_collection: Optional["Collection"] = None, voice_collection: Optional["Collection"] = None):
self.face_collection = face_collection
self.voice_collection = voice_collection
# --------------------------- Faces / Frames ---------------------------
def assign_faces_to_frames(
self,
frames: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
`frames` es una lista de dicts con al menos: {image_path, start, end, faces:[{embedding?, bbox?}]}
Devuelve los mismos frames con `faces` enriquecidos con `identity` y `distance` si hay DB.
"""
if self.face_collection is None:
return frames
out = []
for fr in frames:
faces = fr.get("faces") or []
enr: List[Dict[str, Any]] = []
for f in faces:
emb = f.get("embedding") or f.get("vector")
if not emb:
enr.append(f)
continue
try:
q = self.face_collection.query(query_embeddings=[emb], n_results=1, include=["metadatas", "distances"]) # type: ignore
metas = q.get("metadatas", [[]])[0]
dists = q.get("distances", [[]])[0]
if metas:
md = metas[0] or {}
f = dict(f)
f["identity"] = md.get("identity") or md.get("name")
if dists:
f["distance"] = float(dists[0])
except Exception:
pass
enr.append(f)
fr2 = dict(fr)
fr2["faces"] = enr
out.append(fr2)
return out
# --------------------------- Voices / Segments ------------------------
def assign_voices_to_segments(
self,
audio_segments: List[Dict[str, Any]],
distance_threshold: Optional[float] = None,
) -> List[Dict[str, Any]]:
"""
A帽ade `voice_vecinos` y `voice_identity` a cada segmento si hay colecci贸n de voz.
"""
if self.voice_collection is None:
return audio_segments
out = []
for a in audio_segments:
emb = a.get("voice_embedding")
if not emb:
out.append(a)
continue
try:
q = self.voice_collection.query(query_embeddings=[emb], n_results=3, include=["metadatas", "distances"]) # type: ignore
metas = q.get("metadatas", [[]])[0]
dists = q.get("distances", [[]])[0]
vecinos = []
top_id = None
top_dist = None
for m, d in zip(metas, dists):
name = (m or {}).get("identity") or (m or {}).get("name")
vecinos.append({"identity": name, "distance": float(d)})
if top_id is None:
top_id, top_dist = name, float(d)
a2 = dict(a)
a2["voice_vecinos"] = vecinos
if top_id is not None:
if distance_threshold is None or (top_dist is not None and top_dist <= distance_threshold):
a2["voice_identity"] = top_id
out.append(a2)
except Exception:
out.append(a)
return out
# --------------------------- Map to SRT/Timelines ---------------------
@staticmethod
def map_identities_over_ranges(
per_second_frames: List[Dict[str, Any]],
ranges: List[Dict[str, Any]],
key: str = "faces",
out_key: str = "persona",
) -> List[Dict[str, Any]]:
"""
Para cada rango temporal (keyframes, audio_segments, etc.), agrega qui茅n aparece seg煤n los frames por segundo.
"""
out: List[Dict[str, Any]] = []
for rng in ranges:
s, e = float(rng.get("start", 0.0)), float(rng.get("end", 0.0))
present = []
for fr in per_second_frames:
fs, fe = float(fr.get("start", 0.0)), float(fr.get("end", 0.0))
if fe <= s or fs >= e:
continue
for f in fr.get(key) or []:
ident = f.get("identity")
if ident and ident not in present:
present.append(ident)
r2 = dict(rng)
r2[out_key] = present
out.append(r2)
return out