Spaces:
Sleeping
Sleeping
File size: 11,679 Bytes
dee681d 5e47f17 dee681d 5e47f17 dee681d a02e9eb dee681d a02e9eb dee681d a02e9eb dee681d 5e47f17 a02e9eb dee681d a02e9eb a599a8c 5e47f17 a7c6838 5e47f17 a599a8c 5e47f17 a599a8c 4fbd7b6 a599a8c 5e47f17 a599a8c 5e47f17 a599a8c 5e47f17 a599a8c 5e47f17 a599a8c 0edebe3 a599a8c fa2c247 a599a8c 8294e62 a599a8c 8294e62 a599a8c 8294e62 a599a8c 4fbd7b6 a599a8c 5e47f17 a599a8c 4fbd7b6 a599a8c 5e47f17 a599a8c 5e47f17 a599a8c 0edebe3 a599a8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 |
import gradio as gr
import pandas as pd
from openai import OpenAI
import os
import plotly.graph_objects as go
import networkx as nx
# Initialize NVIDIA API client
client = OpenAI(
base_url="https://integrate.api.nvidia.com/v1",
api_key=os.environ.get("NVIDIA_API_KEY")
)
# Load CSV files
validated_entities_df = pd.read_csv("validated_entities_final.csv")
relations_df = pd.read_csv("extracted_entities_relations_countries.csv")
# Get list of entities for dropdown
entity_list = validated_entities_df['entity'].dropna().unique().tolist()
# Get all unique countries from the data
all_countries_in_data = relations_df['country'].dropna().unique().tolist()
def get_entity_triples(entity):
"""Extract all triples associated with the selected entity"""
# Filter rows where entity appears in head or tail
filtered_df = relations_df[
(relations_df['head'] == entity) | (relations_df['tail'] == entity)
]
if filtered_df.empty:
return [], [], []
# Extract unique triples
triples = []
for _, row in filtered_df.iterrows():
triple = (row['head'], row['relation'], row['tail'])
if triple not in triples:
triples.append(triple)
# Get associated sentences
sentences = filtered_df['sentence'].dropna().unique().tolist()
# Get associated countries
countries = filtered_df['country'].dropna().unique().tolist()
return triples, sentences, countries
def create_knowledge_graph(entity, triples):
"""Create interactive knowledge graph visualization using plotly"""
if not triples:
return None
# Create directed graph
G = nx.DiGraph()
# Add edges (triples)
for head, relation, tail in triples:
G.add_edge(head, tail, label=relation)
# Generate layout
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
# Create edge traces
edge_traces = []
edge_labels = []
for edge in G.edges(data=True):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
# Edge line
edge_trace = go.Scatter(
x=[x0, x1, None],
y=[y0, y1, None],
mode='lines',
line=dict(width=2, color='#888'),
hoverinfo='none',
showlegend=False
)
edge_traces.append(edge_trace)
# Edge label (relation)
edge_labels.append(go.Scatter(
x=[(x0 + x1) / 2],
y=[(y0 + y1) / 2],
mode='text',
text=[edge[2]['label']],
textposition='middle center',
textfont=dict(size=8, color='#666'),
hoverinfo='text',
hovertext=edge[2]['label'],
showlegend=False
))
# Create node trace
node_x = []
node_y = []
node_text = []
node_colors = []
node_sizes = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_text.append(node)
# Highlight the main entity
if node == entity:
node_colors.append('#FF6B6B')
node_sizes.append(30)
else:
node_colors.append('#4ECDC4')
node_sizes.append(20)
node_trace = go.Scatter(
x=node_x,
y=node_y,
mode='markers+text',
text=node_text,
textposition='top center',
textfont=dict(size=10, color='#000'),
marker=dict(
size=node_sizes,
color=node_colors,
line=dict(width=2, color='#fff')
),
hoverinfo='text',
hovertext=node_text,
showlegend=False
)
# Create figure
fig = go.Figure(data=edge_traces + edge_labels + [node_trace])
fig.update_layout(
title=dict(
text=f"Knowledge Graph for: {entity}",
x=0.5,
xanchor='center',
font=dict(size=16)
),
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='#f9f9f9',
height=600
)
return fig
def create_africa_map(countries):
"""Create Africa map with highlighted countries"""
if not countries:
return None
# Use all countries from the dataset for the base map
locations = []
z_values = []
hover_text = []
for country in all_countries_in_data:
locations.append(country)
if country in countries:
z_values.append(1)
hover_text.append(f"{country} (Selected Entity)")
else:
z_values.append(0)
hover_text.append(country)
fig = go.Figure(data=go.Choropleth(
locations=locations,
locationmode='country names',
z=z_values,
text=hover_text,
colorscale=[[0, '#E8E8E8'], [1, '#FF6B6B']],
showscale=False,
marker_line_color='white',
marker_line_width=1,
hovertemplate='<b>%{text}</b><extra></extra>'
))
fig.update_geos(
scope='africa',
showframe=True,
showcoastlines=True,
projection_type='natural earth',
bgcolor='#f0f0f0'
)
fig.update_layout(
title=dict(
text=f"Countries with Topic Mentions ({len(countries)} of {len(all_countries_in_data)} total)",
x=0.5,
xanchor='center',
font=dict(size=16)
),
margin=dict(l=0, r=0, t=40, b=0),
height=500,
geo=dict(bgcolor='#f9f9f9')
)
return fig
def generate_entity_description(entity, triples):
"""Generate entity description using Llama 405B based on triples"""
if not triples:
return "No triples found for this entity."
triples_text = "\n".join([f"- {head} | {rel} | {tail}" for head, rel, tail in triples])
prompt = f"""Based on the following knowledge graph triples, provide a comprehensive description of "{entity}":
{triples_text}
Please synthesize this information into a clear, coherent description that explains what {entity} is, its relationships, and its role based on the triples provided."""
try:
completion = client.chat.completions.create(
model="meta/llama-3.1-405b-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=1024
)
return completion.choices[0].message.content
except Exception as e:
return f"Error generating description: {str(e)}"
def generate_country_paragraphs(entity, sentences, countries):
"""Generate country-specific paragraphs using sentences and country info"""
if not sentences or not countries:
return "No sentences or countries found for this topic."
country_paragraphs = []
for country in countries:
# Filter sentences for this country
country_sentences = relations_df[
((relations_df['head'] == entity) | (relations_df['tail'] == entity)) &
(relations_df['country'] == country)
]['sentence'].dropna().unique().tolist()
if not country_sentences:
continue
sentences_text = "\n".join([f"- {sent}" for sent in country_sentences[:10]]) # Limit to 10 sentences
prompt = f"""Based on the following sentences about "{entity}" in {country}, generate a comprehensive paragraph that describes the entity's role, activities, and significance in {country}:
{sentences_text}
Please create a well-structured paragraph that synthesizes this information specifically for {country}."""
try:
completion = client.chat.completions.create(
model="meta/llama-3.1-405b-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=1024
)
country_paragraph = f"**{country}:**\n{completion.choices[0].message.content}\n"
country_paragraphs.append(country_paragraph)
except Exception as e:
country_paragraphs.append(f"**{country}:** Error generating paragraph: {str(e)}\n")
return "\n".join(country_paragraphs)
def process_entity(entity):
"""Main function to process selected entity"""
if not entity:
return None, None, "", "", ""
# Get triples, sentences, and countries
triples, sentences, countries = get_entity_triples(entity)
if not triples:
return None, None, "No data found for this entity.", "", ""
# Create knowledge graph
kg_fig = create_knowledge_graph(entity, triples)
# Create Africa map
map_fig = create_africa_map(countries)
# Generate entity description
description = generate_entity_description(entity, triples)
# Generate country-specific paragraphs
country_paragraphs = generate_country_paragraphs(entity, sentences, countries)
# Format metadata
metadata = f"**Associated Countries:** {', '.join(countries)}\n\n**Number of Triples:** {len(triples)}\n\n**Number of Sentences:** {len(sentences)}"
return kg_fig, map_fig, description, country_paragraphs, metadata
# Create Gradio interface
with gr.Blocks(title="AU Education Policy Glossary", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# AU Education Policy Glossary (AU-EPG)
Select an Education Policy topic to examine its context and implementation across Africa.
"""
)
with gr.Row():
with gr.Column(scale=1):
entity_dropdown = gr.Dropdown(
choices=sorted(entity_list),
label="Select Topic",
filterable=True,
info="Start typing to search for a topic"
)
search_btn = gr.Button("Search", variant="primary", size="lg")
gr.Markdown("### Metadata")
metadata_output = gr.Markdown()
with gr.Column(scale=3):
with gr.Row():
with gr.Column():
gr.Markdown("### Knowledge Graph Visualization")
kg_plot = gr.Plot(label="Knowledge Graph")
with gr.Column():
gr.Markdown("### Geographic Distribution")
map_plot = gr.Plot(label="Africa Map")
gr.Markdown("### Topic Definition")
description_output = gr.Markdown()
gr.Markdown("### Country-Specific Insights")
country_output = gr.Markdown()
# Connect button to processing function
search_btn.click(
fn=process_entity,
inputs=[entity_dropdown],
outputs=[kg_plot, map_plot, description_output, country_output, metadata_output]
)
# Also trigger on dropdown change
entity_dropdown.change(
fn=process_entity,
inputs=[entity_dropdown],
outputs=[kg_plot, map_plot, description_output, country_output, metadata_output]
)
gr.Markdown(
"""
---
💡 **About this app:**
This app is open source and built to help explore education policy initiatives across African countries.
You’re welcome to view, use, and contribute to its codebase or adapt it for your own research and data projects.
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch() |