HarshBhati commited on
Commit
825a805
·
1 Parent(s): 9f67843

Add database integration: SQLite fallback, model and drift handlers

Browse files
app.py CHANGED
@@ -1,34 +1,38 @@
 
1
  import gradio as gr
2
  import asyncio
3
  from typing import Optional, List, Dict
4
  from contextlib import AsyncExitStack
5
  from mcp import ClientSession, StdioServerParameters
6
  from mcp.client.stdio import stdio_client
 
7
  import json
8
  from datetime import datetime
9
  import plotly.graph_objects as go
10
- import plotly.express as px
 
 
 
 
 
 
 
 
11
 
12
  class MCPClient:
13
  def __init__(self):
14
  self.session: Optional[ClientSession] = None
15
  self.exit_stack = AsyncExitStack()
16
-
17
- async def connect_to_server(self, server_script_path: str = "mcp_server.py"):
18
  """Connect to MCP server"""
19
  is_python = server_script_path.endswith('.py')
20
- is_js = server_script_path.endswith('.js')
21
-
22
- if not (is_python or is_js):
23
- raise ValueError("Server script must be a .py or .js file")
24
-
25
  command = "python" if is_python else "node"
26
  server_params = StdioServerParameters(
27
  command=command,
28
  args=[server_script_path],
29
  env=None
30
  )
31
-
32
  stdio_transport = await self.exit_stack.enter_async_context(
33
  stdio_client(server_params)
34
  )
@@ -36,87 +40,71 @@ class MCPClient:
36
  self.session = await self.exit_stack.enter_async_context(
37
  ClientSession(self.stdio, self.write)
38
  )
39
-
40
  await self.session.initialize()
41
-
42
- # List available tools
43
- response = await self.session.list_tools()
44
- tools = response.tools
45
- print("Connected to server with tools:", [tool.name for tool in tools])
46
-
47
  async def call_tool(self, tool_name: str, arguments: dict):
48
  """Call a tool on the MCP server"""
49
  if not self.session:
50
  raise RuntimeError("Not connected to server")
51
-
52
- response = await self.session.call_tool(tool_name, arguments)
53
- return response.content
54
-
55
  async def close(self):
56
  """Close the MCP client connection"""
57
  await self.exit_stack.aclose()
58
 
 
59
  # Global MCP client instance
60
  mcp_client = MCPClient()
61
 
62
- # Async wrapper functions for Gradio
 
63
  def run_async(coro):
64
- """Helper to run async functions in Gradio"""
65
  try:
66
- loop = asyncio.get_event_loop()
67
  except RuntimeError:
68
  loop = asyncio.new_event_loop()
69
  asyncio.set_event_loop(loop)
70
-
71
- return loop.run_until_complete(coro)
 
 
 
 
72
 
73
- # Auto-connect to MCP server on startup
74
  def initialize_mcp_connection():
75
- """Initialize MCP connection on startup"""
76
  try:
77
  run_async(mcp_client.connect_to_server())
78
- print("Successfully connected to MCP server on startup")
79
  return True
80
  except Exception as e:
81
- print(f"Failed to connect to MCP server on startup: {e}")
82
  return False
83
 
84
- # MCP client functions
 
85
  def get_models_from_db():
86
- """Get all models from database via MCP"""
87
  try:
88
  result = run_async(mcp_client.call_tool("get_all_models", {}))
89
  return result if isinstance(result, list) else []
90
  except Exception as e:
91
  print(f"Error getting models: {e}")
92
- # Fallback data for demonstration
93
- return [
94
- {"name": "llama-3.1-8b-instant", "created": "2025-01-15", "description": "Fast and efficient model for instant responses."},
95
- {"name": "llama3-8b-8192", "created": "2025-02-10", "description": "Extended context window model with 8192 tokens."},
96
- {"name": "gemini-2.5-pro-preview-06-05", "created": "2025-06-05", "description": "Professional preview version of Gemini 2.5."},
97
- {"name": "gemini-2.5-flash-preview-05-20", "created": "2025-05-20", "description": "Flash preview with optimized speed."},
98
- {"name": "gemini-1.5-pro", "created": "2024-12-01", "description": "Stable professional release of Gemini 1.5."}
99
- ]
100
 
101
  def get_available_model_names():
102
- """Get list of available model names for dropdown"""
103
- models = get_models_from_db()
104
- return [model["name"] for model in models]
105
 
106
  def search_models_in_db(search_term: str):
107
- """Search models in database via MCP"""
108
  try:
109
  result = run_async(mcp_client.call_tool("search_models", {"search_term": search_term}))
