Anurag Prasad commited on
Commit
506884b
·
1 Parent(s): 0d5ac71

Added basic dashboard layout

Browse files
Files changed (2) hide show
  1. .gradio/certificate.pem +31 -0
  2. app.py +541 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py CHANGED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import asyncio
3
+ from typing import Optional, List, Dict
4
+ from contextlib import AsyncExitStack
5
+ from mcp import ClientSession, StdioServerParameters
6
+ from mcp.client.stdio import stdio_client
7
+ import json
8
+ from datetime import datetime
9
+ import plotly.graph_objects as go
10
+ import plotly.express as px
11
+
12
+ class MCPClient:
13
+ def __init__(self):
14
+ self.session: Optional[ClientSession] = None
15
+ self.exit_stack = AsyncExitStack()
16
+
17
+ async def connect_to_server(self, server_script_path: str = "mcp_server.py"):
18
+ """Connect to MCP server"""
19
+ is_python = server_script_path.endswith('.py')
20
+ is_js = server_script_path.endswith('.js')
21
+
22
+ if not (is_python or is_js):
23
+ raise ValueError("Server script must be a .py or .js file")
24
+
25
+ command = "python" if is_python else "node"
26
+ server_params = StdioServerParameters(
27
+ command=command,
28
+ args=[server_script_path],
29
+ env=None
30
+ )
31
+
32
+ stdio_transport = await self.exit_stack.enter_async_context(
33
+ stdio_client(server_params)
34
+ )
35
+ self.stdio, self.write = stdio_transport
36
+ self.session = await self.exit_stack.enter_async_context(
37
+ ClientSession(self.stdio, self.write)
38
+ )
39
+
40
+ await self.session.initialize()
41
+
42
+ # List available tools
43
+ response = await self.session.list_tools()
44
+ tools = response.tools
45
+ print("Connected to server with tools:", [tool.name for tool in tools])
46
+
47
+ async def call_tool(self, tool_name: str, arguments: dict):
48
+ """Call a tool on the MCP server"""
49
+ if not self.session:
50
+ raise RuntimeError("Not connected to server")
51
+
52
+ response = await self.session.call_tool(tool_name, arguments)
53
+ return response.content
54
+
55
+ async def close(self):
56
+ """Close the MCP client connection"""
57
+ await self.exit_stack.aclose()
58
+
59
+ # Global MCP client instance
60
+ mcp_client = MCPClient()
61
+
62
+ # Async wrapper functions for Gradio
63
+ def run_async(coro):
64
+ """Helper to run async functions in Gradio"""
65
+ try:
66
+ loop = asyncio.get_event_loop()
67
+ except RuntimeError:
68
+ loop = asyncio.new_event_loop()
69
+ asyncio.set_event_loop(loop)
70
+
71
+ return loop.run_until_complete(coro)
72
+
73
+ # Auto-connect to MCP server on startup
74
+ def initialize_mcp_connection():
75
+ """Initialize MCP connection on startup"""
76
+ try:
77
+ run_async(mcp_client.connect_to_server())
78
+ print("Successfully connected to MCP server on startup")
79
+ return True
80
+ except Exception as e:
81
+ print(f"Failed to connect to MCP server on startup: {e}")
82
+ return False
83
+
84
+ # MCP client functions
85
+ def get_models_from_db():
86
+ """Get all models from database via MCP"""
87
+ try:
88
+ result = run_async(mcp_client.call_tool("get_all_models", {}))
89
+ return result if isinstance(result, list) else []
90
+ except Exception as e:
91
+ print(f"Error getting models: {e}")
92
+ # Fallback data for demonstration
93
+ return [
94
+ {"name": "llama-3.1-8b-instant", "created": "2025-01-15", "description": "Fast and efficient model for instant responses."},
95
+ {"name": "llama3-8b-8192", "created": "2025-02-10", "description": "Extended context window model with 8192 tokens."},
96
+ {"name": "gemini-2.5-pro-preview-06-05", "created": "2025-06-05", "description": "Professional preview version of Gemini 2.5."},
97
+ {"name": "gemini-2.5-flash-preview-05-20", "created": "2025-05-20", "description": "Flash preview with optimized speed."},
98
+ {"name": "gemini-1.5-pro", "created": "2024-12-01", "description": "Stable professional release of Gemini 1.5."}
99
+ ]
100
+
101
+ def get_available_model_names():
102
+ """Get list of available model names for dropdown"""
103
+ models = get_models_from_db()
104
+ return [model["name"] for model in models]
105
+
106
+ def search_models_in_db(search_term: str):
107
+ """Search models in database via MCP"""
108
+ try:
109
+ result = run_async(mcp_client.call_tool("search_models", {"search_term": search_term}))
110
+ return result if isinstance(result, list) else []
111
+ except Exception as e:
112
+ print(f"Error searching models: {e}")
113
+ # Fallback search for demonstration
114
+ all_models = get_models_from_db()
115
+ if not search_term:
116
+ return all_models
117
+ term = search_term.lower()
118
+ return [model for model in all_models if term in model["name"].lower() or term in model["description"].lower()]
119
+
120
+ def format_dropdown_items(models):
121
+ """Format dropdown items to show model name, creation date, and description preview"""
122
+ formatted_items = []
123
+ model_mapping = {}
124
+
125
+ for model in models:
126
+ desc_preview = model["description"][:40] + ("..." if len(model["description"]) > 40 else "")
127
+ item_label = f"{model['name']} (Created: {model['created']}) - {desc_preview}"
128
+ formatted_items.append(item_label)
129
+ model_mapping[item_label] = model["name"]
130
+
131
+ return formatted_items, model_mapping
132
+
133
+ def extract_model_name_from_dropdown(dropdown_value, model_mapping):
134
+ """Extract actual model name from formatted dropdown value"""
135
+ return model_mapping.get(dropdown_value, dropdown_value.split(" (")[0] if dropdown_value else "")
136
+
137
+ def get_model_details(model_name: str):
138
+ """Get model details from database via MCP"""
139
+ try:
140
+ result = run_async(mcp_client.call_tool("get_model_details", {"model_name": model_name}))
141
+ return result
142
+ except Exception as e:
143
+ print(f"Error getting model details: {e}")
144
+ return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
145
+
146
+ def enhance_prompt_via_mcp(prompt: str):
147
+ """Enhance prompt using MCP server"""
148
+ try:
149
+ result = run_async(mcp_client.call_tool("enhance_prompt", {"prompt": prompt}))
150
+ return result.get("enhanced_prompt", prompt)
151
+ except Exception as e:
152
+ print(f"Error enhancing prompt: {e}")
153
+ return f"Enhanced: {prompt}\n\nAdditional context: Be more specific, helpful, and provide detailed responses while maintaining a professional tone."
154
+
155
+ def save_model_to_db(model_name: str, system_prompt: str):
156
+ """Save model to database via MCP"""
157
+ try:
158
+ result = run_async(mcp_client.call_tool("save_model", {
159
+ "model_name": model_name,
160
+ "system_prompt": system_prompt
161
+ }))
162
+ return result.get("message", "Model saved successfully!")
163
+ except Exception as e:
164
+ print(f"Error saving model: {e}")
165
+ return f"Error saving model: {e}"
166
+
167
+ def calculate_drift_via_mcp(model_name: str):
168
+ """Calculate drift for model via MCP"""
169
+ try:
170
+ result = run_async(mcp_client.call_tool("calculate_drift", {"model_name": model_name}))
171
+ return result
172
+ except Exception as e:
173
+ print(f"Error calculating drift: {e}")
174
+ import random
175
+ drift_score = round(random.uniform(0.05, 0.25), 3)
176
+ return {"drift_score": drift_score, "message": f"Drift calculated and saved for {model_name}"}
177
+
178
+ def get_drift_history_from_db(model_name: str):
179
+ """Get drift history from database via MCP"""
180
+ try:
181
+ result = run_async(mcp_client.call_tool("get_drift_history", {"model_name": model_name}))
182
+ return result if isinstance(result, list) else []
183
+ except Exception as e:
184
+ print(f"Error getting drift history: {e}")
185
+ # Fallback data for demonstration
186
+ return [
187
+ {"date": "2025-06-01", "drift_score": 0.12},
188
+ {"date": "2025-06-05", "drift_score": 0.18},
189
+ {"date": "2025-06-09", "drift_score": 0.15}
190
+ ]
191
+
192
+ def create_drift_chart(drift_history):
193
+ """Create drift chart using plotly"""
194
+ if not drift_history:
195
+ return gr.update(value=None)
196
+
197
+ dates = [entry["date"] for entry in drift_history]
198
+ scores = [entry["drift_score"] for entry in drift_history]
199
+
200
+ fig = go.Figure()
201
+ fig.add_trace(go.Scatter(
202
+ x=dates,
203
+ y=scores,
204
+ mode='lines+markers',
205
+ name='Drift Score',
206
+ line=dict(color='#ff6b6b', width=3),
207
+ marker=dict(size=8, color='#ff6b6b')
208
+ ))
209
+
210
+ fig.update_layout(
211
+ title='Model Drift Over Time',
212
+ xaxis_title='Date',
213
+ yaxis_title='Drift Score',
214
+ template='plotly_white',
215
+ height=400,
216
+ showlegend=True
217
+ )
218
+
219
+ return fig
220
+
221
+ # Global variable to store model mapping
222
+ current_model_mapping = {}
223
+
224
+ # Gradio interface functions
225
+ def update_model_dropdown(search_term):
226
+ """Update dropdown choices based on search term"""
227
+ global current_model_mapping
228
+
229
+ if search_term.strip():
230
+ models = search_models_in_db(search_term.strip())
231
+ else:
232
+ models = get_models_from_db()
233
+
234
+ formatted_items, model_mapping = format_dropdown_items(models)
235
+ current_model_mapping = model_mapping
236
+
237
+ return gr.update(choices=formatted_items, value=formatted_items[0] if formatted_items else None)
238
+
239
+ def on_model_select(dropdown_value):
240
+ """Handle model selection"""
241
+ if not dropdown_value:
242
+ return "", ""
243
+
244
+ actual_model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
245
+ return actual_model_name, actual_model_name
246
+
247
+ def toggle_create_new():
248
+ """Toggle create new model section visibility"""
249
+ return gr.update(visible=True)
250
+
251
+ def cancel_create_new():
252
+ """Cancel create new model"""
253
+ return [
254
+ gr.update(visible=False), # create_new_section
255
+ None, # new_model_name (dropdown)
256
+ "", # new_system_prompt
257
+ gr.update(visible=False), # enhanced_prompt_display
258
+ gr.update(visible=False), # prompt_choice
259
+ gr.update(visible=False), # save_model_button
260
+ gr.update(visible=False) # save_status
261
+ ]
262
+
263
+ def enhance_prompt(original_prompt):
264
+ """Enhance prompt and show options"""
265
+ if not original_prompt.strip():
266
+ return [
267
+ gr.update(visible=False),
268
+ gr.update(visible=False),
269
+ gr.update(visible=False)
270
+ ]
271
+
272
+ enhanced = enhance_prompt_via_mcp(original_prompt.strip())
273
+ return [
274
+ gr.update(value=enhanced, visible=True),
275
+ gr.update(visible=True),
276
+ gr.update(visible=True)
277
+ ]
278
+
279
+ def save_new_model(selected_model_name, original_prompt, enhanced_prompt, choice):
280
+ """Save new model to database"""
281
+ if not selected_model_name or not original_prompt.strip():
282
+ return [
283
+ "Please select a model and enter a system prompt",
284
+ gr.update(visible=True),
285
+ gr.update()
286
+ ]
287
+
288
+ final_prompt = enhanced_prompt if choice == "Keep Enhanced" else original_prompt
289
+ status = save_model_to_db(selected_model_name, final_prompt)
290
+
291
+ # Update dropdown choices
292
+ updated_models = get_models_from_db()
293
+ formatted_items, model_mapping = format_dropdown_items(updated_models)
294
+ global current_model_mapping
295
+ current_model_mapping = model_mapping
296
+
297
+ return [
298
+ status,
299
+ gr.update(visible=True),
300
+ gr.update(choices=formatted_items)
301
+ ]
302
+
303
+ def chatbot_response(message, history, dropdown_value):
304
+ """Generate chatbot response"""
305
+ if not message.strip() or not dropdown_value:
306
+ return history, ""
307
+
308
+ model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
309
+ model_details = get_model_details(model_name)
310
+ system_prompt = model_details.get("system_prompt", "")
311
+
312
+ # Simulate response (replace with actual LLM call)
313
+ response = f"[{model_name}] Response to: {message}\n(Using system prompt: {system_prompt[:50]}...)"
314
+ history.append([message, response])
315
+ return history, ""
316
+
317
+ def calculate_drift(dropdown_value):
318
+ """Calculate drift for selected model"""
319
+ if not dropdown_value:
320
+ return "Please select a model first"
321
+
322
+ model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
323
+ result = calculate_drift_via_mcp(model_name)
324
+ drift_score = result.get("drift_score", 0.0)
325
+ message = result.get("message", "")
326
+
327
+ return f"Drift Score: {drift_score:.3f}\n{message}"
328
+
329
+ def refresh_drift_history(dropdown_value):
330
+ """Refresh drift history for selected model"""
331
+ if not dropdown_value:
332
+ return [], gr.update(value=None)
333
+
334
+ model_name = extract_model_name_from_dropdown(dropdown_value, current_model_mapping)
335
+ history = get_drift_history_from_db(model_name)
336
+ chart = create_drift_chart(history)
337
+
338
+ return history, chart
339
+
340
+ def initialize_interface():
341
+ """Initialize interface with MCP connection and default data"""
342
+ # Connect to MCP server
343
+ mcp_connected = initialize_mcp_connection()
344
+
345
+ # Get initial model data
346
+ models = get_models_from_db()
347
+ formatted_items, model_mapping = format_dropdown_items(models)
348
+ global current_model_mapping
349
+ current_model_mapping = model_mapping
350
+
351
+ # Get available model names for create new model dropdown
352
+ available_models = get_available_model_names()
353
+
354
+ return (
355
+ formatted_items, # model_dropdown choices
356
+ formatted_items[0] if formatted_items else None, # model_dropdown value
357
+ available_models, # new_model_name choices
358
+ formatted_items[0].split(" (")[0] if formatted_items else "", # selected_model_display
359
+ formatted_items[0].split(" (")[0] if formatted_items else "" # drift_model_display
360
+ )
361
+
362
+ # Create Gradio interface
363
+ with gr.Blocks(title="AI Model Management & Interaction Platform") as demo:
364
+ gr.Markdown("# AI Model Management & Interaction Platform")
365
+
366
+ with gr.Row():
367
+ # Left Column - Model Selection
368
+ with gr.Column(scale=1):
369
+ gr.Markdown("### Model Selection")
370
+
371
+ model_dropdown = gr.Dropdown(
372
+ choices=[],
373
+ label="Select Model",
374
+ interactive=True
375
+ )
376
+
377
+ search_box = gr.Textbox(
378
+ placeholder="Search by model name or description...",
379
+ label="Search Models"
380
+ )
381
+
382
+ create_new_button = gr.Button("Create New Model", variant="secondary")
383
+
384
+ # Create New Model Section (Initially Hidden)
385
+ with gr.Group(visible=False) as create_new_section:
386
+ gr.Markdown("#### Create New Model")
387
+ new_model_name = gr.Dropdown(
388
+ choices=[],
389
+ label="Select Model Name",
390
+ interactive=True
391
+ )
392
+ new_system_prompt = gr.Textbox(
393
+ label="System Prompt",
394
+ placeholder="Enter system prompt",
395
+ lines=3
396
+ )
397
+
398
+ with gr.Row():
399
+ enhance_button = gr.Button("Enhance Prompt", variant="primary")
400
+ cancel_button = gr.Button("Cancel", variant="secondary")
401
+
402
+ enhanced_prompt_display = gr.Textbox(
403
+ label="Enhanced Prompt",
404
+ interactive=False,
405
+ lines=4,
406
+ visible=False
407
+ )
408
+
409
+ prompt_choice = gr.Radio(
410
+ choices=["Keep Enhanced", "Keep Original"],
411
+ label="Choose Prompt to Use",
412
+ visible=False
413
+ )
414
+
415
+ save_model_button = gr.Button("Save Model", variant="primary", visible=False)
416
+ save_status = gr.Textbox(label="Status", interactive=False, visible=False)
417
+
418
+ # Right Column - Model Operations
419
+ with gr.Column(scale=2):
420
+ gr.Markdown("### Model Operations")
421
+
422
+ with gr.Tabs():
423
+ # Chatbot Tab
424
+ with gr.TabItem("Chatbot"):
425
+ selected_model_display = gr.Textbox(
426
+ label="Currently Selected Model",
427
+ interactive=False
428
+ )
429
+
430
+ chatbot_interface = gr.Chatbot(height=400)
431
+
432
+ with gr.Row():
433
+ msg_input = gr.Textbox(
434
+ placeholder="Enter your message...",
435
+ label="Message",
436
+ scale=4
437
+ )
438
+ send_button = gr.Button("Send", variant="primary", scale=1)
439
+
440
+ clear_chat = gr.Button("Clear Chat", variant="secondary")
441
+
442
+ # Drift Analysis Tab
443
+ with gr.TabItem("Drift Analysis"):
444
+ drift_model_display = gr.Textbox(
445
+ label="Model for Drift Analysis",
446
+ interactive=False
447
+ )
448
+
449
+ with gr.Row():
450
+ calculate_drift_button = gr.Button("Calculate New Drift", variant="primary")
451
+ refresh_history_button = gr.Button("Refresh History", variant="secondary")
452
+
453
+ drift_result = gr.Textbox(label="Latest Drift Calculation", interactive=False)
454
+
455
+ gr.Markdown("#### Drift History")
456
+ drift_history_display = gr.JSON(label="Drift History Data")
457
+
458
+ gr.Markdown("#### Drift Chart")
459
+ drift_chart = gr.Plot(label="Drift Over Time")
460
+
461
+ # Event Handlers
462
+
463
+ # Search functionality - Dynamic update
464
+ search_box.change(
465
+ update_model_dropdown,
466
+ inputs=[search_box],
467
+ outputs=[model_dropdown]
468
+ )
469
+
470
+ # Model selection updates
471
+ model_dropdown.change(
472
+ on_model_select,
473
+ inputs=[model_dropdown],
474
+ outputs=[selected_model_display, drift_model_display]
475
+ )
476
+
477
+ # Create new model functionality
478
+ def show_create_new():
479
+ available_models = get_available_model_names()
480
+ return gr.update(visible=True), gr.update(choices=available_models)
481
+
482
+ create_new_button.click(
483
+ show_create_new,
484
+ outputs=[create_new_section, new_model_name]
485
+ )
486
+
487
+ cancel_button.click(cancel_create_new, outputs=[
488
+ create_new_section, new_model_name, new_system_prompt,
489
+ enhanced_prompt_display, prompt_choice, save_model_button, save_status
490
+ ])
491
+
492
+ # Enhance prompt
493
+ enhance_button.click(
494
+ enhance_prompt,
495
+ inputs=[new_system_prompt],
496
+ outputs=[enhanced_prompt_display, prompt_choice, save_model_button]
497
+ )
498
+
499
+ # Save model
500
+ save_model_button.click(
501
+ save_new_model,
502
+ inputs=[new_model_name, new_system_prompt, enhanced_prompt_display, prompt_choice],
503
+ outputs=[save_status, save_status, model_dropdown]
504
+ )
505
+
506
+ # Chatbot functionality
507
+ send_button.click(
508
+ chatbot_response,
509
+ inputs=[msg_input, chatbot_interface, model_dropdown],
510
+ outputs=[chatbot_interface, msg_input]
511
+ )
512
+
513
+ msg_input.submit(
514
+ chatbot_response,
515
+ inputs=[msg_input, chatbot_interface, model_dropdown],
516
+ outputs=[chatbot_interface, msg_input]
517
+ )
518
+
519
+ clear_chat.click(lambda: [], outputs=[chatbot_interface])
520
+
521
+ # Drift analysis functionality
522
+ calculate_drift_button.click(
523
+ calculate_drift,
524
+ inputs=[model_dropdown],
525
+ outputs=[drift_result]
526
+ )
527
+
528
+ refresh_history_button.click(
529
+ refresh_drift_history,
530
+ inputs=[model_dropdown],
531
+ outputs=[drift_history_display, drift_chart]
532
+ )
533
+
534
+ # Initialize interface on load
535
+ demo.load(
536
+ initialize_interface,
537
+ outputs=[model_dropdown, model_dropdown, new_model_name, selected_model_display, drift_model_display]
538
+ )
539
+
540
+ if __name__ == "__main__":
541
+ demo.launch(share=True)