Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from openai import OpenAI | |
| import os | |
| import plotly.graph_objects as go | |
| import networkx as nx | |
| # Initialize NVIDIA API client | |
| client = OpenAI( | |
| base_url="https://integrate.api.nvidia.com/v1", | |
| api_key=os.environ.get("NVIDIA_API_KEY") | |
| ) | |
| # Load CSV files | |
| validated_entities_df = pd.read_csv("validated_entities_final.csv") | |
| relations_df = pd.read_csv("extracted_entities_relations_countries.csv") | |
| # Get list of entities for dropdown | |
| entity_list = validated_entities_df['entity'].dropna().unique().tolist() | |
| # Get all unique countries from the data | |
| all_countries_in_data = relations_df['country'].dropna().unique().tolist() | |
| def get_entity_triples(entity): | |
| """Extract all triples associated with the selected entity""" | |
| # Filter rows where entity appears in head or tail | |
| filtered_df = relations_df[ | |
| (relations_df['head'] == entity) | (relations_df['tail'] == entity) | |
| ] | |
| if filtered_df.empty: | |
| return [], [], [] | |
| # Extract unique triples | |
| triples = [] | |
| for _, row in filtered_df.iterrows(): | |
| triple = (row['head'], row['relation'], row['tail']) | |
| if triple not in triples: | |
| triples.append(triple) | |
| # Get associated sentences | |
| sentences = filtered_df['sentence'].dropna().unique().tolist() | |
| # Get associated countries | |
| countries = filtered_df['country'].dropna().unique().tolist() | |
| return triples, sentences, countries | |
| def create_knowledge_graph(entity, triples): | |
| """Create interactive knowledge graph visualization using plotly""" | |
| if not triples: | |
| return None | |
| # Create directed graph | |
| G = nx.DiGraph() | |
| # Add edges (triples) | |
| for head, relation, tail in triples: | |
| G.add_edge(head, tail, label=relation) | |
| # Generate layout | |
| pos = nx.spring_layout(G, k=2, iterations=50, seed=42) | |
| # Create edge traces | |
| edge_traces = [] | |
| edge_labels = [] | |
| for edge in G.edges(data=True): | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| # Edge line | |
| edge_trace = go.Scatter( | |
| x=[x0, x1, None], | |
| y=[y0, y1, None], | |
| mode='lines', | |
| line=dict(width=2, color='#888'), | |
| hoverinfo='none', | |
| showlegend=False | |
| ) | |
| edge_traces.append(edge_trace) | |
| # Edge label (relation) | |
| edge_labels.append(go.Scatter( | |
| x=[(x0 + x1) / 2], | |
| y=[(y0 + y1) / 2], | |
| mode='text', | |
| text=[edge[2]['label']], | |
| textposition='middle center', | |
| textfont=dict(size=8, color='#666'), | |
| hoverinfo='text', | |
| hovertext=edge[2]['label'], | |
| showlegend=False | |
| )) | |
| # Create node trace | |
| node_x = [] | |
| node_y = [] | |
| node_text = [] | |
| node_colors = [] | |
| node_sizes = [] | |
| for node in G.nodes(): | |
| x, y = pos[node] | |
| node_x.append(x) | |
| node_y.append(y) | |
| node_text.append(node) | |
| # Highlight the main entity | |
| if node == entity: | |
| node_colors.append('#FF6B6B') | |
| node_sizes.append(30) | |
| else: | |
| node_colors.append('#4ECDC4') | |
| node_sizes.append(20) | |
| node_trace = go.Scatter( | |
| x=node_x, | |
| y=node_y, | |
| mode='markers+text', | |
| text=node_text, | |
| textposition='top center', | |
| textfont=dict(size=10, color='#000'), | |
| marker=dict( | |
| size=node_sizes, | |
| color=node_colors, | |
| line=dict(width=2, color='#fff') | |
| ), | |
| hoverinfo='text', | |
| hovertext=node_text, | |
| showlegend=False | |
| ) | |
| # Create figure | |
| fig = go.Figure(data=edge_traces + edge_labels + [node_trace]) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Knowledge Graph for: {entity}", | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=16) | |
| ), | |
| showlegend=False, | |
| hovermode='closest', | |
| margin=dict(b=20, l=5, r=5, t=40), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| plot_bgcolor='#f9f9f9', | |
| height=600 | |
| ) | |
| return fig | |
| def create_africa_map(countries): | |
| """Create Africa map with highlighted countries""" | |
| if not countries: | |
| return None | |
| # Use all countries from the dataset for the base map | |
| locations = [] | |
| z_values = [] | |
| hover_text = [] | |
| for country in all_countries_in_data: | |
| locations.append(country) | |
| if country in countries: | |
| z_values.append(1) | |
| hover_text.append(f"{country} (Selected Entity)") | |
| else: | |
| z_values.append(0) | |
| hover_text.append(country) | |
| fig = go.Figure(data=go.Choropleth( | |
| locations=locations, | |
| locationmode='country names', | |
| z=z_values, | |
| text=hover_text, | |
| colorscale=[[0, '#E8E8E8'], [1, '#FF6B6B']], | |
| showscale=False, | |
| marker_line_color='white', | |
| marker_line_width=1, | |
| hovertemplate='<b>%{text}</b><extra></extra>' | |
| )) | |
| fig.update_geos( | |
| scope='africa', | |
| showframe=True, | |
| showcoastlines=True, | |
| projection_type='natural earth', | |
| bgcolor='#f0f0f0' | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Countries with Topic Mentions ({len(countries)} of {len(all_countries_in_data)} total)", | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=16) | |
| ), | |
| margin=dict(l=0, r=0, t=40, b=0), | |
| height=500, | |
| geo=dict(bgcolor='#f9f9f9') | |
| ) | |
| return fig | |
| def generate_entity_description(entity, triples): | |
| """Generate entity description using Llama 405B based on triples""" | |
| if not triples: | |
| return "No triples found for this entity." | |
| triples_text = "\n".join([f"- {head} | {rel} | {tail}" for head, rel, tail in triples]) | |
| prompt = f"""Based on the following knowledge graph triples, provide a comprehensive description of "{entity}": | |
| {triples_text} | |
| Please synthesize this information into a clear, coherent description that explains what {entity} is, its relationships, and its role based on the triples provided.""" | |
| try: | |
| completion = client.chat.completions.create( | |
| model="meta/llama-3.1-405b-instruct", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| max_tokens=1024 | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Error generating description: {str(e)}" | |
| def generate_country_paragraphs(entity, sentences, countries): | |
| """Generate country-specific paragraphs using sentences and country info""" | |
| if not sentences or not countries: | |
| return "No sentences or countries found for this topic." | |
| country_paragraphs = [] | |
| for country in countries: | |
| # Filter sentences for this country | |
| country_sentences = relations_df[ | |
| ((relations_df['head'] == entity) | (relations_df['tail'] == entity)) & | |
| (relations_df['country'] == country) | |
| ]['sentence'].dropna().unique().tolist() | |
| if not country_sentences: | |
| continue | |
| sentences_text = "\n".join([f"- {sent}" for sent in country_sentences[:10]]) # Limit to 10 sentences | |
| prompt = f"""Based on the following sentences about "{entity}" in {country}, generate a comprehensive paragraph that describes the entity's role, activities, and significance in {country}: | |
| {sentences_text} | |
| Please create a well-structured paragraph that synthesizes this information specifically for {country}.""" | |
| try: | |
| completion = client.chat.completions.create( | |
| model="meta/llama-3.1-405b-instruct", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| max_tokens=1024 | |
| ) | |
| country_paragraph = f"**{country}:**\n{completion.choices[0].message.content}\n" | |
| country_paragraphs.append(country_paragraph) | |
| except Exception as e: | |
| country_paragraphs.append(f"**{country}:** Error generating paragraph: {str(e)}\n") | |
| return "\n".join(country_paragraphs) | |
| def process_entity(entity): | |
| """Main function to process selected entity""" | |
| if not entity: | |
| return None, None, "", "", "" | |
| # Get triples, sentences, and countries | |
| triples, sentences, countries = get_entity_triples(entity) | |
| if not triples: | |
| return None, None, "No data found for this entity.", "", "" | |
| # Create knowledge graph | |
| kg_fig = create_knowledge_graph(entity, triples) | |
| # Create Africa map | |
| map_fig = create_africa_map(countries) | |
| # Generate entity description | |
| description = generate_entity_description(entity, triples) | |
| # Generate country-specific paragraphs | |
| country_paragraphs = generate_country_paragraphs(entity, sentences, countries) | |
| # Format metadata | |
| metadata = f"**Associated Countries:** {', '.join(countries)}\n\n**Number of Triples:** {len(triples)}\n\n**Number of Sentences:** {len(sentences)}" | |
| return kg_fig, map_fig, description, country_paragraphs, metadata | |
| # Create Gradio interface | |
| with gr.Blocks(title="AU Education Policy Glossary", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # AU Education Policy Glossary (AU-EPG) | |
| Select an Education Policy topic to examine its context and implementation across Africa. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| entity_dropdown = gr.Dropdown( | |
| choices=sorted(entity_list), | |
| label="Select Topic", | |
| filterable=True, | |
| info="Start typing to search for a topic" | |
| ) | |
| search_btn = gr.Button("Search", variant="primary", size="lg") | |
| gr.Markdown("### Metadata") | |
| metadata_output = gr.Markdown() | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Knowledge Graph Visualization") | |
| kg_plot = gr.Plot(label="Knowledge Graph") | |
| with gr.Column(): | |
| gr.Markdown("### Geographic Distribution") | |
| map_plot = gr.Plot(label="Africa Map") | |
| gr.Markdown("### Topic Definition") | |
| description_output = gr.Markdown() | |
| gr.Markdown("### Country-Specific Insights") | |
| country_output = gr.Markdown() | |
| # Connect button to processing function | |
| search_btn.click( | |
| fn=process_entity, | |
| inputs=[entity_dropdown], | |
| outputs=[kg_plot, map_plot, description_output, country_output, metadata_output] | |
| ) | |
| # Also trigger on dropdown change | |
| entity_dropdown.change( | |
| fn=process_entity, | |
| inputs=[entity_dropdown], | |
| outputs=[kg_plot, map_plot, description_output, country_output, metadata_output] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| 💡 **About this app:** | |
| This app is open source and built to help explore education policy initiatives across African countries. | |
| You’re welcome to view, use, and contribute to its codebase or adapt it for your own research and data projects. | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |