import contextlib, io, base64, torch, json from PIL import Image import open_clip from reparam import reparameterize_model class EndpointHandler: def __init__(self, path: str = ""): self.device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Load the model (happens only once at startup) model, _, self.preprocess = open_clip.create_model_and_transforms( "MobileCLIP-B", pretrained='datacompdr' ) model.eval() self.model = reparameterize_model(model) tokenizer = open_clip.get_tokenizer("MobileCLIP-B") self.model.to(self.device) if self.device == "cuda": self.model.to(torch.float16) # --- OPTIMIZATION: Pre-compute text features from your JSON --- # 2. Load your rich class definitions from the file with open(f"{path}/items.json", "r", encoding="utf-8") as f: class_definitions = json.load(f) # 3. Prepare the data for encoding and for the final response # - Use the 'prompt' field for creating the embeddings # - Keep 'name' and 'id' to structure the response later prompts = [item['prompt'] for item in class_definitions] self.class_ids = [item['id'] for item in class_definitions] self.class_names = [item['name'] for item in class_definitions] # 4. Tokenize and encode all prompts at once with torch.no_grad(): text_tokens = tokenizer(prompts).to(self.device) self.text_features = self.model.encode_text(text_tokens) self.text_features /= self.text_features.norm(dim=-1, keepdim=True) def __call__(self, data): # The payload only needs the image now payload = data.get("inputs", data) img_b64 = payload["image"] # ---------------- decode image ---------------- image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) if self.device == "cuda": img_tensor = img_tensor.to(torch.float16) # ---------------- forward pass (very fast) ----------------- with torch.no_grad(): # 1. Encode only the image img_feat = self.model.encode_image(img_tensor) img_feat /= img_feat.norm(dim=-1, keepdim=True) # 2. Compute similarity against the pre-computed text features probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0] # 3. Combine the results with your stored class IDs and names # and convert the tensor of probabilities to a list of floats results = zip(self.class_ids, self.class_names, probs.cpu().tolist()) # 4. Create a sorted list of dictionaries for a clean JSON response return sorted( [{"id": i, "label": name, "score": float(p)} for i, name, p in results], key=lambda x: x["score"], reverse=True ) # """ # MobileCLIP‑B Zero‑Shot Image Classifier (Hugging Face Inference Endpoint) # =========================================================================== # * One container instance is created per replica; the `EndpointHandler` # object below is instantiated exactly **once** at start‑up. # * At request time (`__call__`) we receive a base‑64‑encoded image, run a # **single forward pass**, and return class probabilities. # Design choices # -------------- # 1. **Model & transform come from OpenCLIP** # This guarantees we apply **identical preprocessing** to what the model # was trained with (224 × 224 crop + mean/std normalisation). # 2. **Re‑parameterisation for inference** # MobileCLIP uses MobileOne blocks that have extra convolution branches # for training; `reparameterize_model` fuses them so inference is fast # and deterministic. # 3. **Text embeddings are cached** # The class “prompts” (e.g. `"a photo of a cat"`) are encoded **once at # start‑up**. Each request therefore encodes *only* the image and # performs a single matrix multiplication. # 4. **Mixed precision on GPU** # If the container has CUDA, we cast the model **and** inputs to # `float16`. That halves memory and roughly doubles throughput on most # modern GPUs. On CPU we stay in `float32` for numerical stability. # """ # import contextlib, io, base64, json # from pathlib import Path # from typing import Any, Dict, List # import torch # from PIL import Image # import open_clip # from reparam import reparameterize_model # local copy (~60 LoC) of Apple’s helper # class EndpointHandler: # """ # Hugging Face entry‑point. The toolkit will instantiate this class # once and call it for every HTTP request. # Parameters # ---------- # path : str, optional # Root directory of the repository. HF mounts the code under # `/repository`; we use this path to locate `items.json`. # """ # # ------------------------------------------------------------------ # # # INITIALISATION (runs **once**) # # # ------------------------------------------------------------------ # # def __init__(self, path: str = "") -> None: # self.device = "cuda" if torch.cuda.is_available() else "cpu" # # 1️⃣ Load MobileCLIP‑B weights & transforms ------------------- # # `pretrained="datacompdr"` makes OpenCLIP download the # # official checkpoint from the Hub (cached in the image layer). # model, _, self.preprocess = open_clip.create_model_and_transforms( # "MobileCLIP-B", pretrained="datacompdr" # ) # model.eval() # disable dropout / BN updates # model = reparameterize_model(model) # fuse MobileOne branches # model.to(self.device) # if self.device == "cuda": # model = model.to(torch.float16) # FP16 for throughput # self.model = model # hold a reference # # 2️⃣ Build the tokenizer once -------------------------------- # tokenizer = open_clip.get_tokenizer("MobileCLIP-B") # # 3️⃣ Load class metadata ------------------------------------- # # Expect JSON file: [{"id": 3, "name": "cat", "prompt": "cat"}, …] # items_path = Path(path) / "items.json" # with items_path.open("r", encoding="utf-8") as f: # class_defs: List[Dict[str, Any]] = json.load(f) # # Extract the bits we need later # prompts = [item["prompt"] for item in class_defs] # self.class_ids: List[int] = [item["id"] for item in class_defs] # self.class_names: List[str] = [item["name"] for item in class_defs] # # 4️⃣ Encode all prompts once --------------------------------- # with torch.no_grad(): # text_tokens = tokenizer(prompts).to(self.device) # text_feats = self.model.encode_text(text_tokens) # text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) # self.text_features = text_feats # [num_classes, 512] # # ------------------------------------------------------------------ # # # INFERENCE CALL # # # ------------------------------------------------------------------ # # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # """ # Parameters # ---------- # data : dict # Either the raw payload `{"image": ""}` **or** the # Hugging Face convention `{"inputs": {...}}`. # Returns # ------- # list of dict # Sorted list of `{"id": int, "label": str, "score": float}`. # Scores are the softmax probabilities over the *provided* # class list (they sum to 1.0). # """ # # 1️⃣ Unpack the request payload ------------------------------ # payload: Dict[str, Any] = data.get("inputs", data) # img_b64: str = payload["image"] # # 2️⃣ Decode + preprocess ------------------------------------- # image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") # img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # [1, 3, 224, 224] # if self.device == "cuda": # img_tensor = img_tensor.to(torch.float16) # # 3️⃣ Forward pass (image only) ------------------------------- # with torch.no_grad(): # no autograd graph # img_feat = self.model.encode_image(img_tensor) # [1, 512] # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # L2‑normalise # # cosine similarity → logits → softmax probabilities # probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0] # [num_classes] # # 4️⃣ Assemble JSON‑serialisable response --------------------- # results = zip(self.class_ids, self.class_names, probs.cpu().tolist()) # return sorted( # [{"id": cid, "label": name, "score": float(p)} for cid, name, p in results], # key=lambda x: x["score"], # reverse=True, # )