Authentica / detector.py
MAS-AI-0000's picture
Update detector.py
7830aec verified
"""Core embedding-based detector.
Loads the DETree KNN database and exposes ``detect_embedding``, which accepts
a single pre-computed, L2-normalised embedding vector and returns a prediction.
All modality-specific logic (text, image) lives in separate embedder modules:
- text_embedder.py β†’ str β†’ np.ndarray
- image_embedder.py β†’ PIL.Image β†’ np.ndarray
Usage::
from Apps.detector import detect_embedding
from Apps.text_embedder import get_text_embedding
from Apps.image_embedder import get_image_embedding
emb = get_text_embedding("Some text here")
result = detect_embedding(emb)
# {"predicted_class": "Human"|"Ai", "confidence": 0.95}
emb = get_image_embedding(pil_image)
result = detect_embedding(emb, mode="image")
# {"predicted_class": "Real"|"AI", "confidence": 0.88}
"""
from __future__ import annotations
import logging
import os
import sys
from typing import Optional
import numpy as np
import torch
from huggingface_hub import hf_hub_download
log = logging.getLogger("detector")
logging.basicConfig(level=logging.INFO, format="%(levelname)s [%(name)s] %(message)s")
# ---------------------------------------------------------------------------
# Make the local 'detree' package importable
# ---------------------------------------------------------------------------
_current_dir = os.path.dirname(os.path.abspath(__file__))
if _current_dir not in sys.path:
sys.path.append(_current_dir)
try:
from detree.utils.index import Indexer
log.info("Indexer imported successfully.")
except ImportError as _e:
log.error(f"Could not import detree Indexer: {_e} β€” detection will return fallback responses.")
Indexer = None
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
REPO_ID = "MAS-AI-0000/Authentica"
_DB_PATH = hf_hub_download(
repo_id=REPO_ID,
filename="Lib/Models/MultiModal/merged_multimodal.pt",
)
log.info(f"[paths] _DB_PATH = {_DB_PATH!r} exists={os.path.exists(_DB_PATH)}")
# ---------------------------------------------------------------------------
# Hyperparameters (match values used during database construction)
# ---------------------------------------------------------------------------
TOP_K = 10
THRESHOLD = 0.97
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _load_database(path: str):
log.info(f"_load_database: loading from {path!r} ...")
data = torch.load(path, map_location="cpu")
embeddings = data["embeddings"]
labels = data["labels"]
ids = data["ids"]
classes = data["classes"]
log.info(f"_load_database: classes={list(classes)} "
f"embedding keys={list(embeddings.keys()) if isinstance(embeddings, dict) else type(embeddings).__name__}")
if not isinstance(embeddings, dict):
raise ValueError("Expected embeddings to be a dict keyed by layer index.")
return embeddings, labels, ids, classes
def _to_numpy(value) -> np.ndarray:
if isinstance(value, np.ndarray):
return value
if torch.is_tensor(value):
return value.detach().cpu().numpy()
return np.asarray(value)
# ---------------------------------------------------------------------------
# Module-level initialisation
# ---------------------------------------------------------------------------
_index: Optional[object] = None
_human_index: Optional[int] = None
_classes: list = []
_embedding_dim: int = 0
def _init() -> None:
global _index, _human_index, _classes, _embedding_dim
log.info("_init: starting Detector initialisation.")
if Indexer is None:
log.error("_init: Indexer is None β€” check import error above. Detection disabled.")
return
if not os.path.exists(_DB_PATH):
log.error(f"_init: database not found at {_DB_PATH!r} β€” detection disabled.")
return
try:
embeddings, labels, ids, classes = _load_database(_DB_PATH)
_classes = list(classes)
log.info(f"_init: available classes={_classes}")
if "human" not in _classes:
raise ValueError("Database must include a 'human' class entry.")
_human_index = _classes.index("human")
log.info(f"_init: human_index={_human_index}")
# Layer embeddings keyed by int layer index
layer_embeddings = {int(k): v.float() for k, v in embeddings.items()}
available_layers = sorted(layer_embeddings.keys())
active_layer = available_layers[-1] # last layer by default
log.info(f"_init: available layers={available_layers} using active_layer={active_layer}")
# Resolve per-layer or shared label / id tensors
if isinstance(labels, dict):
layer_labels = _to_numpy(labels[active_layer]).astype(np.int64)
else:
layer_labels = _to_numpy(labels).astype(np.int64)
if isinstance(ids, dict):
layer_ids = _to_numpy(ids[active_layer]).astype(np.int64)
else:
layer_ids = _to_numpy(ids).astype(np.int64)
train_embs = _to_numpy(layer_embeddings[active_layer]).astype(np.float32)
_embedding_dim = train_embs.shape[-1]
log.info(f"_init: train_embs shape={train_embs.shape} embedding_dim={_embedding_dim}")
log.info(f"_init: label distribution β€” "
f"human={int((layer_labels == _human_index).sum())} "
f"ai={int((layer_labels != _human_index).sum())}")
label_dict = {
int(idx): (1 if int(lbl) == int(_human_index) else 0)
for idx, lbl in zip(layer_ids.tolist(), layer_labels.tolist())
}
_index = Indexer(_embedding_dim)
_index.label_dict = label_dict
_index.index_data(layer_ids.tolist(), train_embs)
log.info(f"_init: Indexer built β€” layer={active_layer} dim={_embedding_dim} "
f"entries={len(layer_ids)}")
except Exception as exc:
log.exception(f"_init: error initialising database: {exc}")
_init()
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def detect_embedding(
embedding: np.ndarray,
*,
top_k: int = TOP_K,
threshold: float = THRESHOLD,
) -> dict:
"""Classify a single pre-computed, L2-normalised embedding via KNN.
Args:
embedding: 1-D or (1, dim) float32 numpy array already projected into
the DETree embedding space (L2-normalised).
mode: ``"text"`` returns labels ``"Human"`` / ``"Ai"``.
``"image"`` returns labels ``"Real"`` / ``"AI"``.
top_k: Number of nearest neighbours to consider.
threshold: Probability above which the sample is labelled Human/Real.
Returns:
``{"predicted_class": int, "confidence": float}``
"""
fallback_class = 0
if _index is None:
log.error("detect_embedding: _index is None β€” returning fallback. Check _init logs.")
return {"predicted_class": fallback_class, "confidence": 0.0}
emb = np.asarray(embedding, dtype=np.float32).reshape(1, -1)
log.info(f"detect_embedding: query embedding shape={emb.shape} norm={float(np.linalg.norm(emb)):.4f} "
f"top_k={top_k} threshold={threshold}")
try:
results = _index.search_knn(
emb,
top_k,
index_batch_size=max(1, min(top_k, 128)),
)
_ids, scores, labels_knn = results[0]
log.info(f"detect_embedding: neighbour ids={_ids}")
log.info(f"detect_embedding: neighbour scores={[round(float(s), 4) for s in scores]}")
log.info(f"detect_embedding: neighbour labels={labels_knn} "
f"(1=human, 0=ai)")
scores_t = torch.from_numpy(np.asarray(scores))
weights = torch.softmax(scores_t, dim=0)
label_t = torch.tensor(labels_knn, dtype=torch.float32)
prob_human = float(torch.clamp(torch.dot(weights, label_t), 0.0, 1.0).item())
prob_ai = float(max(0.0, min(1.0, 1.0 - prob_human)))
predicted_class = 1 if prob_human >= threshold else 0
confidence = prob_human if predicted_class == 1 else prob_ai
log.info(f"detect_embedding: prob_human={prob_human:.4f} prob_ai={prob_ai:.4f} "
f"predicted_class={predicted_class} confidence={confidence:.4f}")
except Exception as exc:
log.exception(f"detect_embedding: failed during KNN search: {exc}")
return {"predicted_class": fallback_class, "confidence": 0.0}
return {"predicted_class": predicted_class, "confidence": confidence}