File size: 1,116 Bytes
8dff9a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np

# Load SegFormer for hair segmentation
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
model     = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")

def extract_hair(image: Image.Image) -> Image.Image:
    """

    Return an RGBA image where hair pixels have alpha=255 and

    all other pixels have alpha=0.

    """
    rgb = image.convert("RGB")
    arr = np.array(rgb)
    h, w = arr.shape[:2]

    # Segment hair
    inputs = processor(images=rgb, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits.cpu()
    up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
    seg = up.argmax(dim=1)[0].numpy()
    hair_mask = (seg == 2).astype(np.uint8)

    # Build RGBA
    alpha = (hair_mask * 255).astype(np.uint8)
    rgba  = np.dstack([arr, alpha])
    return Image.fromarray(rgba)