yipengsun commited on
Commit
c8dea05
·
1 Parent(s): c0fff99

Refactor bias detection and output parsing; update requirements

Browse files

- Replaced regex-based sign extraction with MedGemma integration in bias_detector.py.
- Enhanced output_parser.py to utilize llm_output_parser for improved JSON extraction.
- Updated prompts.py to clarify bias source options.
- Added llm-output-parser dependency in requirements.txt.

agents/bias_detector.py CHANGED
@@ -4,7 +4,7 @@ Runs MedSigLIP sign verification on imaging findings mentioned by the Diagnostic
4
  Outputs structured JSON.
5
  """
6
 
7
- import re
8
  import logging
9
  from agents.state import PipelineState
10
  from agents.prompts import BIAS_DETECTOR_SYSTEM, BIAS_DETECTOR_USER
@@ -13,22 +13,21 @@ from models import medgemma_client, medsiglip_client
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
- # Common imaging signs that SigLIP can meaningfully evaluate on chest X-ray.
17
- # These are visual patterns, not abstract diagnoses.
18
- _KNOWN_SIGNS = [
19
- "pleural effusion", "consolidation", "infiltrates", "pneumothorax",
20
- "widened mediastinum", "cardiomegaly", "pulmonary edema", "atelectasis",
21
- "rib fracture", "subcutaneous emphysema", "hilar enlargement",
22
- "hyperinflation", "pleural thickening", "lung opacity", "air bronchogram",
23
- "mediastinal shift", "tracheal deviation", "cephalization",
24
- ]
25
 
 
 
26
 
27
- def _extract_signs(findings: object) -> list[str]:
28
- """Extract imaging signs mentioned in the Diagnostician's findings.
29
 
30
- Matches against known radiological signs rather than parsing diagnoses.
31
- """
32
  if isinstance(findings, list):
33
  chunks: list[str] = []
34
  for item in findings:
@@ -41,31 +40,20 @@ def _extract_signs(findings: object) -> list[str]:
41
  else:
42
  findings_text = str(findings)
43
 
44
- findings_lower = findings_text.lower()
45
- found = []
46
- for sign in _KNOWN_SIGNS:
47
- if sign in findings_lower:
48
- found.append(sign)
49
-
50
- # Also extract any explicit "abnormal" findings with simple patterns
51
- # e.g., "visible pleural line", "blunted costophrenic angle"
52
- extra_patterns = [
53
- r'(?:visible|subtle|small|large|bilateral|unilateral|left|right)\s+([\w\s]{5,30}?)(?:\.|,|;|\n)',
54
- ]
55
- for pat in extra_patterns:
56
- for m in re.findall(pat, findings_lower):
57
- cleaned = m.strip()
58
- if cleaned not in found and len(cleaned) > 5:
59
- found.append(cleaned)
60
-
61
- # Deduplicate, limit to 8
62
- seen = set()
63
- unique = []
64
- for s in found:
65
- if s not in seen:
66
- seen.add(s)
67
- unique.append(s)
68
- return unique[:8]
69
 
70
 
71
  def run(state: PipelineState) -> PipelineState:
 
4
  Outputs structured JSON.
5
  """
6
 
7
+ import json
8
  import logging
9
  from agents.state import PipelineState
10
  from agents.prompts import BIAS_DETECTOR_SYSTEM, BIAS_DETECTOR_USER
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
+ _SIGN_EXTRACTION_PROMPT = """\
17
+ Extract radiological signs (imaging abnormalities) from the following diagnostic findings.
18
+ Return ONLY a JSON array of short sign names.
19
+ Rules:
20
+ - Only include imaging abnormalities that could be visually verified on a medical image.
21
+ - Do NOT include normal anatomical structures, abstract diagnoses, clinical impressions, or treatment recommendations.
22
+ - Maximum 8 signs. If more exist, keep the most clinically significant ones.
23
+ - Return an empty array [] if no signs are found.
 
24
 
25
+ Findings:
26
+ {findings_text}"""
27
 
 
 
28
 
