Sars6 commited on
Commit
a467728
Β·
1 Parent(s): ec83815

DB integrated + Working on the README.md

Browse files
.idea/dataSources.xml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="DataSourceManagerImpl" format="xml" multifile-model="true">
4
+ <data-source source="LOCAL" name="drift_detector" uuid="88b3f45d-5575-4285-8690-21bff284a779">
5
+ <driver-ref>sqlite.xerial</driver-ref>
6
+ <synchronize>true</synchronize>
7
+ <jdbc-driver>org.sqlite.JDBC</jdbc-driver>
8
+ <jdbc-url>jdbc:sqlite:$PROJECT_DIR$/drift_detector.sqlite3</jdbc-url>
9
+ <working-dir>$ProjectFileDir$</working-dir>
10
+ <libraries>
11
+ <library>
12
+ <url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/xerial/sqlite-jdbc/3.45.1.0/sqlite-jdbc-3.45.1.0.jar</url>
13
+ </library>
14
+ <library>
15
+ <url>file://$APPLICATION_CONFIG_DIR$/jdbc-drivers/Xerial SQLiteJDBC/3.45.1/org/slf4j/slf4j-api/1.7.36/slf4j-api-1.7.36.jar</url>
16
+ </library>
17
+ </libraries>
18
+ </data-source>
19
+ </component>
20
+ </project>
README.md CHANGED
@@ -11,8 +11,14 @@ license: mit
11
  ---
12
 
13
  # Drift Detector
14
- Drift Detector is an MCP server, designed to detect drift in LLM performance over time.
15
- This implementation is intended as a proof of concept and is not intended for production use.
 
 
 
 
 
 
16
 
17
  ## How to run
18
 
@@ -64,5 +70,48 @@ The gradio interface in [app.py](app.py) is an example dashboard which allows us
64
 
65
  ### Drift Detector Server
66
 
67
- The Drift Detector server is implemented using the MCP python SDK
68
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  # Drift Detector
14
+ Drift Detector is an MCP server, designed to detect drift in LLM performance over time by using the power of the **sampling** functionality of MCP.
15
+ This implementation is intended as a **proof of concept** and is **NOT intended** for production use without significant changes.
16
+
17
+ ## The Idea
18
+
19
+ The drift detector is a server that can be connected to any LLM client that supports the MCP sampling functionality.
20
+ It allows you to monitor the performance of your LLM models over time, detecting any drift in their behavior.
21
+ This is particularly useful for applications where the model's performance may change due to various factors, such as changes in the data distribution, model updates, or other external influences.
22
 
23
  ## How to run
24
 
 
70
 
71
  ### Drift Detector Server
72
 
73
+ The Drift Detector server is implemented using the MCP python SDK.
74
+ It exposes the following tools:
75
+
76
+ 1. **run_initial_diagnostics**
77
+ - **Purpose**: Establishes a baseline for model behavior using adaptive sampling techniques
78
+ - **Parameters**:
79
+ - `model`: The name of the model to run diagnostics on
80
+ - `model_capabilities`: Full description of the model's capabilities and special features
81
+ - **Sampling Process**:
82
+ - First generates a tailored questionnaire based on model-specific capabilities
83
+ - Collects responses by sampling the target model with controlled parameters (temperature=0.7)
84
+ - Each question is processed individually to ensure proper context isolation
85
+ - Baseline samples are stored as paired question-answer JSON records for future comparison
86
+ - **Output**: Confirmation message indicating successful baseline creation
87
+
88
+ 2. **check_drift**
89
+ - **Purpose**: Measures potential drift by comparative sampling against the baseline
90
+ - **Parameters**:
91
+ - `model`: The name of the model to check for drift
92
+ - **Sampling Process**:
93
+ - Retrieves the original questions from the baseline
94
+ - Re-samples the model with identical questions using the same sampling parameters
95
+ - Maintains consistent context conditions to ensure fair comparison
96
+ - Uses differential analysis to compare semantic and functional differences between sample sets
97
+ - **Drift Evaluation**:
98
+ - Calculates a numerical drift score based on answer divergence
99
+ - Provides threshold-based alerts when drift exceeds acceptable limits (score > 50)
100
+ - Stores the latest sample responses for audit and trend analysis
101
+
102
+ ## Flow
103
+
104
+ The intended flow is as follows:
105
+ 1. When the client contacts the server for the first time, it will run the `run_initial_diagnostics` tool.
106
+ 2. The server will generate a tailored questionnaire based on the model's capabilities.
107
+ 3. This questionnaire will be used to collect responses from the model, establishing a baseline for future comparisons.
108
+ 4. Once the baseline is established, the server will store the paired question-answer JSON records.
109
+ 5. The client can then use the `check_drift` tool to measure potential drift in the model's performance.
110
+ 6. The server will retrieve the original questions from the baseline and re-sample the model with identical questions.
111
+ 7. The server will maintain consistent context conditions to ensure fair comparison.
112
+ 8. If significant drift is detected (score > 50), the server will provide an alert and store the latest sample responses for audit and trend analysis.
113
+ 9. The client can visualize the drift data through the Gradio interface, allowing users to track changes in model performance over time.
114
+
115
+
116
+
117
+
database_module/__init__.py CHANGED
@@ -1,3 +1,10 @@
1
  #database_module/__init__.py
