Dranina-demo / app.py
Clemylia's picture
Update app.py
6d5cb89 verified
# app.py
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F # 🚨 CORRECTION : Nécessaire pour F.pad
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
import sys
import os
# --- 1. DÉFINITION DE L'ARCHITECTURE DRANINA (U-NET) ---
# Copie exacte de l'architecture entraînée pour garantir la compatibilité
# Bloc de Convolution de Base
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# Bloc d'Augmentation (Upsampling)
class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 🚨 CORRECTION : Assurez-vous que l'Upsampling correspond exactement à l'entraînement
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = double_conv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# Gestion des bords (padding) via torch.nn.functional
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# Concaténation le long de la dimension des canaux (dim=1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
# Le Modèle Dranina U-Net Complet
class DraninaUnet(nn.Module):
def __init__(self, n_channels=3, n_classes=3):
super(DraninaUnet, self).__init__()
# Encodeur
self.inc = double_conv(n_channels, 64)
self.down1 = nn.MaxPool2d(2); self.conv1 = double_conv(64, 128)
self.down2 = nn.MaxPool2d(2); self.conv2 = double_conv(128, 256)
self.down3 = nn.MaxPool2d(2); self.conv3 = double_conv(256, 512)
self.down4 = nn.MaxPool2d(2);
self.conv4 = double_conv(512, 1024)
# Décodeur
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# Chemin avant (skip connections)
x1 = self.inc(x)
x2 = self.down1(x1); x2 = self.conv1(x2)
x3 = self.down2(x2); x3 = self.conv2(x3)
x4 = self.down3(x3); x4 = self.conv3(x4)
x5 = self.down4(x4)
x5 = self.conv4(x5) # Bas du U
# Chemin retour (avec skip connections)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
# --- 2. INITIALISATION ET CHARGEMENT DU MODÈLE ---
MODEL_ID = "Clemylia/Dranina-Mandala-Colorizer" # L'ID de ton dépôt de modèle publié
MODEL_FILENAME = "pytorch_model.bin"
IMAGE_SIZE = 256
DEVICE = torch.device("cpu")
model = None
try:
# Télécharger le fichier de poids du modèle depuis le Hub
model_path = hf_hub_download(repo_id=MODEL_ID, filename=MODEL_FILENAME)
print(f"✅ Modèle téléchargé depuis le Hub : {model_path}")
# Instancier le modèle et charger les poids
model = DraninaUnet(n_channels=3, n_classes=3)
# L'argument map_location est CRUCIAL pour que ça tourne sur CPU (le standard du Space Gratuit)
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("✅ Modèle Dranina prêt pour la prédiction.")
except Exception as e:
print(f"❌ Erreur lors du chargement du modèle Dranina : {e}", file=sys.stderr)
# --- 3. FONCTION DE PRÉDICTION ---
def colorize_mandala(input_image_pil):
"""
Prend une image PIL (non coloriée), la passe dans le modèle et retourne
l'image PIL coloriée.
"""
if model is None:
return Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color = 'red')
# Enregistre la taille originale pour la redimensionner à la fin (meilleur rendu)
original_size = input_image_pil.size
# 1. Préparation de l'image (Transformation)
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
])
# Assurez-vous que l'entrée est en RGB (3 canaux)
input_tensor = transform(input_image_pil.convert('RGB')).unsqueeze(0).to(DEVICE)
# 2. Inférence
with torch.no_grad():
output_tensor = model(input_tensor)
# 3. Post-traitement (Conversion en image PIL)
# Clamp pour s'assurer que les valeurs sont entre 0 et 1
output_tensor = torch.clamp(output_tensor.squeeze(0), 0, 1)
# Convertir le tensor (C, H, W) en PIL Image
output_image = transforms.ToPILImage()(output_tensor.cpu())
# Redimensionner l'image de sortie à la taille de l'entrée originale
output_image = output_image.resize(original_size)
return output_image
# --- 4. INTERFACE GRADIO ---
if model is not None:
# Description pour le Space
title = "🎨 Dranina : Mandala Colorizer"
description = (
"Ceci est une démonstration du modèle **Dranina**, entraîné sur notre dataset pour colorier automatiquement des mandalas. "
"Téléversez une image de mandala en noir et blanc pour voir la prédiction du modèle."
)
# Création de l'interface
iface = gr.Interface(
fn=colorize_mandala,
inputs=gr.Image(type="pil", label="Mandala Non Colorié (Entrée)"),
outputs=gr.Image(type="pil", label="Mandala Colorié (Prédiction)"),
title=title,
description=description,
# Ajoute des exemples ici si tu en as dans le Space
# examples=["votre_dossier_dans_le_space/exemple1.jpg"],
allow_flagging="auto",
)
# Lancement de l'application
iface.launch()
else:
print("Application non lancée car le modèle n'a pas pu être chargé. Vérifiez l'ID du modèle et les dépendances.")