# baldhead.py import os import cv2 import numpy as np from PIL import Image import tensorflow as tf import gradio as gr # Keras imports (note: keras-contrib must be installed) import keras.backend as K from keras.layers import ( Input, Conv2D, UpSampling2D, LeakyReLU, GlobalAveragePooling2D, Dense, Reshape, Dropout, Concatenate, multiply, # ← Thêm import multiply ) from keras.models import Model from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization # RetinaFace + skimage for face alignment from retinaface import RetinaFace from skimage import transform as trans # Hugging Face Hub helper from huggingface_hub import hf_hub_download # --- Face‐alignment helpers (giống code gốc) --- image_size = [256, 256] src_landmarks = np.array([ [30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041], ], dtype=np.float32) src_landmarks[:, 0] += 8.0 src_landmarks[:, 0] += 15.0 src_landmarks[:, 1] += 30.0 src_landmarks /= 112 src_landmarks *= 200 def list2array(values): return np.array(list(values)) def align_face(img: np.ndarray): """ Detect faces + landmarks in `img` via RetinaFace. Returns lists of aligned face patches (256×256 RGB), corresponding binary masks, and the transformation matrices. """ faces = RetinaFace.detect_faces(img) bboxes = np.array([list2array(faces[f]['facial_area']) for f in faces]) landmarks = np.array([list2array(faces[f]['landmarks'].values()) for f in faces]) white_canvas = np.ones(img.shape, dtype=np.uint8) * 255 aligned_faces, masks, matrices = [], [], [] if bboxes.shape[0] > 0: for i in range(bboxes.shape[0]): dst = landmarks[i] # detected landmarks tform = trans.SimilarityTransform() tform.estimate(dst, src_landmarks) M = tform.params[0:2, :] warped_face = cv2.warpAffine( img, M, (image_size[1], image_size[0]), borderValue=0.0 ) warped_mask = cv2.warpAffine( white_canvas, M, (image_size[1], image_size[0]), borderValue=0.0 ) aligned_faces.append(warped_face) masks.append(warped_mask) matrices.append(tform.params[0:3, :]) return aligned_faces, masks, matrices def put_face_back( orig_img: np.ndarray, processed_faces: list[np.ndarray], masks: list[np.ndarray], matrices: list[np.ndarray], ): """ Warp each processed face back onto the original `orig_img` using the inverse of the transformation matrices. """ result = orig_img.copy() h, w = orig_img.shape[:2] for i in range(len(processed_faces)): invM = np.linalg.inv(matrices[i])[0:2] warped = cv2.warpAffine(processed_faces[i], invM, (w, h), borderValue=0.0) mask = cv2.warpAffine(masks[i], invM, (w, h), borderValue=0.0) binary_mask = (mask // 255).astype(np.uint8) # Composite: result = result * (1 - mask) + warped * mask result = result * (1 - binary_mask) result = result.astype(np.uint8) result = result + warped * binary_mask return result # ---------------------------- # 2. GENERATOR ARCHITECTURE # ---------------------------- def squeeze_excite_block(x, ratio=4): """ Squeeze-and-Excitation block: channel-wise attention. """ init = x channel_axis = 1 if K.image_data_format() == "channels_first" else -1 filters = init.shape[channel_axis] se_shape = (1, 1, filters) se = GlobalAveragePooling2D()(init) se = Reshape(se_shape)(se) se = Dense(filters // ratio, activation="relu", kernel_initializer="he_normal", use_bias=False)(se) se = Dense(filters, activation="sigmoid", kernel_initializer="he_normal", use_bias=False)(se) return multiply([init, se]) def conv2d(layer_input, filters, f_size=4, bn=True, se=False): """ Downsampling block: Conv2D → LeakyReLU → (InstanceNorm) → (SE block) """ d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same")(layer_input) d = LeakyReLU(alpha=0.2)(d) if bn: d = InstanceNormalization()(d) if se: d = squeeze_excite_block(d) return d def atrous(layer_input, filters, f_size=4, bn=True): """ Atrous (dilated) convolution block with dilation rates [2,4,8]. """ a_list = [] for rate in [2, 4, 8]: a = Conv2D(filters, f_size, dilation_rate=rate, padding="same")(layer_input) a_list.append(a) a = Concatenate()(a_list) a = LeakyReLU(alpha=0.2)(a) if bn: a = InstanceNormalization()(a) return a def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0): """ Upsampling block: UpSampling2D → Conv2D → (Dropout) → InstanceNorm → Concatenate(skip) """ u = UpSampling2D(size=2)(layer_input) u = Conv2D(filters, kernel_size=f_size, strides=1, padding="same", activation="relu")(u) if dropout_rate: u = Dropout(dropout_rate)(u) u = InstanceNormalization()(u) u = Concatenate()([u, skip_input]) return u def build_generator(): """ Reconstruct the generator architecture exactly as in the notebook, then return a Keras Model object. """ d0 = Input(shape=(256, 256, 3)) gf = 64 # Downsampling d1 = conv2d(d0, gf, bn=False, se=True) d2 = conv2d(d1, gf * 2, se=True) d3 = conv2d(d2, gf * 4, se=True) d4 = conv2d(d3, gf * 8) d5 = conv2d(d4, gf * 8) # Atrous block a1 = atrous(d5, gf * 8) # Upsampling u3 = deconv2d(a1, d4, gf * 8) u4 = deconv2d(u3, d3, gf * 4) u5 = deconv2d(u4, d2, gf * 2) u6 = deconv2d(u5, d1, gf) # Final upsample + conv u7 = UpSampling2D(size=2)(u6) output_img = Conv2D(3, kernel_size=4, strides=1, padding="same", activation="tanh")(u7) model = Model(d0, output_img) return model # ---------------------------- # 3. LOAD MODEL WEIGHTS # ---------------------------- HF_REPO_ID = "VanNguyen1214/baldhead" HF_FILENAME = "model_G_5_170.hdf5" HF_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] def load_generator_from_hub(): """ Download the .hdf5 weights from HF Hub into cache, rebuild the generator, then load weights. """ local_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME,token=HF_TOKEN) gen = build_generator() gen.load_weights(local_path) return gen # Load once at startup try: GENERATOR = load_generator_from_hub() print(f"[INFO] Loaded generator weights from {HF_REPO_ID}/{HF_FILENAME}") except Exception as e: print("[ERROR] Could not load generator:", e) GENERATOR = None # ---------------------------- # 4. INFERENCE FUNCTION # ---------------------------- def inference(image: Image.Image) -> Image.Image: """ Gradio-compatible inference function: - Convert PIL→ numpy RGB - Align faces - For each face: normalize to [-1,1], run through generator, denormalize to uint8 - Put processed faces back onto original image - Return full-image PIL """ if GENERATOR is None: return image orig = np.array(image.convert("RGB")) faces, masks, mats = align_face(orig) if len(faces) == 0: return image processed_faces = [] for face in faces: face_input = face.astype(np.float32) face_input = (face_input / 127.5) - 1.0 # scale to [-1,1] face_input = np.expand_dims(face_input, axis=0) # (1,256,256,3) pred = GENERATOR.predict(face_input)[0] # (256,256,3) in [-1,1] pred = ((pred + 1.0) * 127.5).astype(np.uint8) processed_faces.append(pred) output_np = put_face_back(orig, processed_faces, masks, mats) output_pil = Image.fromarray(output_np) return output_pil