2
  from .db import init_db
3
- from .mcp_tools import get_all_models_handler, search_models_handler
 
 
 
 
 
 
 
 
1
  #database_module/__init__.py
2
  from .db import init_db
3
+ from .mcp_tools import (
4
+ get_all_models_handler,
5
+ search_models_handler,
6
+ save_diagnostic_data,
7
+ get_baseline_diagnostics,
8
+ save_drift_score,
9
+ register_model_with_capabilities
10
+ )
database_module/db.py CHANGED
@@ -1,6 +1,6 @@
1
  #database_module/db.py
2
  import os
3
- from sqlalchemy import create_engine
4
  from sqlalchemy.ext.declarative import declarative_base
5
  from sqlalchemy.orm import sessionmaker
6
 
@@ -26,9 +26,25 @@ SessionLocal = sessionmaker(
26
  )
27
  Base = declarative_base()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def init_db():
30
  """
31
  Create tables if they don't exist.
32
  Call this once at application startup.
33
  """
34
  Base.metadata.create_all(bind=engine)
 
 
1
  #database_module/db.py
2
  import os
3
+ from sqlalchemy import create_engine, inspect, text
4
  from sqlalchemy.ext.declarative import declarative_base
5
  from sqlalchemy.orm import sessionmaker
6
 
 
26
  )
27
  Base = declarative_base()
28
 
29
+ def apply_migrations():
30
+ """
31
+ Apply any necessary migrations to existing tables.
32
+ """
33
+ with engine.connect() as conn:
34
+ # Check if the models table exists and has the capabilities column
35
+ inspector = inspect(engine)
36
+ if "models" in inspector.get_table_names():
37
+ columns = [col['name'] for col in inspector.get_columns('models')]
38
+ if "capabilities" not in columns:
39
+ # Add capabilities column to models table
40
+ conn.execute(text("ALTER TABLE models ADD COLUMN capabilities TEXT"))
41
+ conn.commit()
42
+ print("Migration: Added capabilities column to models table")
43
+
44
  def init_db():
45
  """
46
  Create tables if they don't exist.
47
  Call this once at application startup.
48
  """
49
  Base.metadata.create_all(bind=engine)
50
+ apply_migrations()
database_module/mcp_tools.py CHANGED
@@ -2,9 +2,10 @@
2
  from sqlalchemy.orm import Session
3
  from sqlalchemy import or_
4
  from .db import SessionLocal
5
- from .models import ModelEntry, DriftEntry
6
  from datetime import datetime
7
- from typing import Any, Dict, List
 
8
 
9
 
10
  def get_all_models_handler(_: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -111,3 +112,98 @@ def get_drift_history_handler(params: Dict[str, Any]) -> List[Dict[str, Any]]:
111
  {"date": e.date.isoformat(), "drift_score": e.drift_score}
112
  for e in entries
113
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from sqlalchemy.orm import Session
3
  from sqlalchemy import or_
4
  from .db import SessionLocal
5
+ from .models import ModelEntry, DriftEntry, DiagnosticData
6
  from datetime import datetime
7
+ from typing import Any, Dict, List, Optional
8
+ import json
9
 
10
 
11
  def get_all_models_handler(_: Dict[str, Any]) -> List[Dict[str, Any]]:
 
112
  {"date": e.date.isoformat(), "drift_score": e.drift_score}
113
  for e in entries
114
  ]
