Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from constants import * | |
| from mpl_data_plotter import MatplotlibDataPlotter | |
| def convert_int64_to_int32(df): | |
| for col in df.columns: | |
| if df[col].dtype == 'int64': | |
| df[col] = df[col].astype('int32') | |
| return df | |
| def create_color_legend(class_to_color): | |
| # Create HTML for the color legend | |
| legend_html = """ | |
| <div style=" | |
| margin: 10px 0; | |
| padding: 10px; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| background: white; | |
| "> | |
| <div style=" | |
| font-weight: bold; | |
| margin-bottom: 8px; | |
| ">Color Legend:</div> | |
| <div style=" | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 15px; | |
| align-items: center; | |
| "> | |
| """ | |
| # Add each class and its color | |
| for class_name, color in class_to_color.items(): | |
| legend_html += f""" | |
| <div style=" | |
| display: flex; | |
| align-items: center; | |
| gap: 5px; | |
| "> | |
| <div style=" | |
| width: 20px; | |
| height: 20px; | |
| background-color: {color}; | |
| border-radius: 3px; | |
| "></div> | |
| <span>{class_name}</span> | |
| </div> | |
| """ | |
| legend_html += """ | |
| </div> | |
| </div> | |
| """ | |
| return gr.HTML(legend_html) | |
| def update_all_plots(frequency, split_name): | |
| return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name) | |
| if __name__ == "__main__": | |
| print(f"Loading domains data...") | |
| single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip') | |
| single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True) | |
| single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) | |
| single_df = convert_int64_to_int32(single_df) | |
| pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip') | |
| pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True) | |
| pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) | |
| pair_df = convert_int64_to_int32(pair_df) | |
| num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename( | |
| columns={'as_domain_id': 'num_domains'}) | |
| unique_domain_lengths = num_domains_in_region_df.num_domains.unique() | |
| print(f"Initializing data plotter...") | |
| data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df) | |
| print(f"Defining blocks...") | |
| # Create Gradio interface | |
| with gr.Blocks(title="BGC Keyword Plotter") as demo: | |
| gr.Markdown("## BGC Keyword Plotter") | |
| gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.") | |
| color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS) | |
| with gr.Row(): | |
| frequency_slider = gr.Slider( | |
| minimum=int(unique_domain_lengths.min()), | |
| maximum=int(unique_domain_lengths.max()), | |
| step=1, | |
| value=int(unique_domain_lengths.min()), | |
| label="Min number of domains" | |
| ) | |
| model_selector = gr.Radio( | |
| choices=["stratified"] + BIOSYN_CLASS_NAMES, | |
| value="stratified", | |
| label="Model name" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| single_domains_plot = gr.Plot( | |
| label="Single domains", | |
| container=True, | |
| elem_id="single_domains_plot" | |
| ) | |
| with gr.Column(): | |
| pair_domains_plot = gr.Plot(label="Pair domains") | |
| frequency_slider.release( | |
| fn=update_all_plots, | |
| inputs=[frequency_slider, model_selector], | |
| outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot] | |
| ) | |
| demo.load( | |
| fn=update_all_plots, | |
| inputs=[frequency_slider, model_selector], | |
| outputs=[single_domains_plot, pair_domains_plot] | |
| ) | |
| model_selector.input( | |
| fn=update_all_plots, | |
| inputs=[frequency_slider, model_selector], | |
| outputs=[single_domains_plot, pair_domains_plot] | |
| ) | |
| print(f"Launching!...") | |
| demo.launch() | |
| # demo.load(filter_map, [min_price, max_price, boroughs], map) |