Spaces:
Runtime error
Runtime error
| # baldhead.py | |
| import os | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| import gradio as gr | |
| # Keras imports (note: keras-contrib must be installed) | |
| import keras.backend as K | |
| from keras.layers import ( | |
| Input, | |
| Conv2D, | |
| UpSampling2D, | |
| LeakyReLU, | |
| GlobalAveragePooling2D, | |
| Dense, | |
| Reshape, | |
| Dropout, | |
| Concatenate, | |
| multiply, # ← Thêm import multiply | |
| ) | |
| from keras.models import Model | |
| from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization | |
| # RetinaFace + skimage for face alignment | |
| from retinaface import RetinaFace | |
| from skimage import transform as trans | |
| # Hugging Face Hub helper | |
| from huggingface_hub import hf_hub_download | |
| # --- Face‐alignment helpers (giống code gốc) --- | |
| image_size = [256, 256] | |
| src_landmarks = np.array([ | |
| [30.2946, 51.6963], | |
| [65.5318, 51.5014], | |
| [48.0252, 71.7366], | |
| [33.5493, 92.3655], | |
| [62.7299, 92.2041], | |
| ], dtype=np.float32) | |
| src_landmarks[:, 0] += 8.0 | |
| src_landmarks[:, 0] += 15.0 | |
| src_landmarks[:, 1] += 30.0 | |
| src_landmarks /= 112 | |
| src_landmarks *= 200 | |
| def list2array(values): | |
| return np.array(list(values)) | |
| def align_face(img: np.ndarray): | |
| """ | |
| Detect faces + landmarks in `img` via RetinaFace. | |
| Returns lists of aligned face patches (256×256 RGB), | |
| corresponding binary masks, and the transformation matrices. | |
| """ | |
| faces = RetinaFace.detect_faces(img) | |
| bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces]) | |
| landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces]) | |
| white_canvas = np.ones(img.shape, dtype=np.uint8) * 255 | |
| aligned_faces, masks, matrices = [], [], [] | |
| if bboxes.shape[0] > 0: | |
| for i in range(bboxes.shape[0]): | |
| dst = landmarks[i] # detected landmarks | |
| tform = trans.SimilarityTransform() | |
| tform.estimate(dst, src_landmarks) | |
| M = tform.params[0:2, :] | |
| warped_face = cv2.warpAffine( | |
| img, M, (image_size[1], image_size[0]), borderValue=0.0 | |
| ) | |
| warped_mask = cv2.warpAffine( | |
| white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0 | |
| ) | |
| aligned_faces.append(warped_face) | |
| masks.append(warped_mask) | |
| matrices.append(tform.params[0:3, :]) | |
| return aligned_faces, masks, matrices | |
| def put_face_back( | |
| orig_img: np.ndarray, | |
| processed_faces: list[np.ndarray], | |
| masks: list[np.ndarray], | |
| matrices: list[np.ndarray], | |
| ): | |
| """ | |
| Warp each processed face back onto the original `orig_img` | |
| using the inverse of the transformation matrices. | |
| """ | |
| result = orig_img.copy() | |
| h, w = orig_img.shape[:2] | |
| for i in range(len(processed_faces)): | |
| invM = np.linalg.inv(matrices[i])[0:2] | |
| warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0) | |
| mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0) | |
| binary_mask = (mask // 255).astype(np.uint8) | |
| # Composite: result = result * (1 - mask) + warped * mask | |
| result = result * (1 - binary_mask) | |
| result = result.astype(np.uint8) | |
| result = result + warped * binary_mask | |
| return result | |
| # ---------------------------- | |
| # 2. GENERATOR ARCHITECTURE | |
| # ---------------------------- | |
| def squeeze_excite_block(x, ratio=4): | |
| """ | |
| Squeeze-and-Excitation block: channel-wise attention. | |
| """ | |
| init = x | |
| channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | |
| filters = init.shape[channel_axis] | |
| se_shape = (1, 1, filters) | |
| se = GlobalAveragePooling2D()(init) | |
| se = Reshape(se_shape)(se) | |
| se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se) | |
| se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se) | |
| return multiply([init, se]) | |
| def conv2d(layer_input, filters, f_size=4, bn=True, se=False): | |
| """ | |
| Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block) | |
| """ | |
| d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input) | |
| d = LeakyReLU(alpha=0.2)(d) | |
| if bn: | |
| d = InstanceNormalization()(d) | |
| if se: | |
| d = squeeze_excite_block(d) | |
| return d | |
| def atrous(layer_input, filters, f_size=4, bn=True): | |
| """ | |
| Atrous (dilated) convolution block with dilation rates [2,4,8]. | |
| """ | |
| a_list = [] | |
| for rate in [2, 4, 8]: | |
| a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input) | |
| a_list.append(a) | |
| a = Concatenate()(a_list) | |
| a = LeakyReLU(alpha=0.2)(a) | |
| if bn: | |
| a = InstanceNormalization()(a) | |
| return a | |
| def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0): | |
| """ | |
| Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip) | |
| """ | |
| u = UpSampling2D(size=2)(layer_input) | |
| u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u) | |
| if dropout_rate: | |
| u = Dropout(dropout_rate)(u) | |
| u = InstanceNormalization()(u) | |
| u = Concatenate()([u, skip_input]) | |
| return u | |
| def build_generator(): | |
| """ | |
| Reconstruct the generator architecture exactly as in the notebook, | |
| then return a Keras Model object. | |
| """ | |
| d0 = Input(shape=(256, 256, 3)) | |
| gf = 64 | |
| # Downsampling | |
| d1 = conv2d(d0, gf, bn=False, se=True) | |
| d2 = conv2d(d1, gf * 2, se=True) | |
| d3 = conv2d(d2, gf * 4, se=True) | |
| d4 = conv2d(d3, gf * 8) | |
| d5 = conv2d(d4, gf * 8) | |
| # Atrous block | |
| a1 = atrous(d5, gf * 8) | |
| # Upsampling | |
| u3 = deconv2d(a1, d4, gf * 8) | |
| u4 = deconv2d(u3, d3, gf * 4) | |
| u5 = deconv2d(u4, d2, gf * 2) | |
| u6 = deconv2d(u5, d1, gf) | |
| # Final upsample + conv | |
| u7 = UpSampling2D(size=2)(u6) | |
| output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7) | |
| model = Model(d0, output_img) | |
| return model | |
| # ---------------------------- | |
| # 3. LOAD MODEL WEIGHTS | |
| # ---------------------------- | |
| HF_REPO_ID = "VanNguyen1214/baldhead" | |
| HF_FILENAME = "model_G_5_170.hdf5" | |
| HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
| def load_generator_from_hub(): | |
| """ | |
| Download the .hdf5 weights from HF Hub into cache, | |
| rebuild the generator, then load weights. | |
| """ | |
| local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN) | |
| gen = build_generator() | |
| gen.load_weights(local_path) | |
| return gen | |
| # Load once at startup | |
| try: | |
| GENERATOR = load_generator_from_hub() | |
| print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}") | |
| except Exception as e: | |
| print("[ERROR] Could not load generator:", e) | |
| GENERATOR = None | |
| # ---------------------------- | |
| # 4. INFERENCE FUNCTION | |
| # ---------------------------- | |
| def inference(image: Image.Image) -> Image.Image: | |
| """ | |
| Gradio-compatible inference function: | |
| - Convert PIL→ numpy RGB | |
| - Align faces | |
| - For each face: normalize to [-1,1], run through generator, denormalize to uint8 | |
| - Put processed faces back onto original image | |
| - Return full-image PIL | |
| """ | |
| if GENERATOR is None: | |
| return image | |
| orig = np.array(image.convert("RGB")) | |
| faces, masks, mats = align_face(orig) | |
| if len(faces) == 0: | |
| return image | |
| processed_faces = [] | |
| for face in faces: | |
| face_input = face.astype(np.float32) | |
| face_input = (face_input / 127.5) - 1.0 # scale to [-1,1] | |
| face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3) | |
| pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1] | |
| pred = ((pred + 1.0) * 127.5).astype(np.uint8) | |
| processed_faces.append(pred) | |
| output_np = put_face_back(orig, processed_faces, masks, mats) | |
| output_pil = Image.fromarray(output_np) | |
| return output_pil | |