Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import os | |
| from typing import Any | |
| import mcp.types as types | |
| from mcp import CreateMessageResult | |
| from mcp.server import Server | |
| from mcp.server.stdio import stdio_server | |
| from ourllm import genratequestionnaire, gradeanswers | |
| DATA_DIR = "data" | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| app = Server("mcp-drift-server") | |
| registered_models = {} | |
| def get_all_models(): | |
| """Retrieve all registered models.""" | |
| return list(registered_models.keys()) | |
| def search_models(query: str): | |
| """Search registered models by name.""" | |
| return [model for model in registered_models if query.lower() in model.lower()] | |
| def get_model_details(model_name: str): | |
| """Get details of a specific model.""" | |
| return registered_models.get(model_name, None) | |
| def save_model(model_name: str, model_details: dict): | |
| """Save a new model or update an existing one.""" | |
| registered_models[model_name] = model_details | |
| with open(os.path.join(DATA_DIR, "models.json"), "w") as f: | |
| json.dump(registered_models, f, indent=2) | |
| async def list_tools() -> list[types.Tool]: | |
| return [ | |
| types.Tool( | |
| name="run_initial_diagnostics", | |
| description="Generate and store baseline diagnostics for a connected LLM.", | |
| inputSchema={"type":"object", | |
| "properties": { | |
| "model": { | |
| "type": "string", | |
| "description": "The name of the model to run diagnostics on" | |
| }, | |
| "model_capabilities": { | |
| "type": "string", | |
| "description": "Full description of the model's capabilities, including any special features" | |
| } | |
| }, | |
| "required": ["model", "model_capabilities"]}, | |
| ), | |
| types.Tool( | |
| name="check_drift", | |
| description="Re-run diagnostics and compare to baseline for drift scoring.", | |
| inputSchema={"type":"object", | |
| "properties": { | |
| "model": { | |
| "type": "string", | |
| "description": "The name of the model to run diagnostics on" | |
| }, | |
| }, | |
| "required": ["model"]}, | |
| ), | |
| ] | |
| # === Sampling Wrapper === | |
| async def sample(messages: list[types.SamplingMessage], max_tokens=300) -> CreateMessageResult: | |
| return await app.request_context.session.create_message( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=0.7 | |
| ) | |
| # === Baseline File Paths === | |
| def get_baseline_path(model_name): | |
| return os.path.join(DATA_DIR, f"{model_name}_baseline.json") | |
| def get_response_path(model_name): | |
| return os.path.join(DATA_DIR, f"{model_name}_latest.json") | |
| # === Core Logic === | |
| async def run_initial_diagnostics(arguments: dict[str, Any]) -> list[types.TextContent]: | |
| if arguments and "model" in arguments: | |
| model = arguments["model"] | |
| else: | |
| raise(ValueError("Model details is required")) | |
| # 1. Ask the server's internal LLM to generate a questionnaire | |
| questions = await genratequestionnaire(model, arguments["model_capabilities"]) # Server-side trusted LLM | |
| # 2. Send questionnaire to target LLM (i.e., the client) | |
| answers = await sample(questions) # Client model's answers | |
| # 3. Save Q/A pair | |
| with open(get_baseline_path(model), "w") as f: | |
| json.dump({ | |
| "questions": [m.content.text for m in questions], | |
| "answers": [m.content.text for m in answers] | |
| }, f, indent=2) | |
| return [types.TextContent(type="text", text="Baseline stored for model: " + model)] | |
| async def check_drift(arguments: dict[str, str]) -> list[types.TextContent]: | |
| if arguments and "model" in arguments: | |
| model = arguments["model"] | |
| else: | |
| raise (ValueError("Model details is required")) | |
| baseline_path = get_baseline_path(model) | |
| if not os.path.exists(baseline_path): | |
| return [types.TextContent(type="text", text="No baseline exists for model: " + model)] | |
| with open(baseline_path) as f: | |
| data = json.load(f) | |
| questions = [types.SamplingMessage(role="user", content=types.TextContent(type="text", text=q)) for q in | |
| data["questions"]] | |
| old_answers = data["answers"] | |
| # 1. Ask the model again | |
| new_answers_msgs = await sample(questions) | |
| new_answers = [m.content.text for m in new_answers_msgs] | |
| grading_response = await gradeanswers(old_answers, new_answers) | |
| drift_score = grading_response[0].content.text.strip() | |
| # 3. Save the response | |
| with open(get_response_path(model), "w") as f: | |
| json.dump({ | |
| "new_answers": new_answers, | |
| "drift_score": drift_score | |
| }, f, indent=2) | |
| # 4. Optionally alert if high drift | |
| alert = "π¨ Significant drift detected!" if float(drift_score) > 50 else "β Drift within acceptable limits." | |
| return [ | |
| types.TextContent(type="text", text=f"Drift score for {model}: {drift_score}"), | |
| types.TextContent(type="text", text=alert) | |
| ] | |
| async def call_tool(name: str, arguments: dict[str, Any] | None = None): | |
| if name == "run_initial_diagnostics": | |
| return await run_initial_diagnostics(arguments) | |
| elif name == "check_drift": | |
| return await check_drift(arguments) | |
| else: | |
| raise ValueError(f"Unknown tool: {name}") | |
| # === Entrypoint === | |
| async def main(): | |
| async with stdio_server() as streams: | |
| await app.run(streams[0], streams[1], app.create_initialization_options()) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |