michsethowusu commited on
Commit
5e47f17
·
verified ·
1 Parent(s): bb6e757

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -18
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import pandas as pd
3
  from openai import OpenAI
4
  import os
 
 
5
 
6
  # Initialize NVIDIA API client
7
  client = OpenAI(
@@ -16,6 +18,9 @@ relations_df = pd.read_csv("extracted_entities_relations_countries.csv")
16
  # Get list of entities for dropdown
17
  entity_list = validated_entities_df['entity'].dropna().unique().tolist()
18
 
 
 
 
19
  def get_entity_triples(entity):
20
  """Extract all triples associated with the selected entity"""
21
  # Filter rows where entity appears in head or tail
@@ -29,7 +34,7 @@ def get_entity_triples(entity):
29
  # Extract unique triples
30
  triples = []
31
  for _, row in filtered_df.iterrows():
32
- triple = f"{row['head']} | {row['relation']} | {row['tail']}"
33
  if triple not in triples:
34
  triples.append(triple)
35
 
@@ -41,12 +46,171 @@ def get_entity_triples(entity):
41
 
42
  return triples, sentences, countries
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def generate_entity_description(entity, triples):
45
  """Generate entity description using Llama 405B based on triples"""
46
  if not triples:
47
  return "No triples found for this entity."
48
 
49
- triples_text = "\n".join([f"- {triple}" for triple in triples])
50
 
51
  prompt = f"""Based on the following knowledge graph triples, provide a comprehensive description of "{entity}":
52
 
@@ -107,16 +271,19 @@ Please create a well-structured paragraph that synthesizes this information spec
107
  def process_entity(entity):
108
  """Main function to process selected entity"""
109
  if not entity:
110
- return "Please select an entity.", "", "", ""
111
 
112
  # Get triples, sentences, and countries
113
  triples, sentences, countries = get_entity_triples(entity)
114
 
115
  if not triples:
116
- return "No data found for this entity.", "", "", ""
 
 
 
117
 
118
- # Format triples for display
119
- triples_display = "\n".join(triples)
120
 
121
  # Generate entity description
122
  description = generate_entity_description(entity, triples)
@@ -127,15 +294,15 @@ def process_entity(entity):
127
  # Format metadata
128
  metadata = f"**Associated Countries:** {', '.join(countries)}\n\n**Number of Triples:** {len(triples)}\n\n**Number of Sentences:** {len(sentences)}"
129
 
130
- return triples_display, description, country_paragraphs, metadata
131
 
132
  # Create Gradio interface
133
- with gr.Blocks(title="Entity Knowledge Graph Explorer") as demo:
134
  gr.Markdown(
135
  """
136
  # 🔍 Entity Knowledge Graph Explorer
137
 
138
- Select an entity to explore its knowledge graph triples, generate AI descriptions,
139
  and view country-specific insights using Llama 405B.
140
  """
141
  )
@@ -153,13 +320,15 @@ with gr.Blocks(title="Entity Knowledge Graph Explorer") as demo:
153
  gr.Markdown("### Metadata")
154
  metadata_output = gr.Markdown()
155
 
156
- with gr.Column(scale=2):
157
- gr.Markdown("### Knowledge Graph Triples")
158
- triples_output = gr.Textbox(
159
- label="Extracted Triples (Head | Relation | Tail)",
160
- lines=10,
161
- max_lines=15
162
- )
 
 
163
 
164
  gr.Markdown("### AI-Generated Entity Description")
165
  description_output = gr.Markdown()
@@ -171,14 +340,14 @@ with gr.Blocks(title="Entity Knowledge Graph Explorer") as demo:
171
  search_btn.click(
172
  fn=process_entity,
173
  inputs=[entity_dropdown],
174
- outputs=[triples_output, description_output, country_output, metadata_output]
175
  )
176
 
177
  # Also trigger on dropdown change
178
  entity_dropdown.change(
179
  fn=process_entity,
180
  inputs=[entity_dropdown],
181
- outputs=[triples_output, description_output, country_output, metadata_output]
182
  )
183
 
184
  gr.Markdown(
 
2
  import pandas as pd
3
  from openai import OpenAI
4
  import os
5
+ import plotly.graph_objects as go
6
+ import networkx as nx
7
 
8
  # Initialize NVIDIA API client
9
  client = OpenAI(
 
18
  # Get list of entities for dropdown
19
  entity_list = validated_entities_df['entity'].dropna().unique().tolist()
20
 
21
+ # Get all unique countries from the data
22
+ all_countries_in_data = relations_df['country'].dropna().unique().tolist()
23
+
24
  def get_entity_triples(entity):
25
  """Extract all triples associated with the selected entity"""
26
  # Filter rows where entity appears in head or tail
 
34
  # Extract unique triples
35
  triples = []
36
  for _, row in filtered_df.iterrows():
37
+ triple = (row['head'], row['relation'], row['tail'])
38
  if triple not in triples:
39
  triples.append(triple)
40
 
 
46
 
47
  return triples, sentences, countries
48
 
49
+ def create_knowledge_graph(entity, triples):
50
+ """Create interactive knowledge graph visualization using plotly"""
51
+ if not triples:
52
+ return None
53
+
54
+ # Create directed graph
55
+ G = nx.DiGraph()
56
+
57
+ # Add edges (triples)
58
+ for head, relation, tail in triples:
59
+ G.add_edge(head, tail, label=relation)
60
+
61
+ # Generate layout
62
+ pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
63
+
64
+ # Create edge traces
65
+ edge_traces = []
66
+ edge_labels = []
67
+
68
+ for edge in G.edges(data=True):
69
+ x0, y0 = pos[edge[0]]
70
+ x1, y1 = pos[edge[1]]
71
+
72
+ # Edge line
73
+ edge_trace = go.Scatter(
74
+ x=[x0, x1, None],
75
+ y=[y0, y1, None],
76
+ mode='lines',
77
+ line=dict(width=2, color='#888'),
78
+ hoverinfo='none',
79
+ showlegend=False
80
+ )
81
+ edge_traces.append(edge_trace)
82
+
83
+ # Edge label (relation)
84
+ edge_labels.append(go.Scatter(
85
+ x=[(x0 + x1) / 2],
86
+ y=[(y0 + y1) / 2],
87
+ mode='text',
88
+ text=[edge[2]['label']],
89
+ textposition='middle center',
90
+ textfont=dict(size=8, color='#666'),
91
+ hoverinfo='text',
92
+ hovertext=edge[2]['label'],
93
+ showlegend=False
94
+ ))
95
+
96
+ # Create node trace
97
+ node_x = []
98
+ node_y = []
99
+ node_text = []
100
+ node_colors = []
101
+ node_sizes = []
102
+
103
+ for node in G.nodes():
104
+ x, y = pos[node]
105
+ node_x.append(x)
106
+ node_y.append(y)
107
+ node_text.append(node)
108
+
109
+ # Highlight the main entity
110
+ if node == entity:
111
+ node_colors.append('#FF6B6B')
112
+ node_sizes.append(30)
113
+ else:
114
+ node_colors.append('#4ECDC4')
115
+ node_sizes.append(20)
116
+
117
+ node_trace = go.Scatter(
118
+ x=node_x,
119
+ y=node_y,
120
+ mode='markers+text',
121
+ text=node_text,
122
+ textposition='top center',
123
+ textfont=dict(size=10, color='#000'),
124
+ marker=dict(
125
+ size=node_sizes,
126
+ color=node_colors,
127
+ line=dict(width=2, color='#fff')
128
+ ),
129
+ hoverinfo='text',
130
+ hovertext=node_text,
131
+ showlegend=False
132
+ )
133
+
134
+ # Create figure
135
+ fig = go.Figure(data=edge_traces + edge_labels + [node_trace])
136
+
137
+ fig.update_layout(
138
+ title=dict(
139
+ text=f"Knowledge Graph for: {entity}",
140
+ x=0.5,
141
+ xanchor='center',
142
+ font=dict(size=16)
143
+ ),
144
+ showlegend=False,
145
+ hovermode='closest',
146
+ margin=dict(b=20, l=5, r=5, t=40),
147
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
148
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
149
+ plot_bgcolor='#f9f9f9',
150
+ height=600
151
+ )
152
+
153
+ return fig
154
+
155
+ def create_africa_map(countries):
156
+ """Create Africa map with highlighted countries"""
157
+ if not countries:
158
+ return None
159
+
160
+ # Use all countries from the dataset for the base map
161
+ locations = []
162
+ z_values = []
163
+ hover_text = []
164
+
165
+ for country in all_countries_in_data:
166
+ locations.append(country)
167
+ if country in countries:
168
+ z_values.append(1)
169
+ hover_text.append(f"{country} (Selected Entity)")
170
+ else:
171
+ z_values.append(0)
172
+ hover_text.append(country)
173
+
174
+ fig = go.Figure(data=go.Choropleth(
175
+ locations=locations,
176
+ locationmode='country names',
177
+ z=z_values,
178
+ text=hover_text,
179
+ colorscale=[[0, '#E8E8E8'], [1, '#FF6B6B']],
180
+ showscale=False,
181
+ marker_line_color='white',
182
+ marker_line_width=1,
183
+ hovertemplate='<b>%{text}</b><extra></extra>'
184
+ ))
185
+
186
+ fig.update_geos(
187
+ scope='africa',
188
+ showframe=True,
189
+ showcoastlines=True,
190
+ projection_type='natural earth',
191
+ bgcolor='#f0f0f0'
192
+ )
193
+
194
+ fig.update_layout(
195
+ title=dict(
196
+ text=f"Countries with Entity Mentions ({len(countries)} of {len(all_countries_in_data)} total)",
197
+ x=0.5,
198
+ xanchor='center',
199
+ font=dict(size=16)
200
+ ),
201
+ margin=dict(l=0, r=0, t=40, b=0),
202
+ height=500,
203
+ geo=dict(bgcolor='#f9f9f9')
204
+ )
205
+
206
+ return fig
207
+
208
  def generate_entity_description(entity, triples):
209
  """Generate entity description using Llama 405B based on triples"""
210
  if not triples:
211
  return "No triples found for this entity."
212
 
213
+ triples_text = "\n".join([f"- {head} | {rel} | {tail}" for head, rel, tail in triples])
214
 
215
  prompt = f"""Based on the following knowledge graph triples, provide a comprehensive description of "{entity}":
216
 
 
271
  def process_entity(entity):
272
  """Main function to process selected entity"""
273
  if not entity:
274
+ return None, None, "", "", ""
275
 
276
  # Get triples, sentences, and countries
277
  triples, sentences, countries = get_entity_triples(entity)
278
 
279
  if not triples:
280
+ return None, None, "No data found for this entity.", "", ""
281
+
282
+ # Create knowledge graph
283
+ kg_fig = create_knowledge_graph(entity, triples)
284
 
285
+ # Create Africa map
286
+ map_fig = create_africa_map(countries)
287
 
288
  # Generate entity description
289
  description = generate_entity_description(entity, triples)
 
294
  # Format metadata
295
  metadata = f"**Associated Countries:** {', '.join(countries)}\n\n**Number of Triples:** {len(triples)}\n\n**Number of Sentences:** {len(sentences)}"
296
 
297
+ return kg_fig, map_fig, description, country_paragraphs, metadata
298
 
299
  # Create Gradio interface
300
+ with gr.Blocks(title="Entity Knowledge Graph Explorer", theme=gr.themes.Soft()) as demo:
301
  gr.Markdown(
302
  """
303
  # 🔍 Entity Knowledge Graph Explorer
304
 
305
+ Select an entity to explore its knowledge graph, generate AI descriptions,
306
  and view country-specific insights using Llama 405B.
307
  """
308
  )
 
320
  gr.Markdown("### Metadata")
321
  metadata_output = gr.Markdown()
322
 
323
+ with gr.Column(scale=3):
324
+ with gr.Row():
325
+ with gr.Column():
326
+ gr.Markdown("### Knowledge Graph Visualization")
327
+ kg_plot = gr.Plot(label="Knowledge Graph")
328
+
329
+ with gr.Column():
330
+ gr.Markdown("### Geographic Distribution")
331
+ map_plot = gr.Plot(label="Africa Map")
332
 
333
  gr.Markdown("### AI-Generated Entity Description")
334
  description_output = gr.Markdown()
 
340
  search_btn.click(
341
  fn=process_entity,
342
  inputs=[entity_dropdown],
343
+ outputs=[kg_plot, map_plot, description_output, country_output, metadata_output]
344
  )
345
 
346
  # Also trigger on dropdown change
347
  entity_dropdown.change(
348
  fn=process_entity,
349
  inputs=[entity_dropdown],
350
+ outputs=[kg_plot, map_plot, description_output, country_output, metadata_output]
351
  )
352
 
353
  gr.Markdown(