import io import os import shutil import tempfile from typing import Dict, List, Tuple import numpy as np import gradio as gr from PIL import Image import tensorflow as tf from huggingface_hub import hf_hub_download import spaces # ========================= # VARIÁVEIS DE MODELO # ========================= MODEL_URL = os.environ.get("MODEL_URL", "").strip() MODEL_REPO = os.environ.get("MODEL_REPO", "").strip() MODEL_REPO_TYPE = os.environ.get("MODEL_REPO_TYPE", "model").strip() MODEL_V2_FILE = os.environ.get("MODEL_FILE", "raspagem_2025_antes_depois.pb").strip() MODEL_V1_FILE = os.environ.get("MODELO_V1", "").strip() LABELS_FILE = os.environ.get("LABELS_FILE", "labels.txt").strip() IMG_SIZE = int(os.environ.get("IMG_SIZE", "224")) TOPK = int(os.environ.get("TOPK", "0")) # ========================= # Lazy state # ========================= _SERVING_V1 = None _SERVING_V2 = None _LABELS_V1: List[str] = [] _LABELS_V2: List[str] = [] _LAST_INIT_ERROR_V1 = None _LAST_INIT_ERROR_V2 = None # ========================= # Utilitários # ========================= def _download_from_url(url: str) -> str: import requests resp = requests.get(url, timeout=60) resp.raise_for_status() tmp_dir = tempfile.mkdtemp(prefix="raspagem_dl_") local = os.path.join(tmp_dir, os.path.basename(url) or "saved_model.pb") with open(local, "wb") as f: f.write(resp.content) return local def _prepare_saved_model_dir(pb_path: str) -> str: tmp_dir = tempfile.mkdtemp(prefix="raspagem_savedmodel_") shutil.copy(pb_path, os.path.join(tmp_dir, "saved_model.pb")) return tmp_dir def _load_model_from_file(pb_file: str) -> tf.types.experimental.ConcreteFunction: sm_dir = _prepare_saved_model_dir(pb_file) model = tf.saved_model.load(sm_dir) serving = model.signatures.get("serving_default") if serving is None: raise RuntimeError(f"Modelo {pb_file} sem assinatura 'serving_default'.") return serving def _maybe_labels() -> List[str]: try: if LABELS_FILE: if MODEL_REPO: p = hf_hub_download( repo_id=MODEL_REPO, filename=LABELS_FILE, repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", ) else: p = LABELS_FILE with open(p, "r", encoding="utf-8") as f: return [x.strip() for x in f if x.strip()] except Exception as e: print(f"[labels] ignorando erro: {e}") return [] # ========================= # Inicialização # ========================= def _init_v1() -> Tuple[bool, str]: global _SERVING_V1, _LABELS_V1, _LAST_INIT_ERROR_V1 if _SERVING_V1 is not None: return True, "ok" try: if not os.path.exists(MODEL_V1_FILE): raise FileNotFoundError(f"MODELO_V1 não encontrado: {MODEL_V1_FILE}") _SERVING_V1 = _load_model_from_file(MODEL_V1_FILE) _LABELS_V1 = _maybe_labels() return True, "ok" except Exception as e: _LAST_INIT_ERROR_V1 = f"{type(e).__name__}: {e}" return False, _LAST_INIT_ERROR_V1 def _init_v2() -> Tuple[bool, str]: global _SERVING_V2, _LABELS_V2, _LAST_INIT_ERROR_V2 if _SERVING_V2 is not None: return True, "ok" try: if MODEL_URL: pb_path = _download_from_url(MODEL_URL) elif MODEL_REPO: pb_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_V2_FILE, repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", ) elif os.path.exists(MODEL_V2_FILE): pb_path = MODEL_V2_FILE else: raise FileNotFoundError("MODEL_FILE não encontrado") _SERVING_V2 = _load_model_from_file(pb_path) _LABELS_V2 = _maybe_labels() return True, "ok" except Exception as e: _LAST_INIT_ERROR_V2 = f"{type(e).__name__}: {e}" return False, _LAST_INIT_ERROR_V2 # ========================= # Processamento # ========================= def _preprocess_image_to_bytes(pil_img: Image.Image) -> bytes: img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE)) buf = io.BytesIO() img.save(buf, format="JPEG") return buf.getvalue() def _pretty_label(raw: str) -> str: s = (raw or "").strip().lower() m = { "necessario": "Necessário", "necessário": "Necessário", "nao_necessario": "Não necessário", "não_necessário": "Não necessário", "s1": "S1", "s2": "S2", "s3": "S3", "s4": "S4", "s5": "S5", "s6": "S6" } key = s.replace(" ", "").replace("ã", "a").replace("á", "a").replace("é", "e").replace("í", "i").replace("ó", "o").replace("ç", "c") return m.get(key, raw.strip().capitalize()) def _format_result(label: str, score: float, tipo: str) -> str: return f"{tipo}: {_pretty_label(label)} ({score * 100:.1f}%)" # ========================= # Inferência combinada # ========================= @spaces.GPU(duration=120) def infer(image: Image.Image): if image is None: raise ValueError("Envie uma imagem.") image_bytes = _preprocess_image_to_bytes(image) # V1 - Sextante ok1, err1 = _init_v1() if not ok1: raise RuntimeError(f"Erro ao carregar modelo V1: {err1}") res1 = _SERVING_V1( image_bytes=tf.convert_to_tensor([image_bytes]), key=tf.convert_to_tensor(["v1"]), ) scores1 = res1["scores"].numpy()[0] labels1 = [x.decode("utf-8") for x in res1["labels"].numpy()[0]] if "labels" in res1 else _LABELS_V1 i1 = int(np.argmax(scores1)) sextante = _format_result(labels1[i1], scores1[i1], "Sextante") # V2 - Necessidade ok2, err2 = _init_v2() if not ok2: raise RuntimeError(f"Erro ao carregar modelo V2: {err2}") res2 = _SERVING_V2( image_bytes=tf.convert_to_tensor([image_bytes]), key=tf.convert_to_tensor(["v2"]), ) scores2 = res2["scores"].numpy()[0] labels2 = [x.decode("utf-8") for x in res2["labels"].numpy()[0]] if "labels" in res2 else _LABELS_V2 i2 = int(np.argmax(scores2)) necessidade = _format_result(labels2[i2], scores2[i2], "Necessidade") return f"{sextante}\n{necessidade}" # ========================= # Gradio UI # ========================= demo = gr.Blocks(title="RaspagemTF - V1 + V2") with demo: gr.Markdown("## RaspagemTF — Inferência em dois modelos (Sextante, Necessidade)") with gr.Row(): img = gr.Image(type="pil", label="Imagem") res = gr.Textbox(label="Resultados", lines=4) btn = gr.Button("Rodar inferência") btn.click(fn=infer, inputs=img, outputs=res) if __name__ == "__main__": demo.queue() demo.launch()