110
  return result if isinstance(result, list) else []
111
  except Exception as e:
112
  print(f"Error searching models: {e}")
113
- # Fallback search for demonstration
114
- all_models = get_models_from_db()
115
- if not search_term:
116
- return all_models
117
- term = search_term.lower()
118
- return [model for model in all_models if term in model["name"].lower() or term in model["description"].lower()]
119
-
120
  def format_dropdown_items(models):
121
  """Format dropdown items to show model name, creation date, and description preview"""
122
  formatted_items = []
 
1
+ import os
2
  import gradio as gr
3
  import asyncio
4
  from typing import Optional, List, Dict
5
  from contextlib import AsyncExitStack
6
  from mcp import ClientSession, StdioServerParameters
7
  from mcp.client.stdio import stdio_client
8
+ from database_module import init_db, get_all_models_handler, search_models_handler
9
  import json
10
  from datetime import datetime
11
  import plotly.graph_objects as go
12
+
13
+ # --- Initialize database and MCP tool registration ---
14
+ # Create tables and register MCP handlers
15
+ init_db()
16
+
17
+
18
+ # Ensure server.py imports and registers these tools:
19
+ # app.register_tool("get_all_models", get_all_models_handler)
20
+ # app.register_tool("search_models", search_models_handler)
21
 
22
  class MCPClient:
23
  def __init__(self):
24
  self.session: Optional[ClientSession] = None
25
  self.exit_stack = AsyncExitStack()
26
+
27
+ async def connect_to_server(self, server_script_path: str = "server.py"):
28
  """Connect to MCP server"""
29
  is_python = server_script_path.endswith('.py')
 
 
 
 
 
30
  command = "python" if is_python else "node"
31
  server_params = StdioServerParameters(
32
  command=command,
33
  args=[server_script_path],
34
  env=None
35
  )
 
36
  stdio_transport = await self.exit_stack.enter_async_context(
37
  stdio_client(server_params)
38
  )
 
40
  self.session = await self.exit_stack.enter_async_context(
41
  ClientSession(self.stdio, self.write)
42
  )
 
43
  await self.session.initialize()
44
+ tools = (await self.session.list_tools()).tools
45
+ print("Connected to server with tools:", [t.name for t in tools])
46
+
 
 
 
47
  async def call_tool(self, tool_name: str, arguments: dict):
48
  """Call a tool on the MCP server"""
49
  if not self.session:
50
  raise RuntimeError("Not connected to server")
51
+ return (await self.session.call_tool(tool_name, arguments)).content
52
+
 
 
53
  async def close(self):
54
  """Close the MCP client connection"""
55
  await self.exit_stack.aclose()
56
 
57
+
58
  # Global MCP client instance
59
  mcp_client = MCPClient()
60
 
61
+
62
+ # Helper to run async functions
63
  def run_async(coro):
 
64
  try:
65
+ loop = asyncio.get_running_loop()
66
  except RuntimeError:
67
  loop = asyncio.new_event_loop()
68
  asyncio.set_event_loop(loop)
69
+ return loop.run_until_complete(coro)
70
+ else:
71
+ # return result if coroutine returns value, else schedule
72
+ task = loop.create_task(coro)
73
+ return loop.run_until_complete(task) if not task.done() else task
74
+
75
 
76
+ # Initialize MCP connection on startup
77
  def initialize_mcp_connection():
 
78
  try:
79
  run_async(mcp_client.connect_to_server())
80
+ print("Successfully connected to MCP server")
81
  return True
82
  except Exception as e:
83
+ print(f"Failed to connect to MCP server: {e}")
84
  return False
85
 
86
+
87
+ # Wrapper functions remain unchanged but now call real DB-backed MCP tools
88
  def get_models_from_db():
 
89
  try:
90
  result = run_async(mcp_client.call_tool("get_all_models", {}))
91
  return result if isinstance(result, list) else []
92
  except Exception as e:
93
  print(f"Error getting models: {e}")
94
+ return []
95
+
 
 
 
 
 
 
96
 
97
  def get_available_model_names():
98
+ return [m["name"] for m in get_models_from_db()]
99
+
 
100
 
101
  def search_models_in_db(search_term: str):
 
102
  try:
103
  result = run_async(mcp_client.call_tool("search_models", {"search_term": search_term}))
104
  return result if isinstance(result, list) else []
105
  except Exception as e:
106
  print(f"Error searching models: {e}")
107
+ return [m for m in get_models_from_db() if search_term.lower() in m["name"].lower()]
 
 
 
 
 
 
108
  def format_dropdown_items(models):
