NeerajAhire commited on
Commit
3b3be44
·
verified ·
1 Parent(s): 57be78b

Upload test.py

Browse files
Files changed (1) hide show
  1. test.py +500 -0
test.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # test.py — Agentic logic using OpenAI + MCP tools (langchain_core for parsing)
3
+
4
+ import os
5
+ import json
6
+ from typing import Any, Dict, Optional, List, Literal, Type
7
+
8
+ from pydantic import BaseModel, ValidationError
9
+ from openai import OpenAI
10
+ from langchain_core.output_parsers import PydanticOutputParser # ← requested parser
11
+
12
+ # -------------------- OpenAI setup --------------------
13
+ OAI_MODEL = os.getenv("OAI_MODEL", "gpt-4o-mini")
14
+ client_oai = OpenAI(api_key="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")
15
+
16
+
17
+ def _format_history_for_context(
18
+ conversation: List[Dict[str, str]],
19
+ max_turns: int = 8
20
+ ) -> str:
21
+ """
22
+ Convert the last N messages from the session into a compact context string.
23
+ Expected item format: {"role": "user"|"assistant", "content": "..."}.
24
+ """
25
+ if not conversation:
26
+ return ""
27
+ window = conversation[-max_turns:]
28
+ lines = []
29
+ for m in window:
30
+ role = m.get("role", "user")
31
+ content = m.get("content", "").strip()
32
+ if not content:
33
+ continue
34
+ if role == "user":
35
+ lines.append(f"User: {content}")
36
+ else:
37
+ lines.append(f"Assistant: {content}")
38
+ return "\n".join(lines)
39
+
40
+
41
+
42
+ def llm_invoke(
43
+ prompt: str,
44
+ system: str = "You are a helpful assistant. Return JSON when requested.",
45
+ temperature: float = 0.0,
46
+ ) -> str:
47
+ """
48
+ Invoke OpenAI Chat Completions for planning/intent classification (low temperature).
49
+ """
50
+ resp = client_oai.chat.completions.create(
51
+ model=OAI_MODEL,
52
+ messages=[
53
+ {"role": "system", "content": system},
54
+ {"role": "user", "content": prompt},
55
+ ],
56
+ temperature=temperature,
57
+ )
58
+ return resp.choices[0].message.content
59
+
60
+ # -------------------- Pydantic models --------------------
61
+ class IntentSpec(BaseModel):
62
+ in_scope: bool
63
+ intent: Literal["in_scope", "out_of_scope", "chit_chat"]
64
+ reason: Optional[str] = None
65
+
66
+ class SubQuery(BaseModel):
67
+ id: str
68
+ query: str
69
+ tool_name: Literal["ask_excel", "ask_pdf", "ask_link"]
70
+ required_params: Dict[str, Any]
71
+ depends_on: List[str] = []
72
+
73
+ class PlanResponse(BaseModel):
74
+ subqueries: List[SubQuery]
75
+
76
+ class ContextEnhancer(BaseModel):
77
+ answer_found: bool
78
+ needs_enhancement: bool
79
+ enhanced_query: Optional[str] = None
80
+ cached_answer: Optional[str] = None
81
+ reason: Optional[str] = None
82
+
83
+ # -------------------- JSON parsing via langchain_core --------------------
84
+ def _safe_json(text: str) -> str:
85
+ """
86
+ Heuristic sanitizer: strip code fences and extract the main JSON block
87
+ to help PydanticOutputParser if the model adds extra text.
88
+ """
89
+ t = text.strip()
90
+ if t.startswith("```"):
91
+ # Remove triple backtick fences; allow optional 'json' hint
92
+ t = t.strip("`").strip()
93
+ if t.lower().startswith("json"):
94
+ t = t[4:].strip()
95
+ # Try direct JSON
96
+ try:
97
+ json.loads(t)
98
+ return t
99
+ except Exception:
100
+ pass
101
+ # Fallback: find first '{' and last '}'
102
+ start = t.find("{")
103
+ end = t.rfind("}")
104
+ if start != -1 and end != -1 and end > start:
105
+ return t[start : end + 1]
106
+ return text
107
+
108
+ def parse_response(text: str, model_spec: Type[BaseModel]) -> BaseModel:
109
+ """
110
+ Parse into a Pydantic model using langchain_core's PydanticOutputParser,
111
+ with a robust fallback to standard json+pydantic if needed.
112
+ """
113
+ parser = PydanticOutputParser(pydantic_object=model_spec)
114
+ # First try parser.parse() directly
115
+ try:
116
+ return parser.parse(text)
117
+ except Exception:
118
+ pass
119
+ # Fallback: sanitize and try again
120
+ try:
121
+ return parser.parse(_safe_json(text))
122
+ except Exception:
123
+ # Last fallback: manual pydantic construction
124
+ data = json.loads(_safe_json(text))
125
+ return model_spec(**data)
126
+
127
+ # -------------------- Prompts (intent + planning) --------------------
128
+
129
+ '''
130
+ def intent_prompt(query: str, available_iits: List = [], available_branches: List = [], years: List = []) -> str:
131
+ parser = PydanticOutputParser(pydantic_object=IntentSpec)
132
+ fmt = parser.get_format_instructions() # <- tells the LLM the exact JSON keys/types
133
+
134
+ return f"""You are an intent classifier for a JOSAA Counseling Assistant.
135
+
136
+ Supported IITs: {', '.join(available_iits)}
137
+ Supported Branches: {', '.join(available_branches)}
138
+ Available Data: opening/closing ranks ({', '.join(years)}), curriculum, NIRF, placements/faculty/research/facilities.
139
+
140
+ Classify the user's message into EXACTLY ONE of:
141
+ - "chit_chat"
142
+ - "in_scope"
143
+ - "out_of_scope"
144
+
145
+ Rules:
146
+ - "chit_chat" for greetings/small talk (hi/hello/how are you/what can you do).
147
+ - "in_scope" for queries about SUPPORTED IITs/branches, counseling, ranks/cutoffs, courses, curriculum, NIRF, placements, faculty, research, alumni/distinguished alumni and campus facilities.
148
+ - "out_of_scope" otherwise.
149
+
150
+ Return ONLY a JSON object following these instructions:
151
+ {fmt}
152
+
153
+ User query: "{query}"
154
+ """.strip()
155
+ '''
156
+
157
+ def intent_prompt(
158
+ query: str,
159
+ available_iits: List = [],
160
+ available_branches: List = [],
161
+ years: List = [],
162
+ conversation_context: str = "" # NEW
163
+ ) -> str:
164
+ parser = PydanticOutputParser(pydantic_object=IntentSpec)
165
+ fmt = parser.get_format_instructions()
166
+
167
+ convo = f"\n\nRecent conversation:\n{conversation_context}\n\n" if conversation_context else "\n\n"
168
+ return f"""You are an intent classifier for a JOSAA Counseling Assistant.
169
+
170
+ Supported IITs: {', '.join(available_iits)}
171
+ Supported Branches: {', '.join(available_branches)}
172
+ Available Data: opening/closing ranks ({', '.join(years)}), curriculum, NIRF, placements/faculty/research/facilities.{convo}
173
+ Classify the user's message into EXACTLY ONE of:
174
+ - "chit_chat"
175
+ - "in_scope"
176
+ - "out_of_scope"
177
+
178
+ Rules:
179
+ - "chit_chat" for greetings/small talk (hi/hello/how are you/what can you do).
180
+ - "in_scope" for queries about SUPPORTED IITs/branches, counseling, ranks/cutoffs, courses, curriculum, NIRF, placements, faculty, research, alumni/distinguished alumni and campus facilities.
181
+ - "out_of_scope" otherwise.
182
+
183
+ Return ONLY a JSON object following these instructions:
184
+ {fmt}
185
+
186
+ User query: "{query}"
187
+ """.strip()
188
+
189
+ '''
190
+ def planning_prompt(query: str, available_iits: List = [], available_branches: List = [], years: List = []) -> str:
191
+ parser = PydanticOutputParser(pydantic_object=PlanResponse)
192
+ fmt = parser.get_format_instructions()
193
+ return f"""You are a query planner for a JEE counseling assistant.
194
+
195
+ AVAILABLE TOOLS:
196
+ - ask_excel — ranks/cutoffs; params may include iit_name, branch, year
197
+ - ask_pdf — curriculum/NIRF; params may include iit_name, branch
198
+ - ask_link — placements/faculty/research/facilities; params may include iit_name, branch, or a URL
199
+
200
+ Break the user query into specific subqueries targeting ONE tool each.
201
+ Use ONLY supported IIT names and branch names when present.
202
+
203
+ Return ONLY a JSON object following these instructions:
204
+ {fmt}
205
+
206
+ User Query: "{query}"
207
+ """.strip()
208
+ '''
209
+
210
+ def planning_prompt(
211
+ query: str,
212
+ available_iits: List = [],
213
+ available_branches: List = [],
214
+ years: List = [],
215
+ conversation_context: str = "" # NEW
216
+ ) -> str:
217
+ parser = PydanticOutputParser(pydantic_object=PlanResponse)
218
+ fmt = parser.get_format_instructions()
219
+
220
+ convo = f"\n\nRecent conversation:\n{conversation_context}\n\n" if conversation_context else "\n\n"
221
+ return f"""You are a query planner for a JEE counseling assistant.
222
+
223
+ AVAILABLE TOOLS:
224
+ - ask_excel — ranks/cutoffs
225
+ - ask_pdf — curriculum/NIRF
226
+ - ask_link — placements/faculty/research/facilities{convo}
227
+ Break the user query into specific subqueries targeting ONE tool each.
228
+
229
+ Return ONLY a JSON object following these instructions:
230
+ {fmt}
231
+
232
+ User Query: "{query}"
233
+ """.strip()
234
+
235
+
236
+ # -------------------- Intent detection & planning --------------------
237
+ '''
238
+ def intent_detect(user_q: str, available_iits: List, available_branches: List, years: List) -> IntentSpec:
239
+ response = llm_invoke(intent_prompt(user_q, available_iits, available_branches, years), temperature=0.0)
240
+ print("intent is", f"{response}")
241
+ try:
242
+ return parse_response(response, IntentSpec)
243
+ except Exception as e:
244
+ # default to out_of_scope if parsing fails
245
+ return IntentSpec(in_scope=False, intent="out_of_scope", reason=f"Parse error: {e}")
246
+ '''
247
+
248
+ def intent_detect(
249
+ user_q: str,
250
+ available_iits: List,
251
+ available_branches: List,
252
+ years: List,
253
+ conversation_context: str # NEW
254
+ ) -> IntentSpec:
255
+ response = llm_invoke(
256
+ intent_prompt(user_q, available_iits, available_branches, years, conversation_context),
257
+ temperature=0.0
258
+ )
259
+ return parse_response(response, IntentSpec)
260
+
261
+
262
+ '''
263
+ def make_query_plan(user_q: str, available_iits: List, available_branches: List, years: List) -> PlanResponse:
264
+ response = llm_invoke(planning_prompt(user_q, available_iits, available_branches, years), temperature=0.0)
265
+ return parse_response(response, PlanResponse)
266
+ '''
267
+
268
+ def make_query_plan(
269
+ user_q: str,
270
+ available_iits: List,
271
+ available_branches: List,
272
+ years: List,
273
+ conversation_context: str # NEW
274
+ ) -> PlanResponse:
275
+ response = llm_invoke(
276
+ planning_prompt(user_q, available_iits, available_branches, years, conversation_context),
277
+ temperature=0.0
278
+ )
279
+ return parse_response(response, PlanResponse)
280
+
281
+
282
+ # -------------------- MCP tool registry (real calls) --------------------
283
+ def _build_query_text(query: str, params: Dict[str, Any]) -> str:
284
+ """Compose a single question string using the planner's params and description."""
285
+ if not params:
286
+ return query
287
+ param_str = "; ".join(f"{k}: {v}" for k, v in params.items())
288
+ return f"{query}\nParameters: {param_str}"
289
+
290
+ '''
291
+ def make_tool_registry(mcp_client) -> Dict[str, Any]:
292
+ """
293
+ Return callables that invoke actual MCP tools via your client.
294
+ """
295
+ def call_ask_excel(query: str, required_params: Dict[str, Any], temperature: float = 0.1, top_k: int = 5) -> str:
296
+ q_text = _build_query_text(query, required_params)
297
+ return mcp_client.ask_excel(
298
+ question=q_text,
299
+ top_k=top_k,
300
+ sheet=required_params.get("sheet", 0),
301
+ temperature=temperature,
302
+ )
303
+
304
+ def call_ask_pdf(query: str, required_params: Dict[str, Any], temperature: float = 0.1, top_k: int = 5) -> str:
305
+ q_text = _build_query_text(query, required_params)
306
+ return mcp_client.ask_pdf(
307
+ question=q_text,
308
+ top_k=top_k,
309
+ temperature=temperature,
310
+ )
311
+
312
+ def call_ask_link(query: str, required_params: Dict[str, Any], temperature: float = 0.1, top_k: int = 5) -> str:
313
+ q_text = _build_query_text(query, required_params)
314
+ return mcp_client.ask_link(
315
+ question=q_text,
316
+ temperature=temperature,
317
+ subquery_context=required_params.get("subquery_context"),
318
+ top_k=top_k,
319
+ )
320
+
321
+ return {
322
+ "ask_excel": call_ask_excel,
323
+ "ask_pdf": call_ask_pdf,
324
+ "ask_link": call_ask_link,
325
+ }
326
+ '''
327
+
328
+ # AFTER (CHANGE):
329
+ def make_tool_registry(mcp_client, conversation_context: str) -> Dict[str, Any]:
330
+ def _build_query_text(query: str, params: Dict[str, Any], conversation_context: str) -> str:
331
+ parts = [query.strip()]
332
+ if params:
333
+ parts.append("Parameters: " + "; ".join(f"{k}: {v}" for k, v in params.items()))
334
+ if conversation_context:
335
+ parts.append("Conversation context:\n" + conversation_context)
336
+ return "\n".join(parts)
337
+
338
+ def call_ask_excel(query, required_params, temperature=0.1, top_k=5):
339
+ q_text = _build_query_text(query, required_params, conversation_context)
340
+ return mcp_client.ask_excel(question=q_text, top_k=top_k, sheet=required_params.get("sheet", 0), temperature=temperature)
341
+
342
+ def call_ask_pdf(query, required_params, temperature=0.1, top_k=5):
343
+ q_text = _build_query_text(query, required_params, conversation_context)
344
+ return mcp_client.ask_pdf(question=q_text, top_k=top_k, temperature=temperature)
345
+
346
+ def call_ask_link(query, required_params, temperature=0.1, top_k=5):
347
+ q_text = _build_query_text(query, required_params, "") # put convo in subquery_context instead
348
+ subctx = conversation_context if conversation_context else required_params.get("subquery_context")
349
+ # IMPORTANT: align param name with your server (query vs question)
350
+ return mcp_client.ask_link(
351
+ query=q_text, # if server expects 'query'; use question=q_text otherwise
352
+ temperature=temperature,
353
+ subquery_context=subctx,
354
+ top_k=top_k,
355
+ )
356
+
357
+ return {"ask_excel": call_ask_excel, "ask_pdf": call_ask_pdf, "ask_link": call_ask_link}
358
+
359
+ # -------------------- Execute subqueries & synthesize final --------------------
360
+ def build_execution_order(subqueries: List[SubQuery]) -> List[List[str]]:
361
+ """
362
+ Create batches of IDs whose dependencies are satisfied (simple topological batching).
363
+ """
364
+ if not subqueries:
365
+ return []
366
+ remaining = {sq.id: sq for sq in subqueries}
367
+ completed = set()
368
+ order: List[List[str]] = []
369
+ while remaining:
370
+ ready = [sq_id for sq_id, sq in remaining.items() if all(dep in completed for dep in sq.depends_on)]
371
+ if not ready:
372
+ raise ValueError(f"Circular or unsatisfiable dependencies: {list(remaining.keys())}")
373
+ order.append(ready)
374
+ for sq_id in ready:
375
+ completed.add(sq_id)
376
+ del remaining[sq_id]
377
+ return order
378
+
379
+ #def execute_plan(
380
+ # user_q: str,
381
+ # plan: PlanResponse,
382
+ # mcp_client,
383
+ # temperature: float = 0.1,
384
+ # top_k: int = 5
385
+ #) -> Dict[str, Any]:
386
+ # """
387
+ # Execute subqueries in batches; returns a dict of {sq_id: {tool, answer}}.
388
+ # """
389
+ # registry = make_tool_registry(mcp_client)
390
+
391
+ def execute_plan(user_q, plan, mcp_client, conversation_context: str, temperature=0.1, top_k=5):
392
+ registry = make_tool_registry(mcp_client, conversation_context)
393
+ subqs = plan.subqueries
394
+ exec_order = build_execution_order(subqs)
395
+ results: Dict[str, Any] = {}
396
+
397
+ for batch in exec_order:
398
+ for sq_id in batch:
399
+ sq = next(s for s in subqs if s.id == sq_id)
400
+ tool_fn = registry.get(sq.tool_name)
401
+ if not tool_fn:
402
+ results[sq_id] = {"tool": sq.tool_name, "answer": f"❌ Unknown tool '{sq.tool_name}'"}
403
+ continue
404
+ try:
405
+ ans = tool_fn(sq.query, sq.required_params, temperature=temperature, top_k=top_k)
406
+ results[sq_id] = {"tool": sq.tool_name, "answer": ans}
407
+ except Exception as e:
408
+ results[sq_id] = {"tool": sq.tool_name, "answer": f"❌ Error calling tool: {e}"}
409
+
410
+ return {"execution_order": exec_order, "results": results}
411
+
412
+ '''
413
+ def synthesize_answer(user_q: str, exec_result: Dict[str, Any]) -> str:
414
+ """
415
+ Use OpenAI to write a concise final answer using all tool outputs.
416
+ """
417
+ tool_outputs = []
418
+ for batch in exec_result.get("execution_order", []):
419
+ for sq_id in batch:
420
+ entry = exec_result["results"].get(sq_id, {})
421
+ tool_outputs.append(f"[{sq_id} • {entry.get('tool')}] {entry.get('answer', '')}")
422
+ context = "\n".join(tool_outputs) if tool_outputs else "(no tool outputs)"
423
+
424
+ prompt = f"""You are a helpful assistant for JEE/JOSAA counseling.
425
+
426
+ User Question:
427
+ {user_q}
428
+
429
+ Tool Results:
430
+ {context}
431
+
432
+ Write a concise, accurate final answer grounded in the tool results.
433
+ If the tool results are insufficient, state that clearly.
434
+ Avoid bracketed tags and avoid repeating metadata like [sq1].
435
+ """
436
+ return llm_invoke(prompt, system="You are a helpful assistant. Use only provided context.", temperature=0.2)
437
+ '''
438
+
439
+ # AFTER (CHANGE):
440
+ def synthesize_answer(user_q, exec_result, conversation_context: str):
441
+ tool_outputs = []
442
+ # ...
443
+ prompt = f"""You are a helpful assistant for JEE/JOSAA counseling.
444
+
445
+ Recent conversation:
446
+ {conversation_context or "(none)"}
447
+
448
+ User Question:
449
+ {user_q}
450
+
451
+ Tool Results:
452
+ {exec_result}
453
+
454
+ Write a concise, accurate final answer grounded in the tool results and the recent conversation.
455
+ If the available context is insufficient, state that clearly.
456
+ Avoid bracketed tags and metadata like [sq1].
457
+ """
458
+ return llm_invoke(prompt, system="You are a helpful assistant. Use only provided context.", temperature=0.2)
459
+
460
+ # -------------------- Public entry point used by chat_app --------------------
461
+
462
+
463
+ # AFTER (CHANGE):
464
+ def run_agent(
465
+ user_q: str,
466
+ mcp_client,
467
+ available_iits: List[str],
468
+ available_branches: List[str],
469
+ years: List[str],
470
+ conversation: List[Dict[str, str]], # NEW
471
+ top_k: int = 5,
472
+ temperature: float = 0.1,
473
+ ) -> str:
474
+ conversation_context = _format_history_for_context(conversation, max_turns=8)
475
+
476
+ intent = intent_detect(user_q, available_iits, available_branches, years, conversation_context)
477
+ print(intent)
478
+
479
+ print("The intent response is", f"{intent}")
480
+ if intent.intent == "chit_chat":
481
+ return (
482
+ f"Hi! I’m your JOSAA Counseling Assistant.\n"
483
+ f"Ask about branches, opening/closing ranks, or options for your rank.\n"
484
+ f"Supported IITs: {', '.join(available_iits)}; branches: {', '.join(available_branches)}."
485
+ )
486
+ if not intent.in_scope or intent.intent == "out_of_scope":
487
+ return (
488
+ "This assistant only supports JEE/JOSAA counseling.\n"
489
+ f"Supported IITs: {', '.join(available_iits)}; branches: {', '.join(available_branches)}.\n"
490
+ "Please refine your query accordingly."
491
+ )
492
+
493
+ # In-scope → plan → execute → synthesize
494
+
495
+ plan = make_query_plan(user_q, available_iits, available_branches, years, conversation_context)
496
+ print(plan)
497
+ exec_result = execute_plan(user_q, plan, mcp_client, conversation_context, temperature=temperature, top_k=top_k)
498
+ final = synthesize_answer(user_q, exec_result, conversation_context)
499
+ return final.strip()
500
+