29
+ def _extract_signs(findings: object) -> list[str]:
30
+ """Extract signs from findings using MedGemma."""
31
  if isinstance(findings, list):
32
  chunks: list[str] = []
33
  for item in findings:
 
40
  else:
41
  findings_text = str(findings)
42
 
43
+ if not findings_text.strip():
44
+ return []
45
+
46
+ try:
47
+ raw = medgemma_client.generate_text(
48
+ _SIGN_EXTRACTION_PROMPT.format(findings_text=findings_text),
49
+ )
50
+ parsed = json.loads(raw.strip().strip("`").removeprefix("json").strip())
51
+ if isinstance(parsed, list):
52
+ return [str(s).strip().lower() for s in parsed if isinstance(s, str)][:8]
53
+ except (json.JSONDecodeError, Exception) as e:
54
+ logger.warning("LLM sign extraction failed, raw output: %s — %s", raw, e)
55
+
56
+ return []
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def run(state: PipelineState) -> PipelineState:
agents/output_parser.py CHANGED
@@ -1,11 +1,13 @@
1
  """
2
  JSON output parser for LLM responses.
3
- Uses json_repair to handle malformed JSON (missing commas, truncation, extra text, etc.).
 
4
  """
5
 
6
  import logging
7
  from collections.abc import Mapping
8
  from json_repair import repair_json
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
@@ -48,6 +50,16 @@ def parse_json_response(text: str) -> dict:
48
  if isinstance(result, list):
49
  return _coerce_list_root(result)
50
 
 
 
 
 
 
 
 
 
 
 
51
  raise ValueError(
52
  f"Could not parse JSON from LLM output (got {type(result).__name__}, length={len(text)})"
53
  )
 
1
  """
2
  JSON output parser for LLM responses.
3
+ Uses json_repair for malformed JSON, and llm_output_parser as fallback
4
+ to extract JSON from mixed text/markdown LLM output.
5
  """
6
 
7
  import logging
8
  from collections.abc import Mapping
9
  from json_repair import repair_json
10
+ from llm_output_parser import parse_json as extract_json
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
50
  if isinstance(result, list):
51
  return _coerce_list_root(result)
52
 
53
+ # Fallback: json_repair returned a plain string (model output natural language).
54
+ # Use llm_output_parser to extract JSON from mixed text/markdown.
55
+ if isinstance(result, str):
56
+ logger.warning("json_repair returned str, trying llm_output_parser extraction")
57
+ extracted = extract_json(text, allow_incomplete=True, strict=False)
58
+ if isinstance(extracted, Mapping):
59
+ return dict(extracted)
60
+ if isinstance(extracted, list):
61
+ return _coerce_list_root(extracted)
62
+
63
  raise ValueError(
64
  f"Could not parse JSON from LLM output (got {type(result).__name__}, length={len(text)})"
65
  )
agents/prompts.py CHANGED
@@ -56,7 +56,7 @@ Compare both assessments objectively. Neither is assumed correct. Respond with J
56
  "discrepancy_summary": "how the two assessments differ — note which points are uncertain",
57
  "identified_biases": [
58
  {{
59
- "source": "doctor | AI | both",
60
  "type": "bias type",
61
  "evidence": "why you suspect this bias",
62
  "severity": "choose from LOW | MEDIUM | HIGH"
 
56
  "discrepancy_summary": "how the two assessments differ — note which points are uncertain",
57
  "identified_biases": [
58
  {{
59
+ "source": "choose from HUMAN | AI | BOTH",
60
  "type": "bias type",
61
  "evidence": "why you suspect this bias",
62
  "severity": "choose from LOW | MEDIUM | HIGH"
requirements.txt CHANGED
@@ -8,3 +8,4 @@ Pillow>=10.0.0
8
  numpy>=1.24.0
9
  scipy>=1.10.0
10
  json-repair>=0.30.0
 
 
8
  numpy>=1.24.0
9
  scipy>=1.10.0
10
  json-repair>=0.30.0
11
+ llm-output-parser>=0.3.0