109
  """Format dropdown items to show model name, creation date, and description preview"""
110
  formatted_items = []
database_module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .db import init_db
2
+ from .mcp_tools import get_all_models_handler, search_models_handler
database_module/db.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sqlalchemy import create_engine
3
+ from sqlalchemy.ext.declarative import declarative_base
4
+ from sqlalchemy.orm import sessionmaker
5
+
6
+ # Configure your database URL (env var or fallback to SQLite file)
7
+ DATABASE_URL = os.getenv(
8
+ "DATABASE_URL",
9
+ "sqlite:///./drift_detector.sqlite3"
10
+ )
11
+
12
+ # Create engine and session factory
13
+ # For SQLite, disable same-thread check
14
+ connect_args = {"check_same_thread": False} if DATABASE_URL.startswith("sqlite") else {}
15
+ engine = create_engine(
16
+ DATABASE_URL,
17
+ connect_args=connect_args,
18
+ pool_pre_ping=True
19
+ )
20
+
21
+ SessionLocal = sessionmaker(
22
+ autocommit=False,
23
+ autoflush=False,
24
+ bind=engine
25
+ )
26
+ Base = declarative_base()
27
+
28
+ def init_db():
29
+ """
30
+ Create tables if they don't exist.
31
+ Call this once at application startup.
32
+ """
33
+ Base.metadata.create_all(bind=engine)
database_module/mcp_tools.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # database_module/mcp_tools.py
2
+ from sqlalchemy.orm import Session
3
+ from sqlalchemy import or_
4
+ from .db import SessionLocal
5
+ from .models import ModelEntry, DriftEntry
6
+ from datetime import datetime
7
+ from typing import Any, Dict, List
8
+
9
+
10
+ def get_all_models_handler(_: Dict[str, Any]) -> List[Dict[str, Any]]:
11
+ """
12
+ Return all models as list of dicts matching:
13
+ {name, created (ISO), description}
14
+ """
15
+ with SessionLocal() as session:
16
+ entries = session.query(ModelEntry).all()
17
+ return [
18
+ {"name": e.name, "created": e.created.isoformat(), "description": e.description or ""}
19
+ for e in entries
20
+ ]
21
+
22
+
23
+ def search_models_handler(params: Dict[str, Any]) -> List[Dict[str, Any]]:
24
+ """
25
+ Search models by name or description substring (case-insensitive).
26
+ params: {search_term: str}
27
+ """
28
+ term = params.get("search_term", "").strip().lower()
29
+ with SessionLocal() as session:
30
+ query = session.query(ModelEntry)
31
+ if term:
32
+ like_pattern = f"%{term}%"
33
+ query = query.filter(
34
+ or_(
35
+ ModelEntry.name.ilike(like_pattern),
36
+ ModelEntry.description.ilike(like_pattern)
37
+ )
38
+ )
39
+ entries = query.all()
40
+ return [
41
+ {"name": e.name, "created": e.created.isoformat(), "description": e.description or ""}
42
+ for e in entries
43
+ ]
44
+
45
+
46
+ def get_model_details_handler(params: Dict[str, Any]) -> Dict[str, Any]:
47
+ """
48
+ Return a single model's details including system_prompt and description.
49
+ params: {model_name: str}
50
+ """
51
+ model_name = params.get("model_name")
52
+ with SessionLocal() as session:
53
+ e = session.query(ModelEntry).filter_by(name=model_name).first()
54
+ if not e:
55
+ return {"name": model_name, "system_prompt": "You are a helpful AI assistant.", "description": ""}
56
+ # You can store system_prompt as a column if desired; here placeholder
57
+ return {"name": e.name, "system_prompt": "You are a helpful AI assistant.", "description": e.description or ""}
58
+
59
+
60
+ def save_model_handler(params: Dict[str, Any]) -> Dict[str, Any]:
61
+ """
62
+ Save or update a model's system_prompt.
63
+ params: {model_name: str, system_prompt: str}
64
+ """
65
+ name = params.get("model_name")
66
+ prompt = params.get("system_prompt", "")
67
+ with SessionLocal() as session:
68
+ entry = session.query(ModelEntry).filter_by(name=name).first()
69
+ if not entry:
70
+ # New model; created today
71
+ entry = ModelEntry(
72
+ name=name,
73
+ created=datetime.utcnow().date(),
74
+ description=""
75
+ )
76
+ session.add(entry)
77
+ # Optionally store prompt in another table or JSON field
78
+ session.commit()
79
+ return {"message": f"Model '{name}' saved."}
80
+
81
+
82
+ def calculate_drift_handler(params: Dict[str, Any]) -> Dict[str, Any]:
83
+ """
84
+ Placeholder drift calculation: record a new random drift score today.
85
+ params: {model_name: str}
86
+ """
87
+ import random
88
+ name = params.get("model_name")
89
+ score = round(random.uniform(0, 1), 3)
90
+ today = datetime.utcnow().date()
91
+ with SessionLocal() as session:
92
+ entry = DriftEntry(
93
+ model_name=name,
94
+ date=today,
95
+ drift_score=score
96
+ )
97
+ session.add(entry)
98
+ session.commit()
99
+ return {"drift_score": score, "message": f"Drift recorded for '{name}'."}
100
+
101
+
102
+ def get_drift_history_handler(params: Dict[str, Any]) -> List[Dict[str, Any]]:
103
+ """
104
+ Return drift history as list of {date, drift_score} for a model.
105
+ params: {model_name: str}
106
+ """
107
+ name = params.get("model_name")
108
+ with SessionLocal() as session:
109
+ entries = session.query(DriftEntry).filter_by(model_name=name).order_by(DriftEntry.date).all()
110
+ return [
111
+ {"date": e.date.isoformat(), "drift_score": e.drift_score}
112
+ for e in entries
113
+ ]
database_module/models.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # database_module/models.py
2
+ from sqlalchemy import Column, String, Date, Integer, Float, Text
3
+ from .db import Base
4
+
5
+ class ModelEntry(Base):
6
+ __tablename__ = "models"
7
+
8
+ id = Column(Integer, primary_key=True, index=True)
9
+ name = Column(String, unique=True, nullable=False, index=True)
10
+ created = Column(Date, nullable=False)
11
+ description = Column(Text, nullable=True)
12
+
13
+ class DriftEntry(Base):
14
+ __tablename__ = "drift_history"
15
+
16
+ id = Column(Integer, primary_key=True, index=True)
17
+ model_name = Column(String, nullable=False, index=True)
18
+ date = Column(Date, nullable=False)
19
+ drift_score = Column(Float, nullable=False)
drift_detector.sqlite3 ADDED
Binary file (28.7 kB). View file
 