115
+
116
+
117
+ # === New functions for drift detection database operations ===
118
+
119
+ def save_diagnostic_data(
120
+ model_name: str,
121
+ questions: list,
122
+ answers: list,
123
+ is_baseline: bool = False
124
+ ) -> None:
125
+ """
126
+ Save diagnostic questions and answers to the database
127
+ """
128
+ with SessionLocal() as session:
129
+ # Check if model exists, create if not
130
+ model = session.query(ModelEntry).filter_by(name=model_name).first()
131
+ if not model:
132
+ model = ModelEntry(
133
+ name=model_name,
134
+ created=datetime.utcnow().date(),
135
+ description=""
136
+ )
137
+ session.add(model)
138
+
139
+ # Create new diagnostic entry
140
+ diagnostic = DiagnosticData(
141
+ model_name=model_name,
142
+ is_baseline=1 if is_baseline else 0,
143
+ questions=questions,
144
+ answers=answers,
145
+ created=datetime.utcnow()
146
+ )
147
+ session.add(diagnostic)
148
+ session.commit()
149
+
150
+
151
+ def get_baseline_diagnostics(model_name: str) -> Optional[Dict[str, Any]]:
152
+ """
153
+ Retrieve baseline diagnostics for a model
154
+ """
155
+ with SessionLocal() as session:
156
+ baseline = session.query(DiagnosticData)\
157
+ .filter_by(model_name=model_name, is_baseline=1)\
158
+ .order_by(DiagnosticData.created.desc())\
159
+ .first()
160
+
161
+ if not baseline:
162
+ return None
163
+
164
+ return {
165
+ "questions": baseline.questions,
166
+ "answers": baseline.answers,
167
+ "created": baseline.created.isoformat()
168
+ }
169
+
170
+
171
+ def save_drift_score(model_name: str, drift_score: str) -> None:
172
+ """
173
+ Save drift score to database
174
+ """
175
+ # Try to convert score to float if possible
176
+ try:
177
+ score_float = float(drift_score)
178
+ except ValueError:
179
+ score_float = None
180
+
181
+ with SessionLocal() as session:
182
+ entry = DriftEntry(
183
+ model_name=model_name,
184
+ date=datetime.utcnow(),
185
+ drift_score=score_float
186
+ )
187
+ session.add(entry)
188
+ session.commit()
189
+
190
+
191
+ def register_model_with_capabilities(model_name: str, capabilities: str) -> None:
192
+ """
193
+ Register a model with capabilities or update if already exists
194
+ """
195
+ with SessionLocal() as session:
196
+ model = session.query(ModelEntry).filter_by(name=model_name).first()
197
+
198
+ if model:
199
+ model.capabilities = capabilities
200
+ else:
201
+ model = ModelEntry(
202
+ name=model_name,
203
+ created=datetime.utcnow().date(),
204
+ capabilities=capabilities,
205
+ description=""
206
+ )
207
+ session.add(model)
208
+
209
+ session.commit()
database_module/models.py CHANGED
@@ -1,5 +1,6 @@
1
  # database_module/models.py
2
- from sqlalchemy import Column, String, Date, Integer, Float, Text
 
3
  from .db import Base
4
 
5
  class ModelEntry(Base):
@@ -9,11 +10,22 @@ class ModelEntry(Base):
9
  name = Column(String, unique=True, nullable=False, index=True)
10
  created = Column(Date, nullable=False)
11
  description = Column(Text, nullable=True)
 
12
 
13
  class DriftEntry(Base):
14
  __tablename__ = "drift_history"
15
 
16
  id = Column(Integer, primary_key=True, index=True)
17
  model_name = Column(String, nullable=False, index=True)
18
- date = Column(Date, nullable=False)
19
- drift_score = Column(Float, nullable=False)
 
 
 
 
 
 
 
 
 
 
 
1
  # database_module/models.py
2
+ from sqlalchemy import Column, String, Date, Integer, Float, Text, JSON, DateTime
3
+ from datetime import datetime
4
  from .db import Base
5
 
6
  class ModelEntry(Base):
 
10
  name = Column(String, unique=True, nullable=False, index=True)
11
  created = Column(Date, nullable=False)
12
  description = Column(Text, nullable=True)
