Spaces:
Build error
Build error
| import base64 | |
| import io | |
| import random | |
| from io import BytesIO | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| from datasets import load_dataset | |
| import gradio as gr | |
| from score_db import Battle | |
| from score_db import Model as ModelEnum, Winner | |
| def make_plot(seismic, predicted_image): | |
| fig, ax = plt.subplots(1, 1, figsize=(10, 10)) | |
| ax.imshow(Image.fromarray(seismic), cmap="gray") | |
| ax.imshow(predicted_image, cmap="Reds", alpha=0.5, vmin=0, vmax=1) | |
| ax.set_axis_off() | |
| fig.canvas.draw() | |
| # Create a bytes buffer to save the plot | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| # Open the PNG image from the buffer and convert it to a NumPy array | |
| image = np.array(Image.open(buf)) | |
| return image | |
| def call_endpoint(model: ModelEnum, img_array, url: str="https://lukasmosser--seisbase-endpoints-predict.modal.run"): | |
| response = requests.post(url, json={"img": img_array.tolist(), "model": model}) | |
| if response: | |
| # Parse the base64-encoded image data | |
| if response.text.startswith("data:image/tiff;base64,"): | |
| img_data_out = base64.b64decode(response.text.split(",")[1]) | |
| predicted_image = np.array(Image.open(BytesIO(img_data_out))) | |
| return predicted_image | |
| def select_random_image(dataset): | |
| idx = random.randint(0, len(dataset)) | |
| return idx, np.array(dataset[idx]["seismic"]) | |
| def select_random_models(): | |
| model_a = random.choice(list(ModelEnum)) | |
| model_b = random.choice(list(ModelEnum)) | |
| return model_a, model_b | |
| # Create a Gradio interface | |
| with gr.Blocks() as evaluation: | |
| gr.Markdown(""" | |
| ## Seismic Fault Detection Model Evaluation | |
| This application allows you to compare the performance of different seismic fault detection models. | |
| Two models are selected randomly, and their predictions are displayed side by side. | |
| You can choose the better model or mark it as a tie. The results are recorded and used to update the model ratings. | |
| """) | |
| battle = gr.State([]) | |
| radio = gr.Radio(choices=["Less than 5 years", "5 to 20 years", "more than 20 years"], label="How much experience do you have in seismic fault interpretation?") | |
| with gr.Row(): | |
| output_img1 = gr.Image(label="Model A Image") | |
| output_img2 = gr.Image(label="Model B Image") | |
| def show_images(): | |
| dataset = load_dataset("porestar/crossdomainfoundationmodeladaption-deepfault", split="valid") | |
| idx, image_1 = select_random_image(dataset) | |
| model_a, model_b = select_random_models() | |
| fault_probability_1 = call_endpoint(model_a, image_1) | |
| fault_probability_2 = call_endpoint(model_b, image_1) | |
| img_1 = make_plot(image_1, fault_probability_1) | |
| img_2 = make_plot(image_1, fault_probability_2) | |
| experience = 1 | |
| if radio.value == "5 to 20 years": | |
| experience = 2 | |
| elif radio.value == "more than 20 years": | |
| experience = 3 | |
| battle.value.append(Battle(model_a=model_a, model_b=model_b, winner="tie", judge="None", experience=experience, image_idx=idx)) | |
| return img_1, img_2 | |
| # Define the function to make an API call | |
| def make_api_call(choice: Winner): | |
| api_url = "https://lukasmosser--seisbase-eval-add-battle.modal.run" | |
| battle_out = battle.value | |
| battle_out[-1].winner = choice | |
| experience = 1 | |
| if radio.value == "5 to 20 years": | |
| experience = 2 | |
| elif radio.value == "more than 20 years": | |
| experience = 3 | |
| battle_out[-1].experience = experience | |
| response = requests.post(api_url, json=battle_out[-1].dict()) | |
| # Load images on startup | |
| evaluation.load(show_images, inputs=[], outputs=[output_img1, output_img2]) | |
| with gr.Row(): | |
| btn_winner_a = gr.Button("Winner Model A") | |
| btn_tie = gr.Button("Tie") | |
| btn_winner_b = gr.Button("Winner Model B") | |
| # Define button click events | |
| btn_winner_a.click(lambda: make_api_call(Winner.model_a), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2]) | |
| btn_tie.click(lambda: make_api_call(Winner.tie), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2]) | |
| btn_winner_b.click(lambda: make_api_call(Winner.model_b), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2]) | |
| with gr.Blocks() as leaderboard: | |
| def get_results(): | |
| response = requests.get("https://lukasmosser--seisbase-eval-compute-ratings.modal.run") | |
| data = response.json() | |
| models = [entry["model"] for entry in data] | |
| elo_ratings = [entry["elo_rating"] for entry in data] | |
| fig, ax = plt.subplots() | |
| ax.barh(models, elo_ratings, color='skyblue') | |
| ax.set_xlabel('ELO Rating') | |
| ax.set_title('Model ELO Ratings') | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| # Create a bytes buffer to save the plot | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| # Open the PNG image from the buffer and convert it to a NumPy array | |
| image = np.array(Image.open(buf)) | |
| return image | |
| with gr.Row(): | |
| elo_ratings = gr.Image(label="ELO Ratings") | |
| leaderboard.load(get_results, inputs=[], outputs=[elo_ratings]) | |
| demo = gr.TabbedInterface([evaluation, leaderboard], ["Arena", "Leaderboard"]) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |