Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import shutil | |
| import tempfile | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| # ========================= | |
| # Config (via Variables) | |
| # ========================= | |
| # Onde buscar o modelo (.pb): | |
| # 1) MODEL_URL (http/https) | |
| # 2) MODEL_REPO + MODEL_FILE (+ MODEL_REPO_TYPE: model|space) | |
| # 3) Caminho local (MODEL_FILE) na raiz do Space | |
| MODEL_URL = os.environ.get("MODEL_URL", "").strip() | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "").strip() # ex: "vcollos/raspagemTF" ou "spaces/vcollos/raspagem_supra" | |
| MODEL_REPO_TYPE = os.environ.get("MODEL_REPO_TYPE", "model").strip() # "model" ou "space" | |
| MODEL_FILE = os.environ.get("MODEL_FILE", "raspagem_2025_antes_depois.pb").strip() | |
| LABELS_FILE = os.environ.get("LABELS_FILE", "labels.txt").strip() | |
| IMG_SIZE = int(os.environ.get("IMG_SIZE", "224")) | |
| TOPK = int(os.environ.get("TOPK", "0")) # 0 = lista tudo | |
| # ========================= | |
| # Download/resolve SavedModel (.pb) e lazy init | |
| # ========================= | |
| def _download_from_url(url: str) -> str: | |
| import requests | |
| resp = requests.get(url, timeout=60) | |
| resp.raise_for_status() | |
| tmp_dir = tempfile.mkdtemp(prefix="raspagem_dl_") | |
| local = os.path.join(tmp_dir, os.path.basename(url) or "saved_model.pb") | |
| with open(local, "wb") as f: | |
| f.write(resp.content) | |
| return local | |
| def _download_model() -> str: | |
| # Prioridade: URL -> HF repo -> arquivo local | |
| if MODEL_URL: | |
| return _download_from_url(MODEL_URL) | |
| if MODEL_REPO: | |
| try: | |
| return hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILE, | |
| repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", | |
| ) | |
| except Exception as e: | |
| print(f"[download] HF hub falhou: {e}") | |
| if os.path.exists(MODEL_FILE): | |
| return MODEL_FILE | |
| raise FileNotFoundError( | |
| "Modelo não encontrado. Defina MODEL_URL OU (MODEL_REPO, MODEL_REPO_TYPE, MODEL_FILE) OU deixe o arquivo na raiz do Space." | |
| ) | |
| def _prepare_saved_model_dir(pb_path: str) -> str: | |
| # SavedModel mínimo: diretório contendo 'saved_model.pb' | |
| tmp_dir = tempfile.mkdtemp(prefix="raspagem_savedmodel_") | |
| shutil.copy(pb_path, os.path.join(tmp_dir, "saved_model.pb")) | |
| return tmp_dir | |
| # Lazy state | |
| _SERVING_FN = None | |
| _LABELS: List[str] = [] | |
| _LAST_INIT_ERROR: str | None = None | |
| def _maybe_labels() -> List[str]: | |
| try: | |
| if LABELS_FILE: | |
| if MODEL_REPO: | |
| p = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=LABELS_FILE, | |
| repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", | |
| ) | |
| else: | |
| p = LABELS_FILE | |
| with open(p, "r", encoding="utf-8") as f: | |
| return [x.strip() for x in f if x.strip()] | |
| except Exception as e: | |
| print(f"[labels] ignorando erro: {e}") | |
| return [] | |
| def _init_once() -> Tuple[bool, str]: | |
| global _SERVING_FN, _LABELS, _LAST_INIT_ERROR | |
| if _SERVING_FN is not None: | |
| return True, "ok" | |
| try: | |
| pb_local = _download_model() | |
| sm_dir = _prepare_saved_model_dir(pb_local) | |
| model = tf.saved_model.load(sm_dir) | |
| serving = model.signatures.get("serving_default") | |
| if serving is None: | |
| raise RuntimeError("SavedModel sem assinatura 'serving_default'.") | |
| _SERVING_FN = serving | |
| _LABELS = _maybe_labels() | |
| _LAST_INIT_ERROR = None | |
| return True, "ok" | |
| except Exception as e: | |
| _LAST_INIT_ERROR = f"{type(e).__name__}: {e}" | |
| return False, _LAST_INIT_ERROR | |
| # ========================= | |
| # Pré/Pós-processamento | |
| # ========================= | |
| def _preprocess_image_to_bytes(pil_img: Image.Image) -> bytes: | |
| img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE)) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG") | |
| return buf.getvalue() | |
| def _pretty_label(raw: str) -> str: | |
| s = (raw or "").strip().lower() | |
| m = { | |
| "necessario": "Necessário", | |
| "necessário": "Necessário", | |
| "nao_necessario": "Não necessário", | |
| "não_necessário": "Não necessário", | |
| "s1": "S1", | |
| "s2": "S2", | |
| "s3": "S3", | |
| } | |
| # remove acentos/espacos no inicio se vier com variações | |
| key = s.replace(" ", "").replace("ã", "a").replace("á", "a").replace("é", "e").replace("í", "i").replace("ó", "o").replace("ç", "c") | |
| return m.get(key, raw.strip().capitalize()) | |
| def _format_bars(labels: List[str], scores: np.ndarray, topk: int) -> str: | |
| # Ordena desc, aplica topk (0 = tudo), desenha barras de 20 colunas | |
| idxs = np.argsort(scores)[::-1] | |
| if topk and topk > 0: | |
| idxs = idxs[:topk] | |
| lines = [] | |
| for i in idxs: | |
| pct = float(scores[i]) * 100.0 | |
| bar_len = max(1, int(scores[i] * 20)) | |
| bar = "█" * bar_len | |
| label = _pretty_label(labels[i] if i < len(labels) and labels[i] else ( _LABELS[i] if i < len(_LABELS) else f"class_{i}" )) | |
| lines.append(f"{label}: {pct:.1f}% {bar}") | |
| return "\n".join(lines) | |
| # ========================= | |
| # UI functions | |
| # ========================= | |
| def _signature_info() -> Dict[str, Dict[str, str]]: | |
| ok, err = _init_once() | |
| if not ok: | |
| return {"init_error": err} | |
| inputs = {k: str(v) for k, v in _SERVING_FN.structured_input_signature[1].items()} | |
| outputs = {k: str(v) for k, v in _SERVING_FN.structured_outputs.items()} | |
| return {"inputs": inputs, "outputs": outputs} | |
| def _diagnostics() -> Dict[str, object]: | |
| ok, err = _init_once() | |
| return { | |
| "ok": ok, | |
| "error": err if not ok else None, | |
| "env": { | |
| "MODEL_URL": MODEL_URL or None, | |
| "MODEL_REPO": MODEL_REPO or None, | |
| "MODEL_REPO_TYPE": MODEL_REPO_TYPE, | |
| "MODEL_FILE": MODEL_FILE, | |
| "IMG_SIZE": IMG_SIZE, | |
| "TOPK": TOPK, | |
| }, | |
| } | |
| def infer(image: Image.Image): | |
| if image is None: | |
| raise ValueError("Envie uma imagem.") | |
| ok, err = _init_once() | |
| if not ok: | |
| raise RuntimeError(f"Modelo não inicializado: {err}") | |
| image_bytes = _preprocess_image_to_bytes(image) | |
| result = _SERVING_FN( | |
| image_bytes=tf.convert_to_tensor([image_bytes]), | |
| key=tf.convert_to_tensor(["0"]), | |
| ) | |
| scores_t = result.get("scores") | |
| labels_t = result.get("labels") | |
| if scores_t is None: | |
| raise KeyError("Saída 'scores' não encontrada na assinatura do modelo.") | |
| scores = scores_t.numpy()[0] | |
| labels: List[str] = [] | |
| if labels_t is not None: | |
| labels = [x.decode("utf-8") for x in labels_t.numpy()[0]] | |
| return _format_bars(labels, scores, TOPK) | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| demo = gr.Blocks(title="RaspagemTF - SavedModel (.pb)") | |
| with demo: | |
| gr.Markdown("## RaspagemTF — Inferência (SavedModel .pb)") | |
| with gr.Row(): | |
| img = gr.Image(type="pil", label="Imagem") | |
| res = gr.Textbox(label="Resultados", lines=8) | |
| btn = gr.Button("Rodar inferência") | |
| btn.click(fn=infer, inputs=img, outputs=res) | |
| with gr.Accordion("Diagnóstico", open=False): | |
| d_btn = gr.Button("Rodar diagnóstico") | |
| d_out = gr.JSON() | |
| d_btn.click(fn=_diagnostics, inputs=None, outputs=d_out) | |
| def _gpu_diag(): | |
| return { | |
| "tf_version": tf.__version__, | |
| "gpus_detected": [str(g) for g in tf.config.list_physical_devices('GPU')] | |
| } | |
| g_btn = gr.Button("Checar GPU") | |
| g_out = gr.JSON() | |
| g_btn.click(fn=_gpu_diag, inputs=None, outputs=g_out) | |
| with gr.Accordion("Assinaturas do modelo", open=False): | |
| s_btn = gr.Button("Mostrar assinatura") | |
| s_out = gr.JSON() | |
| s_btn.click(fn=_signature_info, inputs=None, outputs=s_out) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() | |