13
+ capabilities = Column(Text, nullable=True) # Added to store model_capabilities
14
 
15
  class DriftEntry(Base):
16
  __tablename__ = "drift_history"
17
 
18
  id = Column(Integer, primary_key=True, index=True)
19
  model_name = Column(String, nullable=False, index=True)
20
+ date = Column(DateTime, nullable=False, default=datetime.utcnow)
21
+ drift_score = Column(Float, nullable=True)
22
+
23
+ class DiagnosticData(Base):
24
+ __tablename__ = "diagnostic_data"
25
+
26
+ id = Column(Integer, primary_key=True, index=True)
27
+ model_name = Column(String, nullable=False, index=True)
28
+ created = Column(DateTime, nullable=False, default=datetime.utcnow)
29
+ is_baseline = Column(Integer, nullable=False, default=0) # 0=latest, 1=baseline
30
+ questions = Column(JSON, nullable=True)
31
+ answers = Column(JSON, nullable=True)
drift_detector.sqlite3 CHANGED
Binary files a/drift_detector.sqlite3 and b/drift_detector.sqlite3 differ
 
requirements.txt CHANGED
@@ -4,5 +4,5 @@ plotly
4
  asyncio
5
  typing
6
  sqlalchemy
7
- psycopg2
8
  fast-agent-mcp
 
4
  asyncio
5
  typing
6
  sqlalchemy
7
+ psycopg2-binary
8
  fast-agent-mcp
server.py CHANGED
@@ -10,7 +10,14 @@ from mcp.server.stdio import stdio_server
10
 
11
  from ourllm import genratequestionnaire, gradeanswers
12
  from database_module import init_db
13
- from database_module import get_all_models_handler, search_models_handler
 
 
 
 
 
 
 
14
 
15
  # Initialize data directory and database
16
  DATA_DIR = "data"
@@ -19,25 +26,6 @@ init_db()
19
 
20
  app = Server("mcp-drift-server")
21
 
22
-
23
- # === Sampling Helper ===
24
- async def sample(messages: List[types.SamplingMessage], max_tokens: int = 300) -> CreateMessageResult:
25
- return await app.request_context.session.create_message(
26
- messages=messages,
27
- max_tokens=max_tokens,
28
- temperature=0.7,
29
- )
30
-
31
-
32
- # === Baseline File Helpers ===
33
- def get_baseline_path(model_name: str) -> str:
34
- return os.path.join(DATA_DIR, f"{model_name}_baseline.json")
35
-
36
-
37
- def get_response_path(model_name: str) -> str:
38
- return os.path.join(DATA_DIR, f"{model_name}_latest.json")
39
-
40
-
41
  # === Tool Manifest ===
42
  @app.list_tools()
43
  async def list_tools() -> List[types.Tool]:
@@ -79,8 +67,6 @@ async def list_tools() -> List[types.Tool]:
79
  ),
80
  ]
81
 
82
-
83
-
84
  # === Sampling Wrapper ===
85
  async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult:
86
  return await app.request_context.session.create_message(
@@ -89,101 +75,70 @@ async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> Creat
89
  temperature=0.7
90
  )
91
 
92
-
93
- # === Baseline File Paths ===
94
- def get_baseline_path(model_name):
95
- return os.path.join(DATA_DIR, f"{model_name}_baseline.json")
96
-
97
-
98
- def get_response_path(model_name):
99
- return os.path.join(DATA_DIR, f"{model_name}_latest.json")
100
-
101
-
102
-
103
  # === Core Logic ===
104
  async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]:
105
  model = arguments["model"]
106
- caps = arguments["model_capabilities"]
107
-
108
- # 1. Generate questionnaire
109
- questions = await genratequestionnaire(model, caps)
110
-
111
- # 2. Ask the target LLM (client)
112
- answers = await sample(questions)
113
-
114
-
115
- # 3. Persist baseline
116
 
117
  # 1. Ask the server's internal LLM to generate a questionnaire
118
-
119
- questions = genratequestionnaire(model, arguments["model_capabilities"]) # Server-side trusted LLM
120
  answers = []
121
  for q in questions:
122
  a = await sample([q])
123
  answers.append(a)
124
 
125
- # 3. Save Q/A pair
126
-
127
- with open(get_baseline_path(model), "w") as f:
128
- json.dump({
129
- "questions": [m.content.text for m in questions],
130
- "answers": [m.content.text for m in answers]
131
- }, f, indent=2)
 
