Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import os | |
| from typing import Any, List, Dict | |
| import mcp.types as types | |
| from mcp import CreateMessageResult | |
| from mcp.server import Server | |
| from mcp.server.stdio import stdio_server | |
| from ourllm import genratequestionnaire, gradeanswers | |
| from database_module import init_db | |
| from database_module import get_all_models_handler, search_models_handler | |
| # Initialize data directory and database | |
| DATA_DIR = "data" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| init_db() | |
| app = Server("mcp-drift-server") | |
| # === Sampling Helper === | |
| async def sample(messages: List[types.SamplingMessage], max_tokens: int = 300) -> CreateMessageResult: | |
| return await app.request_context.session.create_message( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=0.7, | |
| ) | |
| # === Baseline File Helpers === | |
| def get_baseline_path(model_name: str) -> str: | |
| return os.path.join(DATA_DIR, f"{model_name}_baseline.json") | |
| def get_response_path(model_name: str) -> str: | |
| return os.path.join(DATA_DIR, f"{model_name}_latest.json") | |
| # === Tool Manifest === | |
| async def list_tools() -> List[types.Tool]: | |
| return [ | |
| types.Tool( | |
| name="run_initial_diagnostics", | |
| description="Generate and store baseline diagnostics for a connected LLM.", | |
| inputSchema={ | |
| "type": "object", | |
| "properties": { | |
| "model": {"type": "string", "description": "The name of the model to run diagnostics on"}, | |
| "model_capabilities": {"type": "string", "description": "Full description of the model's capabilities"} | |
| }, | |
| "required": ["model", "model_capabilities"] | |
| }, | |
| ), | |
| types.Tool( | |
| name="check_drift", | |
| description="Re-run diagnostics and compare to baseline for drift scoring.", | |
| inputSchema={ | |
| "type": "object", | |
| "properties": {"model": {"type": "string", "description": "The name of the model to run diagnostics on"}}, | |
| "required": ["model"] | |
| }, | |
| ), | |
| types.Tool( | |
| name="get_all_models", | |
| description="Retrieve all registered models from the database.", | |
| inputSchema={"type": "object", "properties": {}, "required": []} | |
| ), | |
| types.Tool( | |
| name="search_models", | |
| description="Search registered models by name.", | |
| inputSchema={ | |
| "type": "object", | |
| "properties": {"query": {"type": "string", "description": "Substring to match model names against"}}, | |
| "required": ["query"] | |
| } | |
| ), | |
| ] | |
| # === Sampling Wrapper === | |
| async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult: | |
| return await app.request_context.session.create_message( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=0.7 | |
| ) | |
| # === Baseline File Paths === | |
| def get_baseline_path(model_name): | |
| return os.path.join(DATA_DIR, f"{model_name}_baseline.json") | |
| def get_response_path(model_name): | |
| return os.path.join(DATA_DIR, f"{model_name}_latest.json") | |
| # === Core Logic === | |
| async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]: | |
| model = arguments["model"] | |
| caps = arguments["model_capabilities"] | |
| # 1. Generate questionnaire | |
| questions = await genratequestionnaire(model, caps) | |
| # 2. Ask the target LLM (client) | |
| answers = await sample(questions) | |
| # 3. Persist baseline | |
| # 1. Ask the server's internal LLM to generate a questionnaire | |
| questions = genratequestionnaire(model, arguments["model_capabilities"]) # Server-side trusted LLM | |
| answers = [] | |
| for q in questions: | |
| a = await sample([q]) | |
| answers.append(a) | |
| # 3. Save Q/A pair | |
| with open(get_baseline_path(model), "w") as f: | |
| json.dump({ | |
| "questions": [m.content.text for m in questions], | |
| "answers": [m.content.text for m in answers] | |
| }, f, indent=2) | |
| return [types.TextContent(type="text", text=f"β Baseline stored for model: {model}")] | |
| async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]: | |
| model = arguments["model"] | |
| base_path = get_baseline_path(model) | |
| # Ensure baseline exists | |
| if not os.path.exists(base_path): | |
| return [types.TextContent(type="text", text=f"β No baseline for model: {model}")] | |
| # Load questions + old answers | |
| with open(base_path) as f: | |
| data = json.load(f) | |
| questions = [ | |
| types.SamplingMessage(role="user", content=types.TextContent(type="text", text=q)) | |
| for q in data["questions"] | |
| ] | |
| old_answers = data["answers"] | |
| # 1. Get fresh answers | |
| new_msgs = await sample(questions) | |
| new_answers = [m.content.text for m in new_msgs] | |
| # 1. Ask the model again | |
| new_answers_msgs = [] | |
| for q in questions: | |
| a = await sample([q]) | |
| new_answers_msgs.append(a) | |
| new_answers = [m.content.text for m in new_answers_msgs] | |
| # 2. Grade for drift | |
| grading = await gradeanswers(old_answers, new_answers) | |
| drift_score = grading[0].content.text.strip() | |
| # 3. Save latest | |
| grading_response = gradeanswers(old_answers, new_answers) | |
| drift_score = grading_response[0].content.text.strip() | |
| # 3. Save the response | |
| with open(get_response_path(model), "w") as f: | |
| json.dump({ | |
| "new_answers": new_answers, | |
| "drift_score": drift_score | |
| }, f, indent=2) | |
| # 4. Alert threshold | |
| try: | |
| score_val = float(drift_score) | |
| alert = "π¨ Significant drift!" if score_val > 50 else "β Drift OK" | |
| except ValueError: | |
| alert = "β οΈ Drift score not numeric" | |
| return [ | |
| types.TextContent(type="text", text=f"Drift score for {model}: {drift_score}"), | |
| types.TextContent(type="text", text=alert) | |
| ] | |
| # === Dispatcher === | |
| async def dispatch_tool(name: str, arguments: Dict[str, Any] | None = None): | |
| if name == "run_initial_diagnostics": | |
| return await run_initial_diagnostics(arguments or {}) | |
| if name == "check_drift": | |
| return await check_drift(arguments or {}) | |
| if name == "get_all_models": | |
| return await get_all_models_handler() | |
| if name == "search_models": | |
| return await search_models_handler(arguments or {}) | |
| raise ValueError(f"Unknown tool: {name}") | |
| # === Entrypoint === | |
| async def main(): | |
| async with stdio_server() as (reader, writer): | |
| await app.run(reader, writer, app.create_initialization_options()) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |