Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms import InterpolationMode | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from model import Model | |
| # Load Model | |
| model_path = hf_hub_download( | |
| repo_id="itserr/exvoto_classifier_convnext_base_224", | |
| filename="model.pt" | |
| ) | |
| model = Model('convnext_base') | |
| ckpt = torch.load(model_path, map_location=torch.device("cpu")) # Ensure compatibility | |
| model.load_state_dict(ckpt['model']) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model.to(device) | |
| model.eval() | |
| # Image Transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize(size=(224,224), interpolation=InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Classification Function | |
| def classify_img(img, threshold): | |
| classification_threshold = threshold | |
| img_tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| pred = model(img_tensor) | |
| score = torch.sigmoid(pred).item() | |
| # Determine Prediction | |
| if score >= classification_threshold: | |
| label = "β This is an **Ex-Voto** image!" | |
| else: | |
| label = "β This is **NOT** an Ex-Voto image." | |
| # Format Confidence Score | |
| confidence = f"The probability that the image is an ex-voto is: {score:.2%}" | |
| return label, confidence | |
| example_images = [['examples/exvoto1.jpg', None], | |
| ['examples/exvoto2.jpg', None], | |
| ['examples/nonexvoto1.jpg', None], | |
| ['examples/nonexvoto2.jpg', None], | |
| ['examples/natural1.jpg', None], | |
| ['examples/natural2.jpg', None],] | |
| # Function to Clear Outputs When a New Image is Uploaded | |
| def clear_outputs(img): | |
| return gr.update(value=""), gr.update(value="") | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Ex-Voto Image Classifier") | |
| gr.Markdown("πΈ **Upload an image** to check if it's an **Ex-Voto** painting!") | |
| with gr.Row(): | |
| with gr.Column(scale=2): # Left section: Image upload & slider | |
| img_input = gr.Image(type="pil") | |
| threshold_slider = gr.Slider( | |
| minimum=0.5, maximum=1.0, value=0.7, step=0.1, label="Classification Threshold" | |
| ) | |
| submit_btn = gr.Button("Classify") | |
| with gr.Column(scale=1): # Right section: Prediction & Confidence | |
| prediction_output = gr.Textbox(label="Prediction", interactive=False) | |
| confidence_output = gr.Textbox(label="Confidence Score", interactive=False) | |
| # Clear outputs when a new image is uploaded | |
| img_input.change(fn=clear_outputs, inputs=[img_input], outputs=[prediction_output, confidence_output]) | |
| # Submit button triggers classification | |
| submit_btn.click(fn=classify_img, inputs=[img_input, threshold_slider], outputs=[prediction_output, confidence_output]) | |
| # Example images (Only show images, no threshold value) | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=[img_input] | |
| ) | |
| # Launch App | |
| demo.launch() |