132
 
133
  return [types.TextContent(type="text", text=f"βœ… Baseline stored for model: {model}")]
134
 
135
-
136
  async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]:
137
- model = arguments["model"]
138
- base_path = get_baseline_path(model)
 
 
139
 
140
  # Ensure baseline exists
141
- if not os.path.exists(base_path):
142
  return [types.TextContent(type="text", text=f"❌ No baseline for model: {model}")]
143
 
144
- # Load questions + old answers
145
- with open(base_path) as f:
146
- data = json.load(f)
147
- questions = [
148
  types.SamplingMessage(role="user", content=types.TextContent(type="text", text=q))
149
- for q in data["questions"]
150
  ]
151
- old_answers = data["answers"]
152
 
153
-
154
- # 1. Get fresh answers
155
- new_msgs = await sample(questions)
156
- new_answers = [m.content.text for m in new_msgs]
157
-
158
- # 1. Ask the model again
159
  new_answers_msgs = []
160
  for q in questions:
161
  a = await sample([q])
162
  new_answers_msgs.append(a)
163
  new_answers = [m.content.text for m in new_answers_msgs]
164
 
165
-
166
- # 2. Grade for drift
167
- grading = await gradeanswers(old_answers, new_answers)
168
- drift_score = grading[0].content.text.strip()
169
-
170
-
171
- # 3. Save latest
172
  grading_response = gradeanswers(old_answers, new_answers)
173
  drift_score = grading_response[0].content.text.strip()
174
 
175
- # 3. Save the response
176
-
177
- with open(get_response_path(model), "w") as f:
178
- json.dump({
179
- "new_answers": new_answers,
180
- "drift_score": drift_score
181
- }, f, indent=2)
 
182
 
183
- # 4. Alert threshold
184
  try:
185
  score_val = float(drift_score)
186
- alert = "🚨 Significant drift!" if score_val > 50 else "βœ… Drift OK"
187
  except ValueError:
188
  alert = "⚠️ Drift score not numeric"
189
 
@@ -192,26 +147,52 @@ async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]:
192
  types.TextContent(type="text", text=alert)
193
  ]
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  # === Dispatcher ===
197
  @app.call_tool()
198
  async def dispatch_tool(name: str, arguments: Dict[str, Any] | None = None):
199
  if name == "run_initial_diagnostics":
200
- return await run_initial_diagnostics(arguments or {})
201
- if name == "check_drift":
202
- return await check_drift(arguments or {})
203
- if name == "get_all_models":
204
- return await get_all_models_handler()
205
- if name == "search_models":
206
- return await search_models_handler(arguments or {})
207
- raise ValueError(f"Unknown tool: {name}")
208
-
209
 
210
  # === Entrypoint ===
211
  async def main():
212
  async with stdio_server() as (reader, writer):
213
  await app.run(reader, writer, app.create_initialization_options())
214
 
215
-
216
  if __name__ == "__main__":
217
  asyncio.run(main())
 
10
 
11
  from ourllm import genratequestionnaire, gradeanswers
12
  from database_module import init_db
13
+ from database_module import (
14
+ get_all_models_handler,
15
+ search_models_handler,
16
+ save_diagnostic_data,
17
+ get_baseline_diagnostics,
18
+ save_drift_score,
19
+ register_model_with_capabilities
20
+ )
21
 
22
  # Initialize data directory and database
23
  DATA_DIR = "data"
 
26
 
27
  app = Server("mcp-drift-server")
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # === Tool Manifest ===
30
  @app.list_tools()
31
  async def list_tools() -> List[types.Tool]:
 
67
  ),
68
  ]
69
 
 
 
70
  # === Sampling Wrapper ===
71
  async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult:
72
  return await app.request_context.session.create_message(
 
75
  temperature=0.7
76
  )
77
 
 
 
 
 
 
 
 
 
 
 
 
78
  # === Core Logic ===
79
  async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]:
80
  model = arguments["model"]
81
+ capabilities = arguments["model_capabilities"]
 
 
 
 
 
 
 
 
 
82
 
83
  # 1. Ask the server's internal LLM to generate a questionnaire
84
+ questions = genratequestionnaire(model, capabilities) # Server-side trusted LLM
 
