HarshBhati commited on
Commit
ec83815
·
2 Parent(s): 2dd4294 c8dd6f4

integrated db.py with server.py

Browse files
Files changed (2) hide show
  1. README.md +56 -1
  2. server.py +47 -0
README.md CHANGED
@@ -10,4 +10,59 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: mit
11
  ---
12
 
13
+ # Drift Detector
14
+ Drift Detector is an MCP server, designed to detect drift in LLM performance over time.
15
+ This implementation is intended as a proof of concept and is not intended for production use.
16
+
17
+ ## How to run
18
+
19
+ To run the Drift Detector, you need to have Python installed on your machine. Follow these steps:
20
+
21
+ 1. Clone the repository:
22
+ ```bash
23
+ git clone https://github.com/saranshhalwai/drift-detector
24
+ cd drift-detector
25
+ ```
26
+ 2. Install the required dependencies:
27
+ ```bash
28
+ pip install -r requirements.txt
29
+ ```
30
+ 3. Start the server:
31
+ ```bash
32
+ gradio app.py
33
+ ```
34
+ 4. Open your web browser and navigate to `http://localhost:7860` to access the Drift Detector interface.
35
+
36
+ ## Interface
37
+
38
+ The interface consists of the following components:
39
+ - **Model Selection** - A panel allowing you to:
40
+ - Select models from a dropdown list
41
+ - Search for models by name or description
42
+ - Create new models with custom system prompts
43
+ - Enhance prompts with AI assistance
44
+
45
+ - **Model Operations** - A tabbed interface with:
46
+ - **Chatbot** - Interact with the selected model through a conversational interface
47
+ - **Drift Analysis** - Analyze and visualize model drift over time, including:
48
+ - Calculate new drift scores for the selected model
49
+ - View historical drift data in JSON format
50
+ - Visualize drift trends through interactive charts
51
+
52
+ The drift detection functionality allows you to track changes in model performance over time, which is essential for monitoring and maintaining model quality.
53
+
54
+ ## Under the Hood
55
+
56
+ Our GitHub repo consists of two main components:
57
+
58
+ - **Drift Detector Server**
59
+ A low-level MCP server that detects drift in LLM performance of the connected client.
60
+ - **Target Client**
61
+ A client implemented using the fast-agent library, which connects to the Drift Detector server and demonstrates it's functionality.
62
+
63
+ The gradio interface in [app.py](app.py) is an example dashboard which allows users to interact with the Drift Detector server and visualize drift data.
64
+
65
+ ### Drift Detector Server
66
+
67
+ The Drift Detector server is implemented using the MCP python SDK
68
+
server.py CHANGED
@@ -80,6 +80,26 @@ async def list_tools() -> List[types.Tool]:
80
  ]
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # === Core Logic ===
84
  async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]:
85
  model = arguments["model"]
@@ -91,7 +111,19 @@ async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextC
91
  # 2. Ask the target LLM (client)
92
  answers = await sample(questions)
93
 
 
94
  # 3. Persist baseline
 
 
 
 
 
 
 
 
 
 
 
95
  with open(get_baseline_path(model), "w") as f:
96
  json.dump({
97
  "questions": [m.content.text for m in questions],
@@ -118,15 +150,30 @@ async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]:
118
  ]
119
  old_answers = data["answers"]
120
 
 
121
  # 1. Get fresh answers
122
  new_msgs = await sample(questions)
123
  new_answers = [m.content.text for m in new_msgs]
124
 
 
 
 
 
 
 
 
 
125
  # 2. Grade for drift
126
  grading = await gradeanswers(old_answers, new_answers)
127
  drift_score = grading[0].content.text.strip()
128
 
 
129
  # 3. Save latest
 
 
 
 
 
130
  with open(get_response_path(model), "w") as f:
131
  json.dump({
132
  "new_answers": new_answers,
 
80
  ]
81
 
82
 
83
+
84
+ # === Sampling Wrapper ===
85
+ async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult:
86
+ return await app.request_context.session.create_message(
87
+ messages=messages,
88
+ max_tokens=max_tokens,
89
+ temperature=0.7
90
+ )
91
+
92
+
93
+ # === Baseline File Paths ===
94
+ def get_baseline_path(model_name):
95
+ return os.path.join(DATA_DIR, f"{model_name}_baseline.json")
96
+
97
+
98
+ def get_response_path(model_name):
99
+ return os.path.join(DATA_DIR, f"{model_name}_latest.json")
100
+
101
+
102
+
103
  # === Core Logic ===
104
  async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]:
105
  model = arguments["model"]
 
111
  # 2. Ask the target LLM (client)
112
  answers = await sample(questions)
113
 
114
+
115
  # 3. Persist baseline
116
+
117
+ # 1. Ask the server's internal LLM to generate a questionnaire
118
+
119
+ questions = genratequestionnaire(model, arguments["model_capabilities"]) # Server-side trusted LLM
120
+ answers = []
121
+ for q in questions:
122
+ a = await sample([q])
123
+ answers.append(a)
124
+
125
+ # 3. Save Q/A pair
126
+
127
  with open(get_baseline_path(model), "w") as f:
128
  json.dump({
129
  "questions": [m.content.text for m in questions],
 
150
  ]
151
  old_answers = data["answers"]
152
 
153
+
154
  # 1. Get fresh answers
155
  new_msgs = await sample(questions)
156
  new_answers = [m.content.text for m in new_msgs]
157
 
158
+ # 1. Ask the model again
159
+ new_answers_msgs = []
160
+ for q in questions:
161
+ a = await sample([q])
162
+ new_answers_msgs.append(a)
163
+ new_answers = [m.content.text for m in new_answers_msgs]
164
+
165
+
166
  # 2. Grade for drift
167
  grading = await gradeanswers(old_answers, new_answers)
168
  drift_score = grading[0].content.text.strip()
169
 
170
+
171
  # 3. Save latest
172
+ grading_response = gradeanswers(old_answers, new_answers)
173
+ drift_score = grading_response[0].content.text.strip()
174
+
175
+ # 3. Save the response
176
+
177
  with open(get_response_path(model), "w") as f:
178
  json.dump({
179
  "new_answers": new_answers,