Spaces:
Sleeping
Sleeping
File size: 6,256 Bytes
fb4e7b9 6d5cb89 fb4e7b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# app.py
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F # 🚨 CORRECTION : Nécessaire pour F.pad
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import sys
import os
# --- 1. DÉFINITION DE L'ARCHITECTURE DRANINA (U-NET) ---
# Copie exacte de l'architecture entraînée pour garantir la compatibilité
# Bloc de Convolution de Base
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# Bloc d'Augmentation (Upsampling)
class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 🚨 CORRECTION : Assurez-vous que l'Upsampling correspond exactement à l'entraînement
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = double_conv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# Gestion des bords (padding) via torch.nn.functional
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# Concaténation le long de la dimension des canaux (dim=1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
# Le Modèle Dranina U-Net Complet
class DraninaUnet(nn.Module):
def __init__(self, n_channels=3, n_classes=3):
super(DraninaUnet, self).__init__()
# Encodeur
self.inc = double_conv(n_channels, 64)
self.down1 = nn.MaxPool2d(2); self.conv1 = double_conv(64, 128)
self.down2 = nn.MaxPool2d(2); self.conv2 = double_conv(128, 256)
self.down3 = nn.MaxPool2d(2); self.conv3 = double_conv(256, 512)
self.down4 = nn.MaxPool2d(2);
self.conv4 = double_conv(512, 1024)
# Décodeur
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# Chemin avant (skip connections)
x1 = self.inc(x)
x2 = self.down1(x1); x2 = self.conv1(x2)
x3 = self.down2(x2); x3 = self.conv2(x3)
x4 = self.down3(x3); x4 = self.conv3(x4)
x5 = self.down4(x4)
x5 = self.conv4(x5) # Bas du U
# Chemin retour (avec skip connections)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
# --- 2. INITIALISATION ET CHARGEMENT DU MODÈLE ---
MODEL_ID = "Clemylia/Dranina-Mandala-Colorizer" # L'ID de ton dépôt de modèle publié
MODEL_FILENAME = "pytorch_model.bin"
IMAGE_SIZE = 256
DEVICE = torch.device("cpu")
model = None
try:
# Télécharger le fichier de poids du modèle depuis le Hub
model_path = hf_hub_download(repo_id=MODEL_ID, filename=MODEL_FILENAME)
print(f"✅ Modèle téléchargé depuis le Hub : {model_path}")
# Instancier le modèle et charger les poids
model = DraninaUnet(n_channels=3, n_classes=3)
# L'argument map_location est CRUCIAL pour que ça tourne sur CPU (le standard du Space Gratuit)
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("✅ Modèle Dranina prêt pour la prédiction.")
except Exception as e:
print(f"❌ Erreur lors du chargement du modèle Dranina : {e}", file=sys.stderr)
# --- 3. FONCTION DE PRÉDICTION ---
def colorize_mandala(input_image_pil):
"""
Prend une image PIL (non coloriée), la passe dans le modèle et retourne
l'image PIL coloriée.
"""
if model is None:
return Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color = 'red')
# Enregistre la taille originale pour la redimensionner à la fin (meilleur rendu)
original_size = input_image_pil.size
# 1. Préparation de l'image (Transformation)
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
])
# Assurez-vous que l'entrée est en RGB (3 canaux)
input_tensor = transform(input_image_pil.convert('RGB')).unsqueeze(0).to(DEVICE)
# 2. Inférence
with torch.no_grad():
output_tensor = model(input_tensor)
# 3. Post-traitement (Conversion en image PIL)
# Clamp pour s'assurer que les valeurs sont entre 0 et 1
output_tensor = torch.clamp(output_tensor.squeeze(0), 0, 1)
# Convertir le tensor (C, H, W) en PIL Image
output_image = transforms.ToPILImage()(output_tensor.cpu())
# Redimensionner l'image de sortie à la taille de l'entrée originale
output_image = output_image.resize(original_size)
return output_image
# --- 4. INTERFACE GRADIO ---
if model is not None:
# Description pour le Space
title = "🎨 Dranina : Mandala Colorizer"
description = (
"Ceci est une démonstration du modèle **Dranina**, entraîné sur notre dataset pour colorier automatiquement des mandalas. "
"Téléversez une image de mandala en noir et blanc pour voir la prédiction du modèle."
)
# Création de l'interface
iface = gr.Interface(
fn=colorize_mandala,
inputs=gr.Image(type="pil", label="Mandala Non Colorié (Entrée)"),
outputs=gr.Image(type="pil", label="Mandala Colorié (Prédiction)"),
title=title,
description=description,
# Ajoute des exemples ici si tu en as dans le Space
# examples=["votre_dossier_dans_le_space/exemple1.jpg"],
allow_flagging="auto",
)
# Lancement de l'application
iface.launch()
else:
print("Application non lancée car le modèle n'a pas pu être chargé. Vérifiez l'ID du modèle et les dépendances.") |