85
  answers = []
86
  for q in questions:
87
  a = await sample([q])
88
  answers.append(a)
89
 
90
+ # 2. Save the model capabilities and questions/answers to database
91
+ register_model_with_capabilities(model, capabilities)
92
+ save_diagnostic_data(
93
+ model_name=model,
94
+ questions=[m.content.text for m in questions],
95
+ answers=[m.content.text for m in answers],
96
+ is_baseline=True
97
+ )
98
 
99
  return [types.TextContent(type="text", text=f"βœ… Baseline stored for model: {model}")]
100
 
 
101
  async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]:
102
+ model = arguments["model"]
103
+
104
+ # Get baseline from database
105
+ baseline = get_baseline_diagnostics(model)
106
 
107
  # Ensure baseline exists
108
+ if not baseline:
109
  return [types.TextContent(type="text", text=f"❌ No baseline for model: {model}")]
110
 
111
+ # Convert questions to sampling messages
112
+ questions = [
 
 
113
  types.SamplingMessage(role="user", content=types.TextContent(type="text", text=q))
114
+ for q in baseline["questions"]
115
  ]
116
+ old_answers = baseline["answers"]
117
 
118
+ # Ask the model again
 
 
 
 
 
119
  new_answers_msgs = []
120
  for q in questions:
121
  a = await sample([q])
122
  new_answers_msgs.append(a)
123
  new_answers = [m.content.text for m in new_answers_msgs]
124
 
125
+ # Grade the answers and get a drift score
 
 
 
 
 
 
126
  grading_response = gradeanswers(old_answers, new_answers)
127
  drift_score = grading_response[0].content.text.strip()
128
 
129
+ # Save the latest responses and drift score to database
130
+ save_diagnostic_data(
131
+ model_name=model,
132
+ questions=baseline["questions"],
133
+ answers=new_answers,
134
+ is_baseline=False
135
+ )
136
+ save_drift_score(model, drift_score)
137
 
138
+ # Alert threshold
139
  try:
140
  score_val = float(drift_score)
141
+ alert = "🚨 Significant drift!" if score_val > 50 else "βœ… Drift OK"
142
  except ValueError:
143
  alert = "⚠️ Drift score not numeric"
144
 
 
147
  types.TextContent(type="text", text=alert)
148
  ]
149
 
150
+ # Database tool handlers
151
+ async def get_all_models_handler_async(_: Dict[str, Any]) -> List[types.TextContent]:
152
+ models = get_all_models_handler({})
153
+ if not models:
154
+ return [types.TextContent(type="text", text="No models registered.")]
155
+
156
+ model_list = "\n".join([f"β€’ {m['name']} - {m['description']}" for m in models])
157
+ return [types.TextContent(
158
+ type="text",
159
+ text=f"Registered models:\n{model_list}"
160
+ )]
161
+
162
+ async def search_models_handler_async(arguments: Dict[str, Any]) -> List[types.TextContent]:
163
+ query = arguments.get("query", "")
164
+ models = search_models_handler({"search_term": query})
165
+
166
+ if not models:
167
+ return [types.TextContent(
168
+ type="text",
169
+ text=f"No models found matching '{query}'."
170
+ )]
171
+
172
+ model_list = "\n".join([f"β€’ {m['name']} - {m['description']}" for m in models])
173
+ return [types.TextContent(
174
+ type="text",
175
+ text=f"Models matching '{query}':\n{model_list}"
176
+ )]
177
 
178
  # === Dispatcher ===
179
  @app.call_tool()
180
  async def dispatch_tool(name: str, arguments: Dict[str, Any] | None = None):
181
  if name == "run_initial_diagnostics":
182
+ return await run_initial_diagnostics(arguments)
183
+ elif name == "check_drift":
184
+ return await check_drift(arguments)
185
+ elif name == "get_all_models":
186
+ return await get_all_models_handler_async(arguments or {})
187
+ elif name == "search_models":
188
+ return await search_models_handler_async(arguments or {})
189
+ else:
190
+ raise ValueError(f"Unknown tool: {name}")
191
 
192
  # === Entrypoint ===
193
  async def main():
194
  async with stdio_server() as (reader, writer):
195
  await app.run(reader, writer, app.create_initialization_options())
196
 
 
197
  if __name__ == "__main__":
198
  asyncio.run(main())