AU-EPG / app.py
michsethowusu's picture
Update app.py
4fbd7b6 verified
raw
history blame
11.7 kB
import gradio as gr
import pandas as pd
from openai import OpenAI
import os
import plotly.graph_objects as go
import networkx as nx
# Initialize NVIDIA API client
client = OpenAI(
base_url="https://integrate.api.nvidia.com/v1",
api_key=os.environ.get("NVIDIA_API_KEY")
)
# Load CSV files
validated_entities_df = pd.read_csv("validated_entities_final.csv")
relations_df = pd.read_csv("extracted_entities_relations_countries.csv")
# Get list of entities for dropdown
entity_list = validated_entities_df['entity'].dropna().unique().tolist()
# Get all unique countries from the data
all_countries_in_data = relations_df['country'].dropna().unique().tolist()
def get_entity_triples(entity):
"""Extract all triples associated with the selected entity"""
# Filter rows where entity appears in head or tail
filtered_df = relations_df[
(relations_df['head'] == entity) | (relations_df['tail'] == entity)
]
if filtered_df.empty:
return [], [], []
# Extract unique triples
triples = []
for _, row in filtered_df.iterrows():
triple = (row['head'], row['relation'], row['tail'])
if triple not in triples:
triples.append(triple)
# Get associated sentences
sentences = filtered_df['sentence'].dropna().unique().tolist()
# Get associated countries
countries = filtered_df['country'].dropna().unique().tolist()
return triples, sentences, countries
def create_knowledge_graph(entity, triples):
"""Create interactive knowledge graph visualization using plotly"""
if not triples:
return None
# Create directed graph
G = nx.DiGraph()
# Add edges (triples)
for head, relation, tail in triples:
G.add_edge(head, tail, label=relation)
# Generate layout
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
# Create edge traces
edge_traces = []
edge_labels = []
for edge in G.edges(data=True):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
# Edge line
edge_trace = go.Scatter(
x=[x0, x1, None],
y=[y0, y1, None],
mode='lines',
line=dict(width=2, color='#888'),
hoverinfo='none',
showlegend=False
)
edge_traces.append(edge_trace)
# Edge label (relation)
edge_labels.append(go.Scatter(
x=[(x0 + x1) / 2],
y=[(y0 + y1) / 2],
mode='text',
text=[edge[2]['label']],
textposition='middle center',
textfont=dict(size=8, color='#666'),
hoverinfo='text',
hovertext=edge[2]['label'],
showlegend=False
))
# Create node trace
node_x = []
node_y = []
node_text = []
node_colors = []
node_sizes = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_text.append(node)
# Highlight the main entity
if node == entity:
node_colors.append('#FF6B6B')
node_sizes.append(30)
else:
node_colors.append('#4ECDC4')
node_sizes.append(20)
node_trace = go.Scatter(
x=node_x,
y=node_y,
mode='markers+text',
text=node_text,
textposition='top center',
textfont=dict(size=10, color='#000'),
marker=dict(
size=node_sizes,
color=node_colors,
line=dict(width=2, color='#fff')
),
hoverinfo='text',
hovertext=node_text,
showlegend=False
)
# Create figure
fig = go.Figure(data=edge_traces + edge_labels + [node_trace])
fig.update_layout(
title=dict(
text=f"Knowledge Graph for: {entity}",
x=0.5,
xanchor='center',
font=dict(size=16)
),
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='#f9f9f9',
height=600
)
return fig
def create_africa_map(countries):
"""Create Africa map with highlighted countries"""
if not countries:
return None
# Use all countries from the dataset for the base map
locations = []
z_values = []
hover_text = []
for country in all_countries_in_data:
locations.append(country)
if country in countries:
z_values.append(1)
hover_text.append(f"{country} (Selected Entity)")
else:
z_values.append(0)
hover_text.append(country)
fig = go.Figure(data=go.Choropleth(
locations=locations,
locationmode='country names',
z=z_values,
text=hover_text,
colorscale=[[0, '#E8E8E8'], [1, '#FF6B6B']],
showscale=False,
marker_line_color='white',
marker_line_width=1,
hovertemplate='<b>%{text}</b><extra></extra>'
))
fig.update_geos(
scope='africa',
showframe=True,
showcoastlines=True,
projection_type='natural earth',
bgcolor='#f0f0f0'
)
fig.update_layout(
title=dict(
text=f"Countries with Topic Mentions ({len(countries)} of {len(all_countries_in_data)} total)",
x=0.5,
xanchor='center',
font=dict(size=16)
),
margin=dict(l=0, r=0, t=40, b=0),
height=500,
geo=dict(bgcolor='#f9f9f9')
)
return fig
def generate_entity_description(entity, triples):
"""Generate entity description using Llama 405B based on triples"""
if not triples:
return "No triples found for this entity."
triples_text = "\n".join([f"- {head} | {rel} | {tail}" for head, rel, tail in triples])
prompt = f"""Based on the following knowledge graph triples, provide a comprehensive description of "{entity}":
{triples_text}
Please synthesize this information into a clear, coherent description that explains what {entity} is, its relationships, and its role based on the triples provided."""
try:
completion = client.chat.completions.create(
model="meta/llama-3.1-405b-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=1024
)
return completion.choices[0].message.content
except Exception as e:
return f"Error generating description: {str(e)}"
def generate_country_paragraphs(entity, sentences, countries):
"""Generate country-specific paragraphs using sentences and country info"""
if not sentences or not countries:
return "No sentences or countries found for this topic."
country_paragraphs = []
for country in countries:
# Filter sentences for this country
country_sentences = relations_df[
((relations_df['head'] == entity) | (relations_df['tail'] == entity)) &
(relations_df['country'] == country)
]['sentence'].dropna().unique().tolist()
if not country_sentences:
continue
sentences_text = "\n".join([f"- {sent}" for sent in country_sentences[:10]]) # Limit to 10 sentences
prompt = f"""Based on the following sentences about "{entity}" in {country}, generate a comprehensive paragraph that describes the entity's role, activities, and significance in {country}:
{sentences_text}
Please create a well-structured paragraph that synthesizes this information specifically for {country}."""
try:
completion = client.chat.completions.create(
model="meta/llama-3.1-405b-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=1024
)
country_paragraph = f"**{country}:**\n{completion.choices[0].message.content}\n"
country_paragraphs.append(country_paragraph)
except Exception as e:
country_paragraphs.append(f"**{country}:** Error generating paragraph: {str(e)}\n")
return "\n".join(country_paragraphs)
def process_entity(entity):
"""Main function to process selected entity"""
if not entity:
return None, None, "", "", ""
# Get triples, sentences, and countries
triples, sentences, countries = get_entity_triples(entity)
if not triples:
return None, None, "No data found for this entity.", "", ""
# Create knowledge graph
kg_fig = create_knowledge_graph(entity, triples)
# Create Africa map
map_fig = create_africa_map(countries)
# Generate entity description
description = generate_entity_description(entity, triples)
# Generate country-specific paragraphs
country_paragraphs = generate_country_paragraphs(entity, sentences, countries)
# Format metadata
metadata = f"**Associated Countries:** {', '.join(countries)}\n\n**Number of Triples:** {len(triples)}\n\n**Number of Sentences:** {len(sentences)}"
return kg_fig, map_fig, description, country_paragraphs, metadata
# Create Gradio interface
with gr.Blocks(title="AU Education Policy Glossary", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# AU Education Policy Glossary (AU-EPG)
Select an Education Policy topic to examine its context and implementation across Africa.
"""
)
with gr.Row():
with gr.Column(scale=1):
entity_dropdown = gr.Dropdown(
choices=sorted(entity_list),
label="Select Topic",
filterable=True,
info="Start typing to search for a topic"
)
search_btn = gr.Button("Search", variant="primary", size="lg")
gr.Markdown("### Metadata")
metadata_output = gr.Markdown()
with gr.Column(scale=3):
with gr.Row():
with gr.Column():
gr.Markdown("### Knowledge Graph Visualization")
kg_plot = gr.Plot(label="Knowledge Graph")
with gr.Column():
gr.Markdown("### Geographic Distribution")
map_plot = gr.Plot(label="Africa Map")
gr.Markdown("### Topic Definition")
description_output = gr.Markdown()
gr.Markdown("### Country-Specific Insights")
country_output = gr.Markdown()
# Connect button to processing function
search_btn.click(
fn=process_entity,
inputs=[entity_dropdown],
outputs=[kg_plot, map_plot, description_output, country_output, metadata_output]
)
# Also trigger on dropdown change
entity_dropdown.change(
fn=process_entity,
inputs=[entity_dropdown],
outputs=[kg_plot, map_plot, description_output, country_output, metadata_output]
)
gr.Markdown(
"""
---
💡 **About this app:**
This app is open source and built to help explore education policy initiatives across African countries.
You’re welcome to view, use, and contribute to its codebase or adapt it for your own research and data projects.
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()