from datasets import load_dataset import gradio as gr import os import random wmtis = load_dataset("nlphuji/wmtis")['test'] dataset_size = len(wmtis) IMAGE = 'image' IMAGE_DESIGNER = 'image_designer' DESIGNER_EXPLANATION = 'designer_explanation' CROWD_CAPTIONS = 'crowd_captions' CROWD_EXPLANATIONS = 'crowd_explanations' CROWD_NEGATIVE_CAPTIONS = 'crowd_negative_captions' CROWD_NEGATIVE_EXPLANATIONS = 'crowd_negative_explanations' QA = 'question_answering_pairs' IMAGE_ID = 'image_id' left_side_columns = [IMAGE] # left_side_columns = [IMAGE, DESIGNER_EXPLANATION, IMAGE_DESIGNER, IMAGE_ID] right_side_columns = [x for x in wmtis.features.keys() if x not in left_side_columns and x not in [CROWD_NEGATIVE_EXPLANATIONS, QA]] enumerate_cols = [CROWD_CAPTIONS, CROWD_EXPLANATIONS, CROWD_NEGATIVE_CAPTIONS, CROWD_NEGATIVE_EXPLANATIONS] def func(index): example = wmtis[index] values = get_instance_values(example) return values def get_instance_values(example): values = [] for k in left_side_columns + right_side_columns: if k in enumerate_cols: value = list_to_string(example[k]) elif k == QA: qa_list = [f"Q: {x[0]} A: {x[1]}" for x in example[k]] value = list_to_string(qa_list) else: value = example[k] values.append(value) return values def list_to_string(lst): return '\n'.join(['{}. {}'.format(i+1, item) for i, item in enumerate(lst)]) demo = gr.Blocks() with demo: gr.Markdown("# Slide to iterate WMTIS") with gr.Column(): slider = gr.Slider(minimum=0, maximum=dataset_size, step=1, label='index') with gr.Row(): index = random.choice(range(0, dataset_size)) example = wmtis[index] instance_values = get_instance_values(example) with gr.Column(): # image_input = gr.Image(value=wmtis[index]["image"]) inputs_left = [] assert len(left_side_columns) == len( instance_values[:len(left_side_columns)]) # excluding the image & designer for key, value in zip(left_side_columns, instance_values[:len(left_side_columns)]): if key == IMAGE: input_k = gr.Image(value=wmtis[index]["image"]) else: label = key.capitalize().replace("_", " ") input_k = gr.Textbox(value=value, label=label) inputs_left.append(input_k) with gr.Column(): text_inputs_right = [] assert len(right_side_columns) == len(instance_values[len(left_side_columns):]) # excluding the image & designer for key, value in zip(right_side_columns, instance_values[len(left_side_columns):]): label = key.capitalize().replace("_", " ") text_input_k = gr.Textbox(value=value, label=label) text_inputs_right.append(text_input_k) slider.change(func, inputs=[slider], outputs=inputs_left + text_inputs_right) demo.launch()