dashingzombie commited on
Commit
76ae127
·
verified ·
1 Parent(s): 82e254c

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitignore +3 -1
  2. handler.py +100 -0
  3. requiremnts.txt +3 -1
.gitignore CHANGED
@@ -208,4 +208,6 @@ __marimo__/
208
 
209
  *.jpg`
210
  runs/
211
- data/
 
 
 
208
 
209
  *.jpg`
210
  runs/
211
+ data/
212
+
213
+ *yolov11-segmentation_earth-worm/
handler.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ # Hugging Face Inference Endpoints - Custom Handler for Ultralytics YOLOv11-seg
3
+ # Returns: {"instances":[{"label":str,"score":float,"polygon":[[x,y],...]},...],
4
+ # "width": int, "height": int}
5
+
6
+ import io
7
+ import base64
8
+ from typing import Any, Dict, List, Union
9
+
10
+ from PIL import Image
11
+ from huggingface_hub import hf_hub_download
12
+ from ultralytics import YOLO
13
+
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path: str = "."):
17
+ """
18
+ Called once on container startup.
19
+ `path` points to the repo root mounted in the endpoint.
20
+ """
21
+ # Resolve weights using Hub API to get the raw binary (handles LFS/private).
22
+ self.repo_id = "dashingzombie/yolov11-segmentation_earth-worm"
23
+ self.filename = "best.pt" # change if you prefer last.pt
24
+
25
+ weights_path = hf_hub_download(
26
+ repo_id=self.repo_id,
27
+ filename=self.filename,
28
+ repo_type="model"
29
+ )
30
+ self.model = YOLO(weights_path)
31
+
32
+ # If class names were not baked into the checkpoint, you can force them:
33
+ if not getattr(self.model, "names", None):
34
+ self.model.names = {0: "body_mask"} # single-class fallback
35
+
36
+ def _to_image(self, payload: Dict[str, Any]) -> Image.Image:
37
+ """
38
+ Accepts either:
39
+ - {"inputs": {"image": <base64-string>}} (Serverless-style)
40
+ - {"inputs": <base64-string>}
41
+ - {"image_bytes": <raw-bytes>} (Toolkit raw)
42
+ """
43
+ if "image_bytes" in payload:
44
+ return Image.open(io.BytesIO(payload["image_bytes"])).convert("RGB")
45
+
46
+ inputs = payload.get("inputs", payload.get("instances", None))
47
+ if isinstance(inputs, dict):
48
+ img_b64 = inputs.get("image") or inputs.get("img") or inputs.get("data")
49
+ else:
50
+ img_b64 = inputs
51
+
52
+ if isinstance(img_b64, str):
53
+ # strip possible 'data:image/jpeg;base64,' prefix
54
+ if "," in img_b64:
55
+ img_b64 = img_b64.split(",", 1)[1]
56
+ data = base64.b64decode(img_b64)
57
+ return Image.open(io.BytesIO(data)).convert("RGB")
58
+
59
+ raise ValueError("No image provided. Expected 'image_bytes' or base64 string under 'inputs'.")
60
+
61
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
62
+ """
63
+ Runs per request. `data` is the incoming JSON/body parsed by the Toolkit.
64
+ Returns JSON-serializable dict.
65
+ """
66
+ image = self._to_image(data)
67
+ W, H = image.size
68
+
69
+ # confidence threshold can be overridden via params
70
+ params = data.get("parameters", {}) or data.get("options", {})
71
+ conf = float(params.get("conf", 0.25))
72
+
73
+ results = self.model(image, conf=conf, verbose=False)[0]
74
+ names = results.names
75
+
76
+ instances: List[Dict[str, Any]] = []
77
+ if results.masks is not None:
78
+ # polygons per instance: results.masks.xy (list of Nx2 arrays)
79
+ for i, poly in enumerate(results.masks.xy):
80
+ cls_id = int(results.boxes.cls[i].item())
81
+ score = float(results.boxes.conf[i].item())
82
+ polygon = [[float(x), float(y)] for x, y in poly]
83
+ instances.append({
84
+ "label": names[cls_id],
85
+ "score": score,
86
+ "polygon": polygon
87
+ })
88
+ else:
89
+ # Fallback to boxes if masks missing (rare for -seg)
90
+ for i, b in enumerate(results.boxes.xyxy.tolist()):
91
+ x1, y1, x2, y2 = [float(v) for v in b]
92
+ cls_id = int(results.boxes.cls[i].item())
93
+ score = float(results.boxes.conf[i].item())
94
+ instances.append({
95
+ "label": names[cls_id],
96
+ "score": score,
97
+ "bbox_xyxy": [x1, y1, x2, y2]
98
+ })
99
+
100
+ return {"instances": instances, "width": W, "height": H}
requiremnts.txt CHANGED
@@ -1,4 +1,6 @@
1
  ultralytics>=8.3
2
  torch
3
  torchvision
4
- pillow
 
 
 
1
  ultralytics>=8.3
2
  torch
3
  torchvision
4
+ pillow
5
+ huggingface_hub
6
+ fastapi