Clemylia commited on
Commit
fb4e7b9
·
verified ·
1 Parent(s): 74e7c11

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F # 🚨 CORRECTION : Nécessaire pour F.pad
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ from huggingface_hub import hf_hub_download
10
+ import sys
11
+ import os
12
+
13
+ # --- 1. DÉFINITION DE L'ARCHITECTURE DRANINA (U-NET) ---
14
+ # Copie exacte de l'architecture entraînée pour garantir la compatibilité
15
+
16
+ # Bloc de Convolution de Base
17
+ def double_conv(in_channels, out_channels):
18
+ return nn.Sequential(
19
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.ReLU(inplace=True),
22
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
23
+ nn.BatchNorm2d(out_channels),
24
+ nn.ReLU(inplace=True)
25
+ )
26
+
27
+ # Bloc d'Augmentation (Upsampling)
28
+ class Up(nn.Module):
29
+ def __init__(self, in_channels, out_channels):
30
+ super().__init__()
31
+ # 🚨 CORRECTION : Assurez-vous que l'Upsampling correspond exactement à l'entraînement
32
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
33
+ self.conv = double_conv(in_channels, out_channels)
34
+
35
+ def forward(self, x1, x2):
36
+ x1 = self.up(x1)
37
+
38
+ # Gestion des bords (padding) via torch.nn.functional
39
+ diffY = x2.size()[2] - x1.size()[2]
40
+ diffX = x2.size()[3] - x1.size()[3]
41
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
42
+ diffY // 2, diffY - diffY // 2])
43
+
44
+ # Concaténation le long de la dimension des canaux (dim=1)
45
+ x = torch.cat([x2, x1], dim=1)
46
+ return self.conv(x)
47
+
48
+ # Le Modèle Dranina U-Net Complet
49
+ class DraninaUnet(nn.Module):
50
+ def __init__(self, n_channels=3, n_classes=3):
51
+ super(DraninaUnet, self).__init__()
52
+
53
+ # Encodeur
54
+ self.inc = double_conv(n_channels, 64)
55
+ self.down1 = nn.MaxPool2d(2); self.conv1 = double_conv(64, 128)
56
+ self.down2 = nn.MaxPool2d(2); self.conv2 = double_conv(128, 256)
57
+ self.down3 = nn.MaxPool2d(2); self.conv3 = double_conv(256, 512)
58
+ self.down4 = nn.MaxPool2d(2);
59
+ self.conv4 = double_conv(512, 1024)
60
+
61
+ # Décodeur
62
+ self.up1 = Up(1024, 512)
63
+ self.up2 = Up(512, 256)
64
+ self.up3 = Up(256, 128)
65
+ self.up4 = Up(128, 64)
66
+
67
+ self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
68
+
69
+ def forward(self, x):
70
+ # Chemin avant (skip connections)
71
+ x1 = self.inc(x)
72
+ x2 = self.down1(x1); x2 = self.conv1(x2)
73
+ x3 = self.down2(x2); x3 = self.conv2(x3)
74
+ x4 = self.down3(x3); x4 = self.conv3(x4)
75
+ x5 = self.down4(x4)
76
+
77
+ x5 = self.conv4(x5) # Bas du U
78
+
79
+ # Chemin retour (avec skip connections)
80
+ x = self.up1(x5, x4)
81
+ x = self.up2(x, x3)
82
+ x = self.up3(x, x2)
83
+ x = self.up4(x, x1)
84
+
85
+ logits = self.outc(x)
86
+ return logits
87
+
88
+ # --- 2. INITIALISATION ET CHARGEMENT DU MODÈLE ---
89
+
90
+ MODEL_ID = "Clemylia/Dranina-Mandala-Colorizer" # L'ID de ton dépôt de modèle publié
91
+ MODEL_FILENAME = "pytorch_model.bin"
92
+ IMAGE_SIZE = 256
93
+ DEVICE = torch.device("cpu")
94
+
95
+ model = None
96
+ try:
97
+ # Télécharger le fichier de poids du modèle depuis le Hub
98
+ model_path = hf_hub_download(repo_id=MODEL_ID, filename=MODEL_FILENAME)
99
+ print(f"✅ Modèle téléchargé depuis le Hub : {model_path}")
100
+
101
+ # Instancier le modèle et charger les poids
102
+ model = DraninaUnet(n_channels=3, n_classes=3)
103
+ # L'argument map_location est CRUCIAL pour que ça tourne sur CPU (le standard du Space Gratuit)
104
+ state_dict = torch.load(model_path, map_location=DEVICE)
105
+ model.load_state_dict(state_dict)
106
+ model.to(DEVICE)
107
+ model.eval()
108
+ print("✅ Modèle Dranina prêt pour la prédiction.")
109
+
110
+ except Exception as e:
111
+ print(f"❌ Erreur lors du chargement du modèle Dranina : {e}", file=sys.stderr)
112
+
113
+ # --- 3. FONCTION DE PRÉDICTION ---
114
+
115
+ def colorize_mandala(input_image_pil):
116
+ """
117
+ Prend une image PIL (non coloriée), la passe dans le modèle et retourne
118
+ l'image PIL coloriée.
119
+ """
120
+ if model is None:
121
+ return Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color = 'red')
122
+
123
+ # Enregistre la taille originale pour la redimensionner à la fin (meilleur rendu)
124
+ original_size = input_image_pil.size
125
+
126
+ # 1. Préparation de l'image (Transformation)
127
+ transform = transforms.Compose([
128
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
129
+ transforms.ToTensor(),
130
+ ])
131
+
132
+ # Assurez-vous que l'entrée est en RGB (3 canaux)
133
+ input_tensor = transform(input_image_pil.convert('RGB')).unsqueeze(0).to(DEVICE)
134
+
135
+ # 2. Inférence
136
+ with torch.no_grad():
137
+ output_tensor = model(input_tensor)
138
+
139
+ # 3. Post-traitement (Conversion en image PIL)
140
+ # Clamp pour s'assurer que les valeurs sont entre 0 et 1
141
+ output_tensor = torch.clamp(output_tensor.squeeze(0), 0, 1)
142
+
143
+ # Convertir le tensor (C, H, W) en PIL Image
144
+ output_image = transforms.ToPILImage()(output_tensor.cpu())
145
+
146
+ # Redimensionner l'image de sortie à la taille de l'entrée originale
147
+ output_image = output_image.resize(original_size)
148
+
149
+ return output_image
150
+
151
+ # --- 4. INTERFACE GRADIO ---
152
+
153
+ if model is not None:
154
+ # Description pour le Space
155
+ title = "🎨 Dranina : Mandala Colorizer"
156
+ description = (
157
+ "Ceci est une démonstration du modèle **Dranina U-Net**, entraîné sur votre dataset pour colorier automatiquement des mandalas. "
158
+ "Téléversez une image de mandala en noir et blanc pour voir la prédiction du modèle."
159
+ )
160
+
161
+ # Création de l'interface
162
+ iface = gr.Interface(
163
+ fn=colorize_mandala,
164
+ inputs=gr.Image(type="pil", label="Mandala Non Colorié (Entrée)"),
165
+ outputs=gr.Image(type="pil", label="Mandala Colorié (Prédiction)"),
166
+ title=title,
167
+ description=description,
168
+ # Ajoute des exemples ici si tu en as dans le Space
169
+ # examples=["votre_dossier_dans_le_space/exemple1.jpg"],
170
+ allow_flagging="auto",
171
+ )
172
+
173
+ # Lancement de l'application
174
+ iface.launch()
175
+ else:
176
+ print("Application non lancée car le modèle n'a pas pu être chargé. Vérifiez l'ID du modèle et les dépendances.")