ourllm.py CHANGED
@@ -1,7 +1,47 @@
 
 
 
1
 
2
- def genratequestionnaire(model, capabilities):
3
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
- def gradeanswers(old_answers, new_answers):
7
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+ from typing import List
3
+ import mcp.types as types
4
 
5
+ def genratequestionnaire(model: str, capabilities: str) -> List[types.SamplingMessage]:
6
+ """
7
+ Generate a baseline questionnaire for the given model.
8
+ Returns a list of SamplingMessage instances (role="user") with diagnostic questions.
9
+ """
10
+ questions = [
11
+ f"Model Name: {model}\nPlease confirm your model name.",
12
+ f"Capabilities Overview:\n{capabilities}\nPlease summarize your key capabilities.",
13
+ "Describe a typical use-case scenario that demonstrates these capabilities.",
14
+ ]
15
+ return [
16
+ types.SamplingMessage(
17
+ role="user",
18
+ content=types.TextContent(type="text", text=q)
19
+ )
20
+ for q in questions
21
+ ]
22
 
23
 
24
+ def gradeanswers(old_answers: List[str], new_answers: List[str]) -> List[types.SamplingMessage]:
25
+ """
26
+ Compare the old and new answers to compute a drift score.
27
+ Returns a list with a single SamplingMessage (role="assistant") whose content.text is the drift percentage.
28
+ """
29
+ total = len(old_answers)
30
+ if total == 0:
31
+ drift_pct = 0.0
32
+ else:
33
+ # Count how many answers are sufficiently similar
34
+ similar_count = 0
35
+ for old, new in zip(old_answers, new_answers):
36
+ ratio = difflib.SequenceMatcher(None, old, new).ratio()
37
+ if ratio >= 0.8:
38
+ similar_count += 1
39
+ drift_pct = round((1 - (similar_count / total)) * 100, 2)
40
+
41
+ drift_text = f"{drift_pct}"
42
+ return [
43
+ types.SamplingMessage(
44
+ role="assistant",
45
+ content=types.TextContent(type="text", text=drift_text)
46
+ )
47
+ ]
requirements.txt CHANGED
@@ -2,4 +2,6 @@ gradio
2
  mcp
3
  plotly
4
  asyncio
5
- typing
 
 
 
2
  mcp
3
  plotly
4
  asyncio
5
+ typing
6
+ sqlalchemy
7
+ psycopg2