Adibvafa commited on
Commit
eaca108
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. .gitignore +174 -0
  3. .vscode/launch.json +15 -0
  4. LICENSE +201 -0
  5. README.md +83 -0
  6. assets/medrax_logo.jpg +3 -0
  7. assets/medrax_logo.png +3 -0
  8. benchmark/__init__.py +0 -0
  9. benchmark/create_benchmark.py +352 -0
  10. benchmark/llm.py +42 -0
  11. benchmark/utils.py +78 -0
  12. data/eurorad_metadata.json +0 -0
  13. data/figures.py +74 -0
  14. data/get_cases.py +51 -0
  15. data/stats/age_distribution.png +3 -0
  16. data/stats/area_of_interest_distribution.png +3 -0
  17. data/stats/gender_distribution.png +3 -0
  18. demo/chest/LIDC.dcm +3 -0
  19. demo/chest/Pseudo.dcm +3 -0
  20. demo/chest/RIDER.dcm +3 -0
  21. demo/chest/TCGAA.dcm +3 -0
  22. demo/chest/__init__.py +0 -0
  23. demo/chest/effusion1.png +3 -0
  24. demo/chest/normal1.jpg +3 -0
  25. demo/chest/normal2.jpg +3 -0
  26. demo/chest/normal3.jpg +3 -0
  27. demo/chest/normal4.jpg +3 -0
  28. demo/chest/normal5.jpg +3 -0
  29. demo/chest/normal6.jpg +3 -0
  30. demo/chest/pneumonia1.jpg +3 -0
  31. demo/chest/pneumonia2.jpg +3 -0
  32. demo/chest/pneumonia3.jpg +3 -0
  33. demo/chest/pneumonia4.jpg +3 -0
  34. demo/chest/pneumonia5.jpg +3 -0
  35. experiments/README.md +63 -0
  36. experiments/analyze_axes.py +385 -0
  37. experiments/benchmark_chexagent.py +316 -0
  38. experiments/benchmark_gpt4o.py +331 -0
  39. experiments/benchmark_llama.py +443 -0
  40. experiments/benchmark_llavamed.py +541 -0
  41. experiments/benchmark_medrax.ipynb +374 -0
  42. experiments/chexbench_gpt4.py +405 -0
  43. experiments/compare_runs.py +290 -0
  44. experiments/inspect_logs.py +210 -0
  45. experiments/validate_logs.py +162 -0
  46. interface.py +259 -0
  47. main.py +63 -0
  48. medrax/__init__.py +0 -0
  49. medrax/agent/__init__.py +1 -0
  50. medrax/agent/agent.py +193 -0
.gitattributes ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.gif filter=lfs diff=lfs merge=lfs -text
2
+ *.dcm filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.jpg filter=lfs diff=lfs merge=lfs -text
5
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
6
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
7
+ *.sqlite3 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ # ruff
165
+ ruff-cache/
166
+ .ruff_cache/
167
+
168
+ afallah/
169
+
170
+ logs/
171
+
172
+ temp/
173
+
174
+ .gradio/
.vscode/launch.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python Debugger: main.py",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "program": "main.py",
12
+ "console": "integratedTerminal"
13
+ }
14
+ ]
15
+ }
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">
2
+ 🤖 MedRAX: Medical Reasoning Agent for Chest X-ray 🏥
3
+ </h1>
4
+ <br>
5
+
6
+ ## Problem
7
+ Medical professionals face significant challenges when using traditional Large Language Models (LLMs) for X-ray analysis. Standard LLMs often hallucinate, lack specialized medical imaging capabilities, and can miss critical diagnostic details. While separate tools exist for various aspects of X-ray analysis, the current fragmented approach requires doctors to juggle multiple systems, leading to inefficient workflows and potential oversights in patient care.
8
+ <br>
9
+ <br>
10
+
11
+ ## Our Solution
12
+ MedRAX is an intelligent medical assistant that seamlessly integrates an LLM with specialized X-ray analysis tools, providing a unified interface for comprehensive X-ray analysis. Through natural conversation, medical professionals can leverage powerful tools while the system intelligently coordinates their usage behind the scenes.
13
+
14
+ Our comprehensive toolset includes:
15
+ - **ChestXRayReportGenerator**: Generates detailed, accurate medical reports from X-ray images
16
+ - **ChestXRayClassifier**: Analyzes images for 18 different pathologies providing probability scores for each condition
17
+ - **ChestXRaySegmentation**: Precisely segments anatomical structures
18
+ - **MedicalVisualQA**: Answers to complex visual medical queries
19
+ - **XRayPhraseGrounding**: Locates and visualizes specific medical findings in X-rays with bounding box precision
20
+ - **ImageVisualizer**: Enhances and displays X-ray images for optimal viewing
21
+ - **ChestXRayGenerator**: Generates synthetic chest X-rays for educational purposes
22
+ - **DicomProcessor**: Handles DICOM file processing and analysis
23
+ <br>
24
+
25
+ ## Technical Implementation
26
+ MedRAX is built on a robust technical foundation:
27
+ - **Core Architecture**: Leverages LangChain and LangGraph for sophisticated agent orchestration
28
+ - **Language Model**: Powered by OpenAI's API for natural language understanding and generation
29
+ - **Specialized Tools**: Integrates medical-domain fine-tuned models for various analysis tasks
30
+ - **Interface**: Built with Gradio for an intuitive, chat-based user experience
31
+ - **Modular Design**: Allows easy integration of additional specialized medical tools
32
+ <br>
33
+
34
+ ## Potential Impact
35
+ - Accelerates X-ray analysis while maintaining high accuracy
36
+ - Reduces the likelihood of missed diagnoses through multi-tool verification
37
+ - Provides valuable educational support for medical students and residents
38
+ - Offers a scalable solution for facilities with limited specialist availability
39
+ - Improves patient outcomes through comprehensive analysis
40
+ - Streamlines workflow for medical professionals
41
+ <br>
42
+
43
+ ## Setup and Usage
44
+
45
+ ### Prerequisites
46
+ - GPU required for optimal performance
47
+ - Python 3.8+
48
+ - OpenAI API key
49
+
50
+ ### Installation
51
+ 1. Clone the repository:
52
+ ```bash
53
+ git clone https://github.com/yourusername/MedRAX.git
54
+ cd MedRAX
55
+ ```
56
+
57
+ 2. Install dependencies:
58
+ ```bash
59
+ pip install -e .
60
+ ```
61
+
62
+ 3. Set up environment variables:
63
+ ```bash
64
+ echo "OPENAI_API_KEY=your_key_here" > .env
65
+ ```
66
+
67
+ ### Running the Application
68
+ Start the application:
69
+ ```bash
70
+ python main.py
71
+ ```
72
+ <br>
73
+
74
+ ## Developers
75
+ - Adibvafa Fallahpour
76
+ - Jun Ma
77
+ - Hongwei Lyu
78
+ <br>
79
+
80
+ ---
81
+ <p align="center">
82
+ Made with ❤️ in Toronto
83
+ </p>
assets/medrax_logo.jpg ADDED

Git LFS Details

  • SHA256: 306aa20d47067df102e4ba26d637f22a7d95f449a5969d320ceeca03b71da1d1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
assets/medrax_logo.png ADDED

Git LFS Details

  • SHA256: 5af3f42308022abe028b670e6716152e714c1f25ebbe6375532775a557b66b2c
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
benchmark/__init__.py ADDED
File without changes
benchmark/create_benchmark.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Medical X-ray Question Generation Benchmark aka ChestAgentBench
4
+
5
+ This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o.
6
+ It structures questions across different analytical categories and saves them as JSON.
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import json
12
+ from typing import *
13
+ from pprint import pprint
14
+
15
+ import openai
16
+ import numpy as np
17
+ from scipy import stats
18
+ import plotly.graph_objects as go
19
+ from tqdm import tqdm
20
+
21
+ from benchmark.utils import load_eurorad_dataset
22
+ from benchmark.llm import get_llm_response
23
+
24
+ # Constants
25
+ DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data"
26
+ DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json")
27
+
28
+ SYSTEM_PROMPT = """
29
+ You are an expert medical benchmark creation assistant.
30
+ Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays.
31
+ """.strip()
32
+
33
+ CATEGORIES_META = {
34
+ "detection": "Identify and locate specific findings in the chest X-ray.",
35
+ "classification": "Determine whether specific findings are present or absent in the chest X-ray.",
36
+ "enumeration": "Count the number of target findings in the chest X-ray.",
37
+ "localization": "Locate a given finding in the chest X-ray.",
38
+ "comparison": "Compare the size or position of a specific finding in the chest X-ray.",
39
+ "relationship": "Determine the relationship between two or more findings in the chest X-ray.",
40
+ "diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.",
41
+ "characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.",
42
+ "reasoning": "Explain the medical rationale and thought process behind findings and conclusions.",
43
+ }
44
+ CATEGORIES = list(CATEGORIES_META.keys())
45
+
46
+ CATEGORY_COMBINATIONS = [
47
+ ["detection", "localization", "characterization", "reasoning"], # Detailed Finding Analysis
48
+ ["detection", "classification", "relationship", "reasoning"], # Pattern Recognition & Relations
49
+ ["localization", "comparison", "relationship", "reasoning"], # Spatial Understanding
50
+ ["classification", "comparison", "diagnosis", "reasoning"], # Clinical Decision Making
51
+ ["classification", "characterization", "diagnosis", "reasoning"], # Diagnostic Characterization
52
+ ]
53
+
54
+ DEFAULT_SECTIONS = [
55
+ "history",
56
+ "image_finding",
57
+ "discussion",
58
+ "differential_diagnosis",
59
+ "diagnosis",
60
+ "figures",
61
+ ]
62
+
63
+
64
+ class Question:
65
+ """A class to generate clinical questions from case data.
66
+
67
+ This class handles creating structured clinical questions by combining case data with
68
+ specified categories and difficulty levels.
69
+
70
+ Attributes:
71
+ type (str): The type of question (e.g. multiple choice)
72
+ difficulty (str): Difficulty level of the question
73
+ case_data (Dict[str, Any]): Dictionary containing the clinical case data
74
+ case_content (str): Formatted case data from selected sections
75
+ case_id (str): Unique identifier for the case
76
+ categories (List[str]): List of analytical categories this question tests
77
+ sections (List[str]): Case sections to include in question
78
+ raw_content (Optional[str]): Raw LLM response to the question prompt
79
+ content (Optional[Dict[str, str]]): Extracted content from the raw LLM response
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ type: str,
85
+ difficulty: str,
86
+ case_data: Dict[str, Any],
87
+ categories: List[str],
88
+ sections: List[str] = [
89
+ "history",
90
+ "image_finding",
91
+ "discussion",
92
+ "differential_diagnosis",
93
+ "diagnosis",
94
+ "figures",
95
+ ],
96
+ system_prompt: str = "You are an expert medical benchmark creation assistant.",
97
+ ) -> None:
98
+ self.type = type
99
+ self.difficulty = difficulty
100
+ self.case_data = case_data
101
+ self.case_id = case_data["case_id"]
102
+ self.categories = categories
103
+ self.sections = sections
104
+ self.system_prompt = system_prompt
105
+ self.case_content = self.select_case_sections()
106
+ self.raw_content: Optional[str] = None
107
+ self.content: Optional[Dict[str, str]] = None
108
+
109
+ def create_question_prompt(self) -> str:
110
+ """Creates a formatted prompt for generating a clinical question.
111
+
112
+ Returns:
113
+ str: A structured prompt containing the question parameters and clinical data
114
+ """
115
+ category_descriptions = "\n".join(
116
+ f"{category}: {desc}"
117
+ for category, desc in CATEGORIES_META.items()
118
+ if category in self.categories
119
+ )
120
+
121
+ return f"""
122
+ You must follow these guidelines:
123
+ 1. Questions must be answerable using only context and chest X-rays.
124
+ - Questions must explicitly mention the referenced figures
125
+ - Questions can only reference the chest X-ray figures
126
+
127
+ 2. Questions must have unambiguous, verifiable answers, and should:
128
+ - Challenge the agent's analytical capabilities
129
+ - Require multi-step reasoning
130
+ - Test ability to make precise observations
131
+ - Evaluate capability to derive insights and findings from the chest X-ray
132
+
133
+ 3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools.
134
+
135
+
136
+ Create a {self.difficulty} {self.type} clinical question that integrates the following:
137
+
138
+ {category_descriptions}
139
+
140
+ based on the following clinical case:
141
+
142
+ {self.case_content}
143
+
144
+ Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays.
145
+ Your question should require the agent to derive insights and findings from the chest X-ray by itself.
146
+ Your answer should be verifiable directly in the context of the case.
147
+ You can only use the image findings that come from the chest X-ray figures.
148
+
149
+ Your response must follow this exact format:
150
+ THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question]
151
+ QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.]
152
+ FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]]
153
+ EXPLANATION: [short explanation of why your answer is verifiable in the case]
154
+ ANSWER: [correct answer e.g. "A"]
155
+ """.strip().replace(
156
+ " ", ""
157
+ ) # remove tabs
158
+
159
+ def select_case_sections(self) -> str:
160
+ """Extract and format selected sections from case data into paragraphs.
161
+
162
+ Returns:
163
+ str: Formatted string with case sections and content
164
+ """
165
+ section_mapping = {
166
+ "history": ("history", "No history provided."),
167
+ "image_finding": ("image_finding", "No findings provided."),
168
+ "discussion": ("discussion", "No discussion provided."),
169
+ "differential_diagnosis": (
170
+ "differential_diagnosis",
171
+ "No differential diagnosis provided.",
172
+ ),
173
+ "diagnosis": ("diagnosis", "No diagnosis provided."),
174
+ "figures": ("figures", "No figures provided."),
175
+ }
176
+
177
+ formatted = []
178
+ for section in self.sections:
179
+ if section in section_mapping:
180
+ key, default = section_mapping[section]
181
+ content = self.case_data.get(key, default)
182
+
183
+ if key == "figures":
184
+ figures_text = []
185
+ for figure in content:
186
+ for subfig in figure["subfigures"]:
187
+ figures_text.append(f"{subfig['number']}: {subfig['caption']}")
188
+ content = "\n".join(figures_text)
189
+
190
+ formatted.append(f"{section}:\n{content}")
191
+
192
+ return "\n\n".join(formatted)
193
+
194
+ def create_question(
195
+ self,
196
+ client: openai.OpenAI,
197
+ temperature: float = 0.7,
198
+ top_p: float = 0.95,
199
+ max_tokens: int = 500,
200
+ model: str = "gpt-4o",
201
+ ) -> str:
202
+ """Create a clinical question using LLM.
203
+
204
+ Args:
205
+ client (openai.OpenAI): OpenAI client instance
206
+ temperature (float): Controls randomness in responses. Defaults to 0.7.
207
+ top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95.
208
+ max_tokens (int): Max tokens in model response. Defaults to 500.
209
+ model (str): OpenAI model to use. Defaults to "gpt-4o".
210
+
211
+ Returns:
212
+ str: LLM response containing formatted question components
213
+ """
214
+ self.raw_content = get_llm_response(
215
+ client=client,
216
+ prompt=self.create_question_prompt(),
217
+ system_prompt=self.system_prompt,
218
+ temperature=temperature,
219
+ top_p=top_p,
220
+ max_tokens=max_tokens,
221
+ model=model,
222
+ )
223
+ self.content = self.extract_content()
224
+
225
+ return self.raw_content
226
+
227
+ def extract_content(self) -> Dict[str, str]:
228
+ """Extract sections from raw LLM response using regex patterns.
229
+
230
+ Returns:
231
+ Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer
232
+ """
233
+ keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"]
234
+
235
+ content = {}
236
+ for kw in keywords:
237
+ pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)"
238
+ match = re.search(pattern, self.raw_content, re.DOTALL)
239
+ content[kw.lower()] = match.group(1).strip() if match else None
240
+
241
+ return content
242
+
243
+ def save(self, output_path: str) -> Dict[str, Any]:
244
+ """Save question content and metadata as a JSON file.
245
+
246
+ Args:
247
+ output_path (str): Directory path where the JSON file will be saved
248
+
249
+ Returns:
250
+ Dict[str, Any]: Question data including content (thoughts, question, figures, options,
251
+ explanation, answer) and metadata (type, difficulty, categories, etc.)
252
+ """
253
+ question_metadata = self.content.copy()
254
+
255
+ # Add metadata
256
+ question_metadata["metadata"] = {
257
+ "case_id": self.case_id,
258
+ "type": self.type,
259
+ "difficulty": self.difficulty,
260
+ "categories": self.categories,
261
+ "sections": self.sections,
262
+ }
263
+
264
+ # Create a directory for the case
265
+ case_dir = os.path.join(output_path, str(self.case_id))
266
+ os.makedirs(case_dir, exist_ok=True)
267
+
268
+ # Save the question metadata to a JSON file
269
+ output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json")
270
+ with open(output_file, "w") as f:
271
+ json.dump(question_metadata, f, indent=2)
272
+
273
+ return question_metadata
274
+
275
+
276
+ def generate_questions(
277
+ dataset: Dict[str, Any],
278
+ client: openai.OpenAI,
279
+ output_dir: str,
280
+ skip_first: int = 100,
281
+ temperature: float = 0.7,
282
+ top_p: float = 0.95,
283
+ max_tokens: int = 1200,
284
+ model: str = "gpt-4o",
285
+ ) -> None:
286
+ """Generate questions for each case and category combination.
287
+
288
+ Args:
289
+ dataset: Dictionary of case data
290
+ client: OpenAI client instance
291
+ output_dir: Directory to save generated questions
292
+ skip_first: Number of initial cases to skip
293
+ temperature: LLM temperature parameter
294
+ top_p: LLM top_p parameter
295
+ max_tokens: Maximum tokens for LLM response
296
+ model: LLM model name
297
+ """
298
+ target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first]
299
+
300
+ for case_id in tqdm(target_cases, desc="Processing cases"):
301
+ case_data = dataset[case_id]
302
+
303
+ for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"):
304
+ question = Question(
305
+ type="multiple choice (A/B/C/D/E/F)",
306
+ difficulty="complex",
307
+ case_data=case_data,
308
+ categories=category,
309
+ sections=DEFAULT_SECTIONS,
310
+ system_prompt=SYSTEM_PROMPT,
311
+ )
312
+
313
+ response = question.create_question(
314
+ client=client,
315
+ temperature=temperature,
316
+ top_p=top_p,
317
+ max_tokens=max_tokens,
318
+ model=model,
319
+ )
320
+ question.save(output_dir)
321
+
322
+
323
+ def main():
324
+ """Main execution function."""
325
+ client = openai.OpenAI()
326
+
327
+ # Load and verify dataset
328
+ dataset = load_eurorad_dataset(
329
+ DATASET_PATH,
330
+ section="Chest Imaging",
331
+ as_dict=True,
332
+ filter_by_caption=[
333
+ "xray",
334
+ "x-ray",
335
+ "x ray",
336
+ "ray",
337
+ "xr",
338
+ "radiograph",
339
+ ],
340
+ )
341
+ print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n")
342
+
343
+ # Optional: Print sample case for verification
344
+ case_data = dataset["16798"]
345
+ pprint(case_data, sort_dicts=False)
346
+
347
+ # Generate questions
348
+ generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions")
349
+
350
+
351
+ if __name__ == "__main__":
352
+ main()
benchmark/llm.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from typing import List
3
+
4
+
5
+ def get_llm_response(
6
+ client: openai.OpenAI,
7
+ prompt: str,
8
+ system_prompt: str = "You are a helpful assistant.",
9
+ model: str = "gpt-4o-mini",
10
+ temperature: float = 0.7,
11
+ top_p: float = 0.95,
12
+ max_tokens: int = 500,
13
+ ) -> str:
14
+ """
15
+ Get response from OpenAI language model.
16
+
17
+ Args:
18
+ client (openai.OpenAI): OpenAI client
19
+ prompt (str): The user prompt/question to send to the model
20
+ system_prompt (str, optional): System prompt to set model behavior.
21
+ model (str, optional): OpenAI model to use. Defaults to "gpt-4o-mini".
22
+ temperature (float, optional): Controls randomness in responses. Defaults to 0.7.
23
+ top_p (float, optional): Controls diversity via nucleus sampling. Defaults to 0.95.
24
+ max_tokens (int, optional): Max tokens in model response. Defaults to 200.
25
+
26
+ Returns:
27
+ str: The model's response text
28
+ """
29
+ messages = [
30
+ {"role": "system", "content": system_prompt},
31
+ {"role": "user", "content": prompt},
32
+ ]
33
+
34
+ response = client.chat.completions.create(
35
+ model=model,
36
+ messages=messages,
37
+ temperature=temperature,
38
+ top_p=top_p,
39
+ max_tokens=max_tokens,
40
+ )
41
+
42
+ return response.choices[0].message.content
benchmark/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Dict, List
4
+
5
+
6
+ def load_eurorad_dataset(
7
+ dataset_path: str,
8
+ section: str = "any",
9
+ as_dict: bool = False,
10
+ filter_by_caption: List[str] = [
11
+ "xray",
12
+ "x-ray",
13
+ "x ray",
14
+ "ray",
15
+ "xr",
16
+ "radiograph",
17
+ "radiogram",
18
+ "plain film",
19
+ ],
20
+ ) -> List[Dict] | Dict[str, Dict]:
21
+ """
22
+ Load a dataset from a JSON file.
23
+
24
+ Args:
25
+ dataset_path (str): Path to the JSON dataset file.
26
+ section (str, optional): Section of the dataset to load. Defaults to "any".
27
+ as_dict (bool, optional): Whether to return data as dict. Defaults to False.
28
+ filter_by_caption (List[str], optional): List of strings to filter cases by caption content. Defaults to [].
29
+
30
+ Returns:
31
+ List[Dict] | Dict[str, Dict]: The loaded dataset as a list of dictionaries or dict if as_dict=True.
32
+
33
+ Raises:
34
+ FileNotFoundError: If dataset_path does not exist
35
+ json.JSONDecodeError: If file is not valid JSON
36
+ """
37
+
38
+ with open(dataset_path, "r", encoding="utf-8") as file:
39
+ data = json.load(file)
40
+
41
+ if filter_by_caption:
42
+ filtered_data = {}
43
+ for case_id, case in data.items():
44
+ if any(
45
+ any(x in subfig["caption"].lower() for x in filter_by_caption)
46
+ for figure in case["figures"]
47
+ for subfig in figure["subfigures"]
48
+ ) or any(x in case["image_finding"].lower() for x in filter_by_caption):
49
+ filtered_data[case_id] = case
50
+ data = filtered_data
51
+
52
+ if section != "any":
53
+ section = section.strip().lower()
54
+ if not as_dict:
55
+ data = [
56
+ item for item in data.values() if item.get("section", "").strip().lower() == section
57
+ ]
58
+ else:
59
+ data = {
60
+ k: v for k, v in data.items() if v.get("section", "").strip().lower() == section
61
+ }
62
+
63
+ elif not as_dict:
64
+ data = list(data.values())
65
+
66
+ return data
67
+
68
+
69
+ def save_dataset(dataset: Dict | List[Dict], dataset_path: str):
70
+ """
71
+ Save a dataset to a JSON file.
72
+
73
+ Args:
74
+ dataset (Dict | List[Dict]): The dataset to save as a dictionary or list of dictionaries.
75
+ dataset_path (str): Path where the JSON dataset file will be saved.
76
+ """
77
+ with open(dataset_path, "w", encoding="utf-8") as file:
78
+ json.dump(dataset, file)
data/eurorad_metadata.json ADDED
The diff for this file is too large to render. See raw diff
 
data/figures.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ import requests
5
+ from tqdm import tqdm
6
+
7
+
8
+ def download_eurorad_figures(metadata_path: str, output_dir: str) -> None:
9
+ """
10
+ Download figures from Eurorad dataset and save them organized by case_id.
11
+
12
+ Args:
13
+ metadata_path: Path to the eurorad_metadata.json file
14
+ output_dir: Base directory where figures will be saved
15
+
16
+ The figures will be saved as:
17
+ {output_dir}/{case_id}/{figure_number}.jpg
18
+ Example:
19
+ figures/189/Figure_1a.jpg
20
+ """
21
+ # Create output directory if it doesn't exist
22
+ output_path = Path(output_dir)
23
+ output_path.mkdir(exist_ok=True)
24
+
25
+ # Load metadata
26
+ with open(metadata_path) as f:
27
+ metadata = json.load(f)
28
+
29
+ # Iterate through all cases with progress bar
30
+ for case_id in tqdm(metadata, desc="Downloading cases", unit="case"):
31
+ case = metadata[case_id]
32
+ case_dir = output_path / str(case["case_id"])
33
+ case_dir.mkdir(exist_ok=True)
34
+
35
+ # Process all figures and their subfigures
36
+ for figure in case["figures"]:
37
+ for subfig in figure["subfigures"]:
38
+
39
+ # Remove leading and trailing whitespace and convert to lowercase
40
+ subfig_name = f"{subfig['number'].strip().replace(' ', '_').lower()}.jpg"
41
+ subfig_path = Path(case_dir) / subfig_name
42
+
43
+ save_figure(
44
+ url=subfig["url"],
45
+ output_path=subfig_path,
46
+ )
47
+
48
+
49
+ def save_figure(url: str, output_path: Path) -> None:
50
+ """
51
+ Download and save a single figure.
52
+
53
+ Args:
54
+ url: URL of the figure to download
55
+ output_path: Path where the figure should be saved
56
+ """
57
+ if output_path.exists():
58
+ return
59
+
60
+ try:
61
+ response = requests.get(url, timeout=10)
62
+ response.raise_for_status()
63
+ with open(output_path, "wb") as f:
64
+ f.write(response.content)
65
+ except Exception as e:
66
+ print(f"Error downloading {url}: {e}")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ root = os.path.dirname(os.path.abspath(__file__))
71
+ download_eurorad_figures(
72
+ metadata_path=os.path.join(root, "eurorad_metadata.json"),
73
+ output_dir=os.path.join(root, "figures"),
74
+ )
data/get_cases.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import time
4
+ import json
5
+ from tqdm import tqdm
6
+
7
+
8
+ def get_response(url):
9
+ headers = {
10
+ "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36 Edg/108.0.1462.54"
11
+ }
12
+ return requests.get(url, headers=headers)
13
+
14
+ def get_case_numbers_from_page(page):
15
+ url = f"https://www.eurorad.org/advanced-search?sort_by=published_at&sort_order=ASC&page={page}&filter%5B0%5D=section%3A40"
16
+
17
+ # Remove proxy usage since it's likely triggering the protection
18
+ response = get_response(url)
19
+ print(response.text)
20
+
21
+ soup = BeautifulSoup(response.text, "html.parser")
22
+ spans = soup.find_all("span", class_="case__number small")
23
+
24
+ # Remove '#' from the span text and strip extra whitespace
25
+ numbers = [span.text.strip().replace("#", "").strip() for span in spans]
26
+ return numbers
27
+
28
+
29
+ def main():
30
+ total_pages = 107 # Pages 0 through 106
31
+ all_numbers = []
32
+
33
+ for page in tqdm(range(total_pages)):
34
+ numbers = get_case_numbers_from_page(page)
35
+ all_numbers.extend(numbers)
36
+
37
+ if page != total_pages - 1 and len(numbers) != 9:
38
+ print(f"Warning: Page {page} returned {len(numbers)} cases instead of 9")
39
+
40
+ # Be kind to the server – avoid hitting it too fast
41
+ time.sleep(1)
42
+ break
43
+
44
+ with open('case_numbers.json', 'w') as f:
45
+ json.dump(all_numbers, f)
46
+
47
+ print(f"Saved {len(all_numbers)} case numbers to case_numbers.json")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
data/stats/age_distribution.png ADDED

Git LFS Details

  • SHA256: 0409ec03f305ccd8fdee1c097dede52b7cf0f84f05b99fbd18727fb8e67238ad
  • Pointer size: 132 Bytes
  • Size of remote file: 2.71 MB
data/stats/area_of_interest_distribution.png ADDED

Git LFS Details

  • SHA256: 2a80d9aa1bf9b025b8aaa2b1c0d4807e36afc175747ba71b500ef1ceaf542081
  • Pointer size: 132 Bytes
  • Size of remote file: 2.91 MB
data/stats/gender_distribution.png ADDED

Git LFS Details

  • SHA256: a4cfd37f71fc91a848d990f6e2ff6c9611f555e09e435885680ffbbb85458838
  • Pointer size: 132 Bytes
  • Size of remote file: 1.96 MB
demo/chest/LIDC.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11d25b1d34dff083057de994fef7da3dcef75bd7b334823ec6cb9c16b3ba0338
3
+ size 17071804
demo/chest/Pseudo.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b35ae460fb5f62eb6d6c4c5117f6683100ad92c5fb6ba1a3c36da39703c4652
3
+ size 7535280
demo/chest/RIDER.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc15f7afa5434991e1359f596433870ad611b42227db87d484d31976545de7fd
3
+ size 7534066
demo/chest/TCGAA.dcm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e8137290ac823d3da3c00ce3e18120123eaa62a786934c7afc52a989b0b64cf
3
+ size 7535274
demo/chest/__init__.py ADDED
File without changes
demo/chest/effusion1.png ADDED

Git LFS Details

  • SHA256: ba5af84601f11ab44142e5dfaf578b49d76de45633470e606c7edc4b1c77ba07
  • Pointer size: 131 Bytes
  • Size of remote file: 233 kB
demo/chest/normal1.jpg ADDED

Git LFS Details

  • SHA256: 785419c9ec7d0235fe056c254cd3be785d6052b558ae32c595ad558be57062dd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
demo/chest/normal2.jpg ADDED

Git LFS Details

  • SHA256: cecf56a8b90e9ccb3c54641beb40652e72a2bdcb311efc696a331fe4de7efbf0
  • Pointer size: 131 Bytes
  • Size of remote file: 798 kB
demo/chest/normal3.jpg ADDED

Git LFS Details

  • SHA256: 3f721831529e9604c99e3bd999483321e0e0648c5987351570fe45e48c190948
  • Pointer size: 132 Bytes
  • Size of remote file: 1.43 MB
demo/chest/normal4.jpg ADDED

Git LFS Details

  • SHA256: ed84d75328f1eb80c6554e3c6ba8dcd573e733914b2934bfce399ae6e8f38ec4
  • Pointer size: 131 Bytes
  • Size of remote file: 566 kB
demo/chest/normal5.jpg ADDED

Git LFS Details

  • SHA256: 9e7c4251d9b300f9256c6fe72ef1c3167beeecca747e6b9c8b80ee3260ea9ac8
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
demo/chest/normal6.jpg ADDED

Git LFS Details

  • SHA256: 4b47dd1665b828ab3610d1a60ec08c37083579f834b2dd5891570c8a105825a5
  • Pointer size: 131 Bytes
  • Size of remote file: 387 kB
demo/chest/pneumonia1.jpg ADDED

Git LFS Details

  • SHA256: 92d1c1e3334b1dd8f1d5eea56681adeb38dc5b7c8dd17536fb0e47fc701c5ae1
  • Pointer size: 130 Bytes
  • Size of remote file: 35.6 kB
demo/chest/pneumonia2.jpg ADDED

Git LFS Details

  • SHA256: eb17ab7b6f63d0f0078c378a1cc1debbffffd6331cb2723f7169410a738287fa
  • Pointer size: 130 Bytes
  • Size of remote file: 56.7 kB
demo/chest/pneumonia3.jpg ADDED

Git LFS Details

  • SHA256: dd41c787362b60e03037b6658f3824068ea268d83915904efb09aae95e10bd72
  • Pointer size: 130 Bytes
  • Size of remote file: 81.7 kB
demo/chest/pneumonia4.jpg ADDED

Git LFS Details

  • SHA256: 8223cf57d33d1528782f83b62d3d62d2f41fe9bf34053553a86e609c2b2ba94b
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
demo/chest/pneumonia5.jpg ADDED

Git LFS Details

  • SHA256: 59bee7e6a36e7629a320e1c74d65dd0683c8310dbbb2489f5d32054419a3a667
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
experiments/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiments
2
+ Below are the instructions for running experiments using our novel ChestAgentBench and the previous SoTA CheXbench. ChestAgentBench is a comprehensive benchmark containing over 2,500 complex medical queries across 8 diverse categories.
3
+
4
+ ### ChestAgentBench
5
+
6
+ To run gpt-4o on ChestAgentBench, enter the `experiments` directory and run the following script:
7
+ ```bash
8
+ python benchmark_gpt4o.py
9
+ ```
10
+
11
+ To run llama 3.2 vision 90B on ChestAgentBench, run the following:
12
+ ```bash
13
+ python benchmark_llama.py
14
+ ```
15
+
16
+ To run chexagent on ChestAgentBench, run the following:
17
+ ```bash
18
+ python benchmark_chexagent.py
19
+ ```
20
+
21
+ To run llava-med on ChestAgentBench, you'll need to clone their repo and copy the following script into it, after you follow their setup instructions.
22
+ ```bash
23
+ mv benchmark_llavamed.py ~/LLaVA-Med/llava/serve
24
+ python -m llava.serve.benchmark_llavamed --model-name llava-med-v1.5-mistral-7b --controller http://localhost:10000
25
+ ```
26
+
27
+ If you want to inspect the logs, you can run the following. It will select the most recent log file by default.
28
+ ```bash
29
+ python inspect_logs.py [optional: log-file] -n [num-logs]
30
+ ```
31
+
32
+ Finally, to analyze results, run:
33
+ ```bash
34
+ python analyze_axes.py results/[logfile].json ../benchmark/questions/ --model [gpt4|llama|chexagent|llava-med] --max-questions [optional:int]
35
+ ```
36
+
37
+ ### CheXbench
38
+
39
+ To run the models on chexbench, you can use `chexbench_gpt4.py` as a reference. You'll need to download the dataset files locally, and upload them for each request. Rad-ReStruct and Open-I use the same set of images, so you can download the `NLMCXR.zip` file just once and copy the images to both directories.
40
+
41
+ You can find the datasets here:
42
+ 1. [SLAKE: A Semantically-Labeled Knowledge-Enhanced Dataset for Medical Visual Question Answering](https://www.med-vqa.com/slake/). Save this to `MedMAX/data/slake`.
43
+ 2. [Rad-ReStruct: A Novel VQA Benchmark and Method for Structured Radiology Reporting](https://github.com/ChantalMP/Rad-ReStruct). Save the images to `MedMAX/data/rad-restruct/images`.
44
+ 3. [Open-I Service of the National Library of Medicine](https://openi.nlm.nih.gov/faq). Save the images to `MedMAX/data/openi/images`.
45
+
46
+ Once you're finished, you'll want to fix the paths in the `chexbench.json` file to your local paths using the `MedMax/data/fix_chexbench.py` script.
47
+
48
+
49
+ ### Compare Runs
50
+ Analyze a single file based on overall accuracy and along different axes
51
+ ```
52
+ python compare_runs.py results/medmax.json
53
+ ```
54
+
55
+ For a direct evaluation comparing **2** models, on the exact same questions
56
+ ```
57
+ python compare_runs.py results/medmax.json results/gpt4o.json
58
+ ```
59
+
60
+ For a direct evaluation comparing **ALL** models, on the exact same questions (add as many model log files as you want).
61
+ ```
62
+ python compare_runs.py results/medmax.json results/gpt4o.json results/llama.json results/chexagent.json results/llavamed.json
63
+ ```
experiments/analyze_axes.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union, Any
2
+ import json
3
+ import os
4
+ import sys
5
+ import argparse
6
+ from collections import defaultdict
7
+ from tqdm import tqdm
8
+
9
+ QUESTION_TYPES = {
10
+ "Detailed Finding Analysis": ["detection", "localization", "characterization"],
11
+ "Pattern Recognition & Relations": ["detection", "classification", "relationship"],
12
+ "Spatial Understanding": ["localization", "comparison", "relationship"],
13
+ "Clinical Decision Making": ["classification", "comparison", "diagnosis"],
14
+ "Diagnostic Classification": ["classification", "characterization", "diagnosis"],
15
+ }
16
+
17
+
18
+ def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]:
19
+ """
20
+ Extract just the letter from various answer formats.
21
+
22
+ Args:
23
+ answer: The answer text to extract letter from
24
+
25
+ Returns:
26
+ Optional[str]: The extracted letter in uppercase, or None if no letter found
27
+ """
28
+ if not answer:
29
+ return None
30
+
31
+ # Convert to string and clean
32
+ answer = str(answer).strip()
33
+
34
+ # If it's just a single letter, return it
35
+ if len(answer) == 1 and answer.isalpha():
36
+ return answer.upper()
37
+
38
+ # Try to extract letter from format like "A)" or "A."
39
+ if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ":
40
+ return answer[0].upper()
41
+
42
+ # Try to extract letter from format like "A) Some text"
43
+ if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")):
44
+ return answer[0].upper()
45
+
46
+ return None
47
+
48
+
49
+ def analyze_gpt4_results(
50
+ results_file: str, max_questions: Optional[int] = None
51
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
52
+ """
53
+ Analyze results in GPT-4 format.
54
+
55
+ Args:
56
+ results_file: Path to results file
57
+ max_questions: Maximum number of questions to analyze
58
+
59
+ Returns:
60
+ Tuple containing:
61
+ - overall_accuracy (float)
62
+ - category_accuracies (Dict)
63
+ - question_type_stats (Dict)
64
+ - correct_ids (List[str])
65
+ - incorrect_ids (List[str])
66
+ """
67
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
68
+ all_questions = 0
69
+ all_correct = 0
70
+ correct_ids = []
71
+ incorrect_ids = []
72
+
73
+ with open(results_file, "r") as f:
74
+ lines = f.readlines()
75
+
76
+ processed_questions = 0
77
+
78
+ for line in tqdm(lines, desc="Analyzing Benchmark Results"):
79
+ # Check if we've hit the maximum questions
80
+ if max_questions is not None and processed_questions >= max_questions:
81
+ break
82
+ if line.startswith("HTTP Request:"):
83
+ continue
84
+
85
+ try:
86
+ entry = json.loads(line)
87
+ metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
88
+ question_id = entry.get("question_id")
89
+
90
+ model_letter = extract_answer_letter(entry.get("model_answer"))
91
+ correct_letter = extract_answer_letter(entry.get("correct_answer"))
92
+
93
+ if model_letter and correct_letter:
94
+ all_questions += 1
95
+ processed_questions += 1
96
+ is_correct = model_letter == correct_letter
97
+
98
+ if is_correct:
99
+ all_correct += 1
100
+ correct_ids.append(question_id)
101
+ else:
102
+ incorrect_ids.append(question_id)
103
+
104
+ for category in metadata.get("categories", []):
105
+ category_performance[category]["total"] += 1
106
+ if is_correct:
107
+ category_performance[category]["correct"] += 1
108
+
109
+ except json.JSONDecodeError:
110
+ continue
111
+
112
+ return process_results(
113
+ category_performance, all_questions, all_correct, correct_ids, incorrect_ids
114
+ )
115
+
116
+
117
+ def analyze_llama_results(
118
+ results_file: str, max_questions: Optional[int] = None
119
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
120
+ """
121
+ Analyze results in Llama format.
122
+
123
+ Args:
124
+ results_file: Path to results file
125
+ max_questions: Maximum number of questions to analyze
126
+
127
+ Returns:
128
+ Tuple containing:
129
+ - overall_accuracy (float)
130
+ - category_accuracies (Dict)
131
+ - question_type_stats (Dict)
132
+ - correct_ids (List[str])
133
+ - incorrect_ids (List[str])
134
+ """
135
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
136
+ all_questions = 0
137
+ all_correct = 0
138
+ correct_ids = []
139
+ incorrect_ids = []
140
+
141
+ with open(results_file, "r") as f:
142
+ lines = f.readlines()
143
+
144
+ # If max_questions is set, limit the number of lines processed
145
+ if max_questions is not None:
146
+ lines = lines[:max_questions]
147
+
148
+ for line in tqdm(lines, desc="Analyzing Benchmark Results"):
149
+ if line.startswith("HTTP Request:"):
150
+ continue
151
+
152
+ try:
153
+ entry = json.loads(line)
154
+ metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
155
+ question_id = entry.get("question_id")
156
+
157
+ model_letter = extract_answer_letter(entry.get("model_answer"))
158
+ correct_letter = extract_answer_letter(entry.get("correct_answer"))
159
+
160
+ if model_letter and correct_letter:
161
+ all_questions += 1
162
+ is_correct = model_letter == correct_letter
163
+
164
+ if is_correct:
165
+ all_correct += 1
166
+ correct_ids.append(question_id)
167
+ else:
168
+ incorrect_ids.append(question_id)
169
+
170
+ for category in metadata.get("categories", []):
171
+ category_performance[category]["total"] += 1
172
+ if is_correct:
173
+ category_performance[category]["correct"] += 1
174
+
175
+ except json.JSONDecodeError:
176
+ continue
177
+
178
+ return process_results(
179
+ category_performance, all_questions, all_correct, correct_ids, incorrect_ids
180
+ )
181
+
182
+
183
+ def analyze_chexagent_results(
184
+ results_file: str, max_questions: Optional[int] = None
185
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
186
+ """
187
+ Analyze results in CheXagent format.
188
+
189
+ Args:
190
+ results_file: Path to results file
191
+ max_questions: Maximum number of questions to analyze
192
+
193
+ Returns:
194
+ Tuple containing:
195
+ - overall_accuracy (float)
196
+ - category_accuracies (Dict)
197
+ - question_type_stats (Dict)
198
+ - correct_ids (List[str])
199
+ - incorrect_ids (List[str])
200
+ """
201
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
202
+ all_questions = 0
203
+ all_correct = 0
204
+ correct_ids = []
205
+ incorrect_ids = []
206
+
207
+ with open(results_file, "r") as f:
208
+ lines = f.readlines()
209
+
210
+ # If max_questions is set, limit the number of lines processed
211
+ if max_questions is not None:
212
+ lines = lines[:max_questions]
213
+
214
+ for line in tqdm(lines, desc="Analyzing Benchmark Results"):
215
+ try:
216
+ entry = json.loads(line)
217
+ metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
218
+ question_id = entry.get("question_id")
219
+
220
+ model_letter = extract_answer_letter(entry.get("model_answer"))
221
+ correct_letter = extract_answer_letter(entry.get("correct_answer"))
222
+
223
+ if model_letter and correct_letter:
224
+ all_questions += 1
225
+ is_correct = model_letter == correct_letter
226
+
227
+ if is_correct:
228
+ all_correct += 1
229
+ correct_ids.append(question_id)
230
+ else:
231
+ incorrect_ids.append(question_id)
232
+
233
+ for category in metadata.get("categories", []):
234
+ category_performance[category]["total"] += 1
235
+ if is_correct:
236
+ category_performance[category]["correct"] += 1
237
+
238
+ except json.JSONDecodeError:
239
+ continue
240
+
241
+ return process_results(
242
+ category_performance, all_questions, all_correct, correct_ids, incorrect_ids
243
+ )
244
+
245
+
246
+ def process_results(
247
+ category_performance: Dict,
248
+ all_questions: int,
249
+ all_correct: int,
250
+ correct_ids: Optional[List[str]] = None,
251
+ incorrect_ids: Optional[List[str]] = None,
252
+ ) -> Tuple[float, Dict, Dict, List[str], List[str]]:
253
+ """
254
+ Process raw results into final statistics.
255
+
256
+ Args:
257
+ category_performance: Dict containing performance by category
258
+ all_questions: Total number of questions
259
+ all_correct: Total number of correct answers
260
+ correct_ids: List of IDs for correctly answered questions
261
+ incorrect_ids: List of IDs for incorrectly answered questions
262
+
263
+ Returns:
264
+ Tuple containing:
265
+ - overall_accuracy (float)
266
+ - category_accuracies (Dict)
267
+ - question_type_stats (Dict)
268
+ - correct_ids (List[str])
269
+ - incorrect_ids (List[str])
270
+ """
271
+ category_accuracies = {
272
+ category: {
273
+ "accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0,
274
+ "total": stats["total"],
275
+ "correct": stats["correct"],
276
+ }
277
+ for category, stats in category_performance.items()
278
+ }
279
+
280
+ question_type_stats = {}
281
+ for qtype, categories in QUESTION_TYPES.items():
282
+ total = sum(
283
+ category_performance[cat]["total"] for cat in categories if cat in category_performance
284
+ )
285
+ correct = sum(
286
+ category_performance[cat]["correct"]
287
+ for cat in categories
288
+ if cat in category_performance
289
+ )
290
+
291
+ question_type_stats[qtype] = {
292
+ "accuracy": (correct / total * 100) if total > 0 else 0,
293
+ "total": total,
294
+ "correct": correct,
295
+ }
296
+
297
+ overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0
298
+
299
+ return (
300
+ overall_accuracy,
301
+ category_accuracies,
302
+ question_type_stats,
303
+ correct_ids or [],
304
+ incorrect_ids or [],
305
+ )
306
+
307
+
308
+ def print_analysis(
309
+ overall_accuracy: float,
310
+ category_accuracies: Dict,
311
+ question_type_stats: Dict,
312
+ correct_ids: List[str],
313
+ incorrect_ids: List[str],
314
+ model_name: str,
315
+ ) -> None:
316
+ """
317
+ Print analysis results.
318
+
319
+ Args:
320
+ overall_accuracy: Overall accuracy percentage
321
+ category_accuracies: Dict containing accuracy metrics by category
322
+ question_type_stats: Dict containing stats by question type
323
+ correct_ids: List of IDs for correctly answered questions
324
+ incorrect_ids: List of IDs for incorrectly answered questions
325
+ model_name: Name of the model being analyzed
326
+ """
327
+ total_questions = len(correct_ids) + len(incorrect_ids)
328
+ print(
329
+ f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)"
330
+ )
331
+
332
+ print("\nCategory Performance:")
333
+ sorted_categories = sorted(
334
+ category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True
335
+ )
336
+ for category, metrics in sorted_categories:
337
+ print(f"{category}:")
338
+ print(f" Accuracy: {metrics['accuracy']:.2f}%")
339
+ print(f" Total Questions: {metrics['total']}")
340
+ print(f" Correct Questions: {metrics['correct']}")
341
+
342
+ print("\nQuestion Type Performance:")
343
+ sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True)
344
+ for qtype, metrics in sorted_types:
345
+ print(f"\n{qtype}:")
346
+ print(f" Accuracy: {metrics['accuracy']:.2f}%")
347
+ print(f" Total Questions: {metrics['total']}")
348
+ print(f" Correct Questions: {metrics['correct']}")
349
+ print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}")
350
+
351
+ # Save question IDs to JSON
352
+ question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids}
353
+
354
+ output_filename = f"{model_name}_question_ids.json"
355
+ with open(output_filename, "w") as f:
356
+ json.dump(question_ids, f, indent=2)
357
+
358
+ print(f"\nQuestion IDs have been saved to {output_filename}")
359
+
360
+
361
+ if __name__ == "__main__":
362
+ parser = argparse.ArgumentParser(description="Analyze benchmark results")
363
+ parser.add_argument("results_file", help="Path to results file")
364
+ parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory")
365
+ parser.add_argument(
366
+ "--model",
367
+ choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"],
368
+ default="gpt4",
369
+ help="Specify model format (default: gpt4)",
370
+ )
371
+ parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze")
372
+ args = parser.parse_args()
373
+
374
+ if args.model == "gpt4":
375
+ results = analyze_gpt4_results(args.results_file, args.max_questions)
376
+ elif args.model == "llama":
377
+ results = analyze_llama_results(args.results_file, args.max_questions)
378
+ elif args.model == "chexagent":
379
+ results = analyze_chexagent_results(args.results_file, args.max_questions)
380
+ elif args.model == "medrax":
381
+ results = analyze_gpt4_results(args.results_file, args.max_questions)
382
+ else:
383
+ parser.error(f"Unsupported model: {args.model}")
384
+
385
+ print_analysis(*results, args.model)
experiments/benchmark_chexagent.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+ import glob
5
+ import time
6
+ import logging
7
+ from datetime import datetime
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from tqdm import tqdm
12
+
13
+ # Configure model settings
14
+ MODEL_NAME = "StanfordAIMI/CheXagent-2-3b"
15
+ DTYPE = torch.bfloat16
16
+ DEVICE = "cuda"
17
+
18
+ # Configure logging
19
+ log_filename = f"model_inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
20
+ logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
21
+
22
+
23
+ def initialize_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
24
+ """Initialize the CheXagent model and tokenizer.
25
+
26
+ Returns:
27
+ tuple containing:
28
+ - AutoModelForCausalLM: The initialized CheXagent model
29
+ - AutoTokenizer: The initialized tokenizer
30
+ """
31
+ print("Loading model and tokenizer...")
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ MODEL_NAME, device_map="auto", trust_remote_code=True
35
+ )
36
+ model = model.to(DTYPE)
37
+ model.eval()
38
+ return model, tokenizer
39
+
40
+
41
+ def create_inference_request(
42
+ question_data: dict,
43
+ case_details: dict,
44
+ case_id: str,
45
+ question_id: str,
46
+ model: AutoModelForCausalLM,
47
+ tokenizer: AutoTokenizer,
48
+ ) -> str | None:
49
+ """Create and execute an inference request for the CheXagent model.
50
+
51
+ Args:
52
+ question_data: Dictionary containing question details and metadata
53
+ case_details: Dictionary containing case information and image paths
54
+ case_id: Unique identifier for the medical case
55
+ question_id: Unique identifier for the question
56
+ model: The initialized CheXagent model
57
+ tokenizer: The initialized tokenizer
58
+
59
+ Returns:
60
+ str | None: Single letter answer (A-F) if successful, None if failed
61
+ """
62
+ system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
63
+ Rules:
64
+ 1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
65
+ 2. Do not add periods, explanations, or any other text
66
+ 3. Do not use markdown or formatting
67
+ 4. Do not restate the question
68
+ 5. Do not explain your reasoning
69
+
70
+ Examples of valid responses:
71
+ A
72
+ B
73
+ C
74
+
75
+ Examples of invalid responses:
76
+ "A."
77
+ "Answer: B"
78
+ "C) This shows..."
79
+ "The answer is D"
80
+ """
81
+
82
+ prompt = f"""Given the following medical case:
83
+ Please answer this multiple choice question:
84
+ {question_data['question']}
85
+ Base your answer only on the provided images and case information."""
86
+
87
+ # Parse required figures
88
+ try:
89
+ if isinstance(question_data["figures"], str):
90
+ try:
91
+ required_figures = json.loads(question_data["figures"])
92
+ except json.JSONDecodeError:
93
+ required_figures = [question_data["figures"]]
94
+ elif isinstance(question_data["figures"], list):
95
+ required_figures = question_data["figures"]
96
+ else:
97
+ required_figures = [str(question_data["figures"])]
98
+ except Exception as e:
99
+ print(f"Error parsing figures: {e}")
100
+ required_figures = []
101
+
102
+ required_figures = [
103
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
104
+ ]
105
+
106
+ # Get image paths
107
+ image_paths = []
108
+ for figure in required_figures:
109
+ base_figure_num = "".join(filter(str.isdigit, figure))
110
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
111
+
112
+ matching_figures = [
113
+ case_figure
114
+ for case_figure in case_details.get("figures", [])
115
+ if case_figure["number"] == f"Figure {base_figure_num}"
116
+ ]
117
+
118
+ for case_figure in matching_figures:
119
+ subfigures = []
120
+ if figure_letter:
121
+ subfigures = [
122
+ subfig
123
+ for subfig in case_figure.get("subfigures", [])
124
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
125
+ or subfig.get("label", "").lower() == figure_letter.lower()
126
+ ]
127
+ else:
128
+ subfigures = case_figure.get("subfigures", [])
129
+
130
+ for subfig in subfigures:
131
+ if "local_path" in subfig:
132
+ image_paths.append("medrax/data/" + subfig["local_path"])
133
+
134
+ if not image_paths:
135
+ print(f"No local images found for case {case_id}, question {question_id}")
136
+ return None
137
+
138
+ try:
139
+ start_time = time.time()
140
+
141
+ # Prepare input for the model
142
+ query = tokenizer.from_list_format(
143
+ [*[{"image": path} for path in image_paths], {"text": prompt}]
144
+ )
145
+ conv = [{"from": "system", "value": system_prompt}, {"from": "human", "value": query}]
146
+ input_ids = tokenizer.apply_chat_template(
147
+ conv, add_generation_prompt=True, return_tensors="pt"
148
+ )
149
+
150
+ # Generate response
151
+ with torch.no_grad():
152
+ output = model.generate(
153
+ input_ids.to(DEVICE),
154
+ do_sample=False,
155
+ num_beams=1,
156
+ temperature=1.0,
157
+ top_p=1.0,
158
+ use_cache=True,
159
+ max_new_tokens=512,
160
+ )[0]
161
+
162
+ response = tokenizer.decode(output[input_ids.size(1) : -1])
163
+ duration = time.time() - start_time
164
+
165
+ # Clean response
166
+ clean_answer = validate_answer(response)
167
+
168
+ # Log response
169
+ log_entry = {
170
+ "case_id": case_id,
171
+ "question_id": question_id,
172
+ "timestamp": datetime.now().isoformat(),
173
+ "model": MODEL_NAME,
174
+ "duration": round(duration, 2),
175
+ "model_answer": clean_answer,
176
+ "correct_answer": question_data["answer"],
177
+ "input": {
178
+ "question_data": {
179
+ "question": question_data["question"],
180
+ "explanation": question_data["explanation"],
181
+ "metadata": question_data.get("metadata", {}),
182
+ "figures": question_data["figures"],
183
+ },
184
+ "image_paths": image_paths,
185
+ },
186
+ }
187
+ logging.info(json.dumps(log_entry))
188
+ return clean_answer
189
+
190
+ except Exception as e:
191
+ print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
192
+ log_entry = {
193
+ "case_id": case_id,
194
+ "question_id": question_id,
195
+ "timestamp": datetime.now().isoformat(),
196
+ "model": MODEL_NAME,
197
+ "status": "error",
198
+ "error": str(e),
199
+ "input": {
200
+ "question_data": {
201
+ "question": question_data["question"],
202
+ "explanation": question_data["explanation"],
203
+ "metadata": question_data.get("metadata", {}),
204
+ "figures": question_data["figures"],
205
+ },
206
+ "image_paths": image_paths,
207
+ },
208
+ }
209
+ logging.info(json.dumps(log_entry))
210
+ return None
211
+
212
+
213
+ def validate_answer(response_text: str) -> str | None:
214
+ """Enforce strict single-letter response format.
215
+
216
+ Args:
217
+ response_text: Raw response text from the model
218
+
219
+ Returns:
220
+ str | None: Single uppercase letter (A-F) if valid, None if invalid
221
+ """
222
+ if not response_text:
223
+ return None
224
+
225
+ # Remove all whitespace and convert to uppercase
226
+ cleaned = response_text.strip().upper()
227
+
228
+ # Check if it's exactly one valid letter
229
+ if len(cleaned) == 1 and cleaned in "ABCDEF":
230
+ return cleaned
231
+
232
+ # If not, try to extract just the letter
233
+ match = re.search(r"([A-F])", cleaned)
234
+ return match.group(1) if match else None
235
+
236
+
237
+ def load_benchmark_questions(case_id: str) -> list[str]:
238
+ """Find all question files for a given case ID.
239
+
240
+ Args:
241
+ case_id: Unique identifier for the medical case
242
+
243
+ Returns:
244
+ list[str]: List of paths to question JSON files
245
+ """
246
+ benchmark_dir = "../benchmark/questions"
247
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
248
+
249
+
250
+ def count_total_questions() -> tuple[int, int]:
251
+ """Count total number of cases and questions in benchmark.
252
+
253
+ Returns:
254
+ tuple containing:
255
+ - int: Total number of cases
256
+ - int: Total number of questions
257
+ """
258
+ total_cases = len(glob.glob("../benchmark/questions/*"))
259
+ total_questions = sum(
260
+ len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
261
+ for case_id in os.listdir("../benchmark/questions")
262
+ )
263
+ return total_cases, total_questions
264
+
265
+
266
+ def main():
267
+ # Load the cases with local paths
268
+ with open("medrax/data/updated_cases.json", "r") as file:
269
+ data = json.load(file)
270
+
271
+ # Initialize model and tokenizer
272
+ model, tokenizer = initialize_model()
273
+
274
+ total_cases, total_questions = count_total_questions()
275
+ cases_processed = 0
276
+ questions_processed = 0
277
+ skipped_questions = 0
278
+
279
+ print(f"\nBeginning inference with {MODEL_NAME}")
280
+ print(f"Found {total_cases} cases with {total_questions} total questions")
281
+
282
+ # Process each case with progress bar
283
+ for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
284
+ question_files = load_benchmark_questions(case_id)
285
+ if not question_files:
286
+ continue
287
+
288
+ cases_processed += 1
289
+ for question_file in tqdm(
290
+ question_files, desc=f"Processing questions for case {case_id}", leave=False
291
+ ):
292
+ with open(question_file, "r") as file:
293
+ question_data = json.load(file)
294
+ question_id = os.path.basename(question_file).split(".")[0]
295
+
296
+ questions_processed += 1
297
+ answer = create_inference_request(
298
+ question_data, case_details, case_id, question_id, model, tokenizer
299
+ )
300
+
301
+ if answer is None:
302
+ skipped_questions += 1
303
+ continue
304
+
305
+ print(f"\nCase {case_id}, Question {question_id}")
306
+ print(f"Model Answer: {answer}")
307
+ print(f"Correct Answer: {question_data['answer']}")
308
+
309
+ print(f"\nInference Summary:")
310
+ print(f"Total Cases Processed: {cases_processed}")
311
+ print(f"Total Questions Processed: {questions_processed}")
312
+ print(f"Total Questions Skipped: {skipped_questions}")
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
experiments/benchmark_gpt4o.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import openai
3
+ import os
4
+ import glob
5
+ import time
6
+ import logging
7
+ from datetime import datetime
8
+ from tenacity import retry, wait_exponential, stop_after_attempt
9
+
10
+ model_name = "chatgpt-4o-latest"
11
+ temperature = 0.2
12
+ log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
13
+ logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
14
+
15
+
16
+ def calculate_cost(
17
+ prompt_tokens: int, completion_tokens: int, model: str = "chatgpt-4o-latest"
18
+ ) -> float:
19
+ """Calculate the cost of API usage based on token counts.
20
+
21
+ Args:
22
+ prompt_tokens: Number of tokens in the prompt
23
+ completion_tokens: Number of tokens in the completion
24
+ model: Model name to use for pricing, defaults to chatgpt-4o-latest
25
+
26
+ Returns:
27
+ float: Cost in USD
28
+ """
29
+ pricing = {"chatgpt-4o-latest": {"prompt": 5.0, "completion": 15.0}}
30
+ rates = pricing.get(model, {"prompt": 5.0, "completion": 15.0})
31
+ return (prompt_tokens * rates["prompt"] + completion_tokens * rates["completion"]) / 1000000
32
+
33
+
34
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
35
+ def create_multimodal_request(
36
+ question_data: dict, case_details: dict, case_id: str, question_id: str, client: openai.OpenAI
37
+ ) -> openai.types.chat.ChatCompletion:
38
+ """Create and send a multimodal request to the OpenAI API.
39
+
40
+ Args:
41
+ question_data: Dictionary containing question details and figures
42
+ case_details: Dictionary containing case information and figures
43
+ case_id: Identifier for the medical case
44
+ question_id: Identifier for the specific question
45
+ client: OpenAI client instance
46
+
47
+ Returns:
48
+ openai.types.chat.ChatCompletion: API response object, or None if request fails
49
+ """
50
+ prompt = f"""Given the following medical case:
51
+ Please answer this multiple choice question:
52
+ {question_data['question']}
53
+ Base your answer only on the provided images and case information."""
54
+
55
+ content = [{"type": "text", "text": prompt}]
56
+
57
+ # Parse required figures
58
+ try:
59
+ # Try multiple ways of parsing figures
60
+ if isinstance(question_data["figures"], str):
61
+ try:
62
+ required_figures = json.loads(question_data["figures"])
63
+ except json.JSONDecodeError:
64
+ required_figures = [question_data["figures"]]
65
+ elif isinstance(question_data["figures"], list):
66
+ required_figures = question_data["figures"]
67
+ else:
68
+ required_figures = [str(question_data["figures"])]
69
+ except Exception as e:
70
+ print(f"Error parsing figures: {e}")
71
+ required_figures = []
72
+
73
+ # Ensure each figure starts with "Figure "
74
+ required_figures = [
75
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
76
+ ]
77
+
78
+ subfigures = []
79
+ for figure in required_figures:
80
+ # Handle both regular figures and those with letter suffixes
81
+ base_figure_num = "".join(filter(str.isdigit, figure))
82
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
83
+
84
+ # Find matching figures in case details
85
+ matching_figures = [
86
+ case_figure
87
+ for case_figure in case_details.get("figures", [])
88
+ if case_figure["number"] == f"Figure {base_figure_num}"
89
+ ]
90
+
91
+ if not matching_figures:
92
+ print(f"No matching figure found for {figure} in case {case_id}")
93
+ continue
94
+
95
+ for case_figure in matching_figures:
96
+ # If a specific letter is specified, filter subfigures
97
+ if figure_letter:
98
+ matching_subfigures = [
99
+ subfig
100
+ for subfig in case_figure.get("subfigures", [])
101
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
102
+ or subfig.get("label", "").lower() == figure_letter.lower()
103
+ ]
104
+ subfigures.extend(matching_subfigures)
105
+ else:
106
+ # If no letter specified, add all subfigures
107
+ subfigures.extend(case_figure.get("subfigures", []))
108
+
109
+ # Add images to content
110
+ for subfig in subfigures:
111
+ if "url" in subfig:
112
+ content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
113
+ else:
114
+ print(f"Subfigure missing URL: {subfig}")
115
+
116
+ # If no images found, log and return None
117
+ if len(content) == 1: # Only the text prompt exists
118
+ print(f"No images found for case {case_id}, question {question_id}")
119
+ return None
120
+
121
+ messages = [
122
+ {
123
+ "role": "system",
124
+ "content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F).",
125
+ },
126
+ {"role": "user", "content": content},
127
+ ]
128
+
129
+ if len(content) == 1: # Only the text prompt exists
130
+ print(f"No images found for case {case_id}, question {question_id}")
131
+ log_entry = {
132
+ "case_id": case_id,
133
+ "question_id": question_id,
134
+ "timestamp": datetime.now().isoformat(),
135
+ "model": model_name,
136
+ "temperature": temperature,
137
+ "status": "skipped",
138
+ "reason": "no_images",
139
+ "cost": 0,
140
+ "input": {
141
+ "messages": messages,
142
+ "question_data": {
143
+ "question": question_data["question"],
144
+ "explanation": question_data["explanation"],
145
+ "metadata": question_data.get("metadata", {}),
146
+ "figures": question_data["figures"],
147
+ },
148
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
149
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
150
+ },
151
+ }
152
+ logging.info(json.dumps(log_entry))
153
+ return None
154
+
155
+ try:
156
+ start_time = time.time()
157
+
158
+ response = client.chat.completions.create(
159
+ model=model_name, messages=messages, max_tokens=50, temperature=temperature
160
+ )
161
+ duration = time.time() - start_time
162
+
163
+ log_entry = {
164
+ "case_id": case_id,
165
+ "question_id": question_id,
166
+ "timestamp": datetime.now().isoformat(),
167
+ "model": model_name,
168
+ "temperature": temperature,
169
+ "duration": round(duration, 2),
170
+ "usage": {
171
+ "prompt_tokens": response.usage.prompt_tokens,
172
+ "completion_tokens": response.usage.completion_tokens,
173
+ "total_tokens": response.usage.total_tokens,
174
+ },
175
+ "cost": calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens),
176
+ "model_answer": response.choices[0].message.content,
177
+ "correct_answer": question_data["answer"],
178
+ "input": {
179
+ "messages": messages,
180
+ "question_data": {
181
+ "question": question_data["question"],
182
+ "explanation": question_data["explanation"],
183
+ "metadata": question_data.get("metadata", {}),
184
+ "figures": question_data["figures"],
185
+ },
186
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
187
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
188
+ },
189
+ }
190
+ logging.info(json.dumps(log_entry))
191
+ return response
192
+
193
+ except openai.RateLimitError:
194
+ log_entry = {
195
+ "case_id": case_id,
196
+ "question_id": question_id,
197
+ "timestamp": datetime.now().isoformat(),
198
+ "model": model_name,
199
+ "temperature": temperature,
200
+ "status": "error",
201
+ "reason": "rate_limit",
202
+ "cost": 0,
203
+ "input": {
204
+ "messages": messages,
205
+ "question_data": {
206
+ "question": question_data["question"],
207
+ "explanation": question_data["explanation"],
208
+ "metadata": question_data.get("metadata", {}),
209
+ "figures": question_data["figures"],
210
+ },
211
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
212
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
213
+ },
214
+ }
215
+ logging.info(json.dumps(log_entry))
216
+ print(
217
+ f"\nRate limit hit for case {case_id}, question {question_id}. Waiting 20s...",
218
+ flush=True,
219
+ )
220
+ time.sleep(20)
221
+ raise
222
+ except Exception as e:
223
+ log_entry = {
224
+ "case_id": case_id,
225
+ "question_id": question_id,
226
+ "timestamp": datetime.now().isoformat(),
227
+ "model": model_name,
228
+ "temperature": temperature,
229
+ "status": "error",
230
+ "error": str(e),
231
+ "cost": 0,
232
+ "input": {
233
+ "messages": messages,
234
+ "question_data": {
235
+ "question": question_data["question"],
236
+ "explanation": question_data["explanation"],
237
+ "metadata": question_data.get("metadata", {}),
238
+ "figures": question_data["figures"],
239
+ },
240
+ "image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
241
+ "image_captions": [subfig.get("caption", "") for subfig in subfigures],
242
+ },
243
+ }
244
+ logging.info(json.dumps(log_entry))
245
+ print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
246
+ raise
247
+
248
+
249
+ def load_benchmark_questions(case_id: str) -> list:
250
+ """Load benchmark questions for a given case.
251
+
252
+ Args:
253
+ case_id: Identifier for the medical case
254
+
255
+ Returns:
256
+ list: List of paths to question files
257
+ """
258
+ benchmark_dir = "../benchmark/questions"
259
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
260
+
261
+
262
+ def count_total_questions() -> tuple[int, int]:
263
+ """Count total number of cases and questions in benchmark.
264
+
265
+ Returns:
266
+ tuple: (total_cases, total_questions)
267
+ """
268
+ total_cases = len(glob.glob("../benchmark/questions/*"))
269
+ total_questions = sum(
270
+ len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
271
+ for case_id in os.listdir("../benchmark/questions")
272
+ )
273
+ return total_cases, total_questions
274
+
275
+
276
+ def main() -> None:
277
+ """Main function to run the benchmark evaluation."""
278
+ with open("../data/eurorad_metadata.json", "r") as file:
279
+ data = json.load(file)
280
+
281
+ api_key = os.getenv("OPENAI_API_KEY")
282
+ if not api_key:
283
+ raise ValueError("OPENAI_API_KEY environment variable is not set.")
284
+ global client
285
+ client = openai.OpenAI(api_key=api_key)
286
+
287
+ total_cases, total_questions = count_total_questions()
288
+ cases_processed = 0
289
+ questions_processed = 0
290
+ skipped_questions = 0
291
+
292
+ print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}")
293
+
294
+ for case_id, case_details in data.items():
295
+ question_files = load_benchmark_questions(case_id)
296
+ if not question_files:
297
+ continue
298
+
299
+ cases_processed += 1
300
+ for question_file in question_files:
301
+ with open(question_file, "r") as file:
302
+ question_data = json.load(file)
303
+ question_id = os.path.basename(question_file).split(".")[0]
304
+
305
+ questions_processed += 1
306
+ response = create_multimodal_request(
307
+ question_data, case_details, case_id, question_id, client
308
+ )
309
+
310
+ # Handle cases where response is None
311
+ if response is None:
312
+ skipped_questions += 1
313
+ print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
314
+ continue
315
+
316
+ print(
317
+ f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
318
+ )
319
+ print(f"Case ID: {case_id}")
320
+ print(f"Question ID: {question_id}")
321
+ print(f"Model Answer: {response.choices[0].message.content}")
322
+ print(f"Correct Answer: {question_data['answer']}\n")
323
+
324
+ print(f"\nBenchmark Summary:")
325
+ print(f"Total Cases Processed: {cases_processed}")
326
+ print(f"Total Questions Processed: {questions_processed}")
327
+ print(f"Total Questions Skipped: {skipped_questions}")
328
+
329
+
330
+ if __name__ == "__main__":
331
+ main()
experiments/benchmark_llama.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Any, Union
2
+ import re
3
+ import json
4
+ import os
5
+ import glob
6
+ import time
7
+ import logging
8
+ import socket
9
+ import requests
10
+ import httpx
11
+ import backoff
12
+ from datetime import datetime
13
+ from tenacity import retry, wait_exponential, stop_after_attempt
14
+ from openai import OpenAI
15
+
16
+ # Configure model settings
17
+ MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct"
18
+ temperature = 0.2
19
+
20
+ # Configure logging
21
+ log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
22
+ logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
23
+
24
+
25
+ def verify_dns() -> bool:
26
+ """Verify DNS resolution and connectivity.
27
+
28
+ Returns:
29
+ bool: True if DNS resolution succeeds, False otherwise
30
+ """
31
+ try:
32
+ # Try to resolve openrouter.ai
33
+ socket.gethostbyname("openrouter.ai")
34
+ return True
35
+ except socket.gaierror:
36
+ print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...")
37
+ # Modify resolv.conf to use Google DNS
38
+ try:
39
+ with open("/etc/resolv.conf", "w") as f:
40
+ f.write("nameserver 8.8.8.8\n")
41
+ return True
42
+ except Exception as e:
43
+ print(f"Failed to update DNS settings: {e}")
44
+ return False
45
+
46
+
47
+ def verify_connection() -> bool:
48
+ """Verify connection to OpenRouter API.
49
+
50
+ Returns:
51
+ bool: True if connection succeeds, False otherwise
52
+ """
53
+ try:
54
+ response = requests.get("https://openrouter.ai/api/v1/status", timeout=10)
55
+ return response.status_code == 200
56
+ except Exception as e:
57
+ print(f"Connection test failed: {e}")
58
+ return False
59
+
60
+
61
+ def initialize_client() -> OpenAI:
62
+ """Initialize the OpenRouter client with proper timeout settings and connection verification.
63
+
64
+ Returns:
65
+ OpenAI: Configured OpenAI client for OpenRouter
66
+
67
+ Raises:
68
+ ValueError: If OPENROUTER_API_KEY environment variable is not set
69
+ ConnectionError: If DNS verification or connection test fails
70
+ """
71
+ api_key = os.getenv("OPENROUTER_API_KEY")
72
+ if not api_key:
73
+ raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
74
+
75
+ # Configure timeout settings for the client
76
+ timeout_settings = 120 # Increased timeout for large images/responses
77
+
78
+ # Verify DNS and connection
79
+ if not verify_dns():
80
+ raise ConnectionError("DNS verification failed. Please check your network settings.")
81
+
82
+ if not verify_connection():
83
+ raise ConnectionError(
84
+ "Cannot connect to OpenRouter. Please check your internet connection."
85
+ )
86
+
87
+ # Set up client with retry and timeout settings
88
+ return OpenAI(
89
+ base_url="https://openrouter.ai/api/v1",
90
+ api_key=api_key,
91
+ timeout=timeout_settings,
92
+ http_client=httpx.Client(
93
+ timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3)
94
+ ),
95
+ )
96
+
97
+
98
+ @backoff.on_exception(
99
+ backoff.expo,
100
+ (ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError),
101
+ max_tries=5,
102
+ max_time=300, # Maximum total time to try in seconds
103
+ )
104
+ def create_multimodal_request(
105
+ question_data: Dict[str, Any],
106
+ case_details: Dict[str, Any],
107
+ case_id: str,
108
+ question_id: str,
109
+ client: OpenAI,
110
+ ) -> Optional[Any]:
111
+ """Create and send a multimodal request to the model.
112
+
113
+ Args:
114
+ question_data: Dictionary containing question details
115
+ case_details: Dictionary containing case information
116
+ case_id: ID of the medical case
117
+ question_id: ID of the specific question
118
+ client: OpenAI client instance
119
+
120
+ Returns:
121
+ Optional[Any]: Model response if successful, None if skipped
122
+
123
+ Raises:
124
+ ConnectionError: If connection fails
125
+ TimeoutError: If request times out
126
+ Exception: For other errors
127
+ """
128
+
129
+ system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
130
+ Rules:
131
+ 1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
132
+ 2. Do not add periods, explanations, or any other text
133
+ 3. Do not use markdown or formatting
134
+ 4. Do not restate the question
135
+ 5. Do not explain your reasoning
136
+
137
+ Examples of valid responses:
138
+ A
139
+ B
140
+ C
141
+
142
+ Examples of invalid responses:
143
+ "A."
144
+ "Answer: B"
145
+ "C) This shows..."
146
+ "The answer is D"
147
+ """
148
+
149
+ prompt = f"""Given the following medical case:
150
+ Please answer this multiple choice question:
151
+ {question_data['question']}
152
+ Base your answer only on the provided images and case information."""
153
+
154
+ # Parse required figures
155
+ try:
156
+ if isinstance(question_data["figures"], str):
157
+ try:
158
+ required_figures = json.loads(question_data["figures"])
159
+ except json.JSONDecodeError:
160
+ required_figures = [question_data["figures"]]
161
+ elif isinstance(question_data["figures"], list):
162
+ required_figures = question_data["figures"]
163
+ else:
164
+ required_figures = [str(question_data["figures"])]
165
+ except Exception as e:
166
+ print(f"Error parsing figures: {e}")
167
+ required_figures = []
168
+
169
+ required_figures = [
170
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
171
+ ]
172
+
173
+ # Process subfigures and prepare content
174
+ content = [{"type": "text", "text": prompt}]
175
+ image_urls = []
176
+ image_captions = []
177
+
178
+ for figure in required_figures:
179
+ base_figure_num = "".join(filter(str.isdigit, figure))
180
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
181
+
182
+ matching_figures = [
183
+ case_figure
184
+ for case_figure in case_details.get("figures", [])
185
+ if case_figure["number"] == f"Figure {base_figure_num}"
186
+ ]
187
+
188
+ for case_figure in matching_figures:
189
+ subfigures = []
190
+ if figure_letter:
191
+ subfigures = [
192
+ subfig
193
+ for subfig in case_figure.get("subfigures", [])
194
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
195
+ or subfig.get("label", "").lower() == figure_letter.lower()
196
+ ]
197
+ else:
198
+ subfigures = case_figure.get("subfigures", [])
199
+
200
+ for subfig in subfigures:
201
+ if "url" in subfig:
202
+ content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
203
+ image_urls.append(subfig["url"])
204
+ image_captions.append(subfig.get("caption", ""))
205
+
206
+ if len(content) == 1: # Only the text prompt exists
207
+ print(f"No images found for case {case_id}, question {question_id}")
208
+ # Log the skipped question
209
+ log_entry = {
210
+ "case_id": case_id,
211
+ "question_id": question_id,
212
+ "timestamp": datetime.now().isoformat(),
213
+ "model": MODEL_NAME,
214
+ "status": "skipped",
215
+ "reason": "no_images",
216
+ "input": {
217
+ "question_data": {
218
+ "question": question_data["question"],
219
+ "explanation": question_data["explanation"],
220
+ "metadata": question_data.get("metadata", {}),
221
+ "figures": question_data["figures"],
222
+ },
223
+ "image_urls": image_urls,
224
+ },
225
+ }
226
+ logging.info(json.dumps(log_entry))
227
+ return None
228
+
229
+ try:
230
+ start_time = time.time()
231
+
232
+ response = client.chat.completions.create(
233
+ model=MODEL_NAME,
234
+ temperature=temperature,
235
+ messages=[
236
+ {"role": "system", "content": system_prompt},
237
+ {"role": "user", "content": content},
238
+ ],
239
+ )
240
+ duration = time.time() - start_time
241
+
242
+ # Get raw response
243
+ raw_answer = response.choices[0].message.content
244
+
245
+ # Validate and clean
246
+ clean_answer = validate_answer(raw_answer)
247
+
248
+ if not clean_answer:
249
+ print(f"Warning: Invalid response format for case {case_id}, question {question_id}")
250
+ print(f"Raw response: {raw_answer}")
251
+
252
+ # Update response object with cleaned answer
253
+ response.choices[0].message.content = clean_answer
254
+
255
+ # Log response
256
+ log_entry = {
257
+ "case_id": case_id,
258
+ "question_id": question_id,
259
+ "timestamp": datetime.now().isoformat(),
260
+ "model": MODEL_NAME,
261
+ "temperature": temperature,
262
+ "duration": round(duration, 2),
263
+ "usage": {
264
+ "prompt_tokens": response.usage.prompt_tokens,
265
+ "completion_tokens": response.usage.completion_tokens,
266
+ "total_tokens": response.usage.total_tokens,
267
+ },
268
+ "model_answer": response.choices[0].message.content,
269
+ "correct_answer": question_data["answer"],
270
+ "input": {
271
+ "question_data": {
272
+ "question": question_data["question"],
273
+ "explanation": question_data["explanation"],
274
+ "metadata": question_data.get("metadata", {}),
275
+ "figures": question_data["figures"],
276
+ },
277
+ "image_urls": image_urls,
278
+ },
279
+ }
280
+ logging.info(json.dumps(log_entry))
281
+ return response
282
+
283
+ except ConnectionError as e:
284
+ print(f"Connection error for case {case_id}, question {question_id}: {str(e)}")
285
+ print("Retrying after a longer delay...")
286
+ time.sleep(30) # Add a longer delay before retry
287
+ raise
288
+ except TimeoutError as e:
289
+ print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}")
290
+ print("Retrying with increased timeout...")
291
+ raise
292
+ except Exception as e:
293
+ # Log failed requests too
294
+ log_entry = {
295
+ "case_id": case_id,
296
+ "question_id": question_id,
297
+ "timestamp": datetime.now().isoformat(),
298
+ "model": MODEL_NAME,
299
+ "temperature": temperature,
300
+ "status": "error",
301
+ "error": str(e),
302
+ "input": {
303
+ "question_data": {
304
+ "question": question_data["question"],
305
+ "explanation": question_data["explanation"],
306
+ "metadata": question_data.get("metadata", {}),
307
+ "figures": question_data["figures"],
308
+ },
309
+ "image_urls": image_urls,
310
+ },
311
+ }
312
+ logging.info(json.dumps(log_entry))
313
+ raise
314
+
315
+
316
+ def extract_answer(response_text: str) -> Optional[str]:
317
+ """Extract single letter answer from model response.
318
+
319
+ Args:
320
+ response_text: Raw text response from model
321
+
322
+ Returns:
323
+ Optional[str]: Single letter answer if found, None otherwise
324
+ """
325
+ # Convert to uppercase and remove periods
326
+ text = response_text.upper().replace(".", "")
327
+
328
+ # Look for common patterns
329
+ patterns = [
330
+ r"ANSWER:\s*([A-F])", # Matches "ANSWER: X"
331
+ r"OPTION\s*([A-F])", # Matches "OPTION X"
332
+ r"([A-F])\)", # Matches "X)"
333
+ r"\b([A-F])\b", # Matches single letter
334
+ ]
335
+
336
+ for pattern in patterns:
337
+ matches = re.findall(pattern, text)
338
+ if matches:
339
+ return matches[0]
340
+
341
+ return None
342
+
343
+
344
+ def validate_answer(response_text: str) -> Optional[str]:
345
+ """Enforce strict single-letter response format.
346
+
347
+ Args:
348
+ response_text: Raw text response from model
349
+
350
+ Returns:
351
+ Optional[str]: Valid single letter answer if found, None otherwise
352
+ """
353
+ if not response_text:
354
+ return None
355
+
356
+ # Remove all whitespace and convert to uppercase
357
+ cleaned = response_text.strip().upper()
358
+
359
+ # Check if it's exactly one valid letter
360
+ if len(cleaned) == 1 and cleaned in "ABCDEF":
361
+ return cleaned
362
+
363
+ # If not, try to extract just the letter
364
+ match = re.search(r"([A-F])", cleaned)
365
+ return match.group(1) if match else None
366
+
367
+
368
+ def load_benchmark_questions(case_id: str) -> List[str]:
369
+ """Find all question files for a given case ID.
370
+
371
+ Args:
372
+ case_id: ID of the medical case
373
+
374
+ Returns:
375
+ List[str]: List of paths to question files
376
+ """
377
+ benchmark_dir = "../benchmark/questions"
378
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
379
+
380
+
381
+ def count_total_questions() -> Tuple[int, int]:
382
+ """Count total number of cases and questions.
383
+
384
+ Returns:
385
+ Tuple[int, int]: (total_cases, total_questions)
386
+ """
387
+ total_cases = len(glob.glob("../benchmark/questions/*"))
388
+ total_questions = sum(
389
+ len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
390
+ for case_id in os.listdir("../benchmark/questions")
391
+ )
392
+ return total_cases, total_questions
393
+
394
+
395
+ def main():
396
+ with open("../data/eurorad_metadata.json", "r") as file:
397
+ data = json.load(file)
398
+
399
+ client = initialize_client()
400
+ total_cases, total_questions = count_total_questions()
401
+ cases_processed = 0
402
+ questions_processed = 0
403
+ skipped_questions = 0
404
+
405
+ print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}")
406
+
407
+ for case_id, case_details in data.items():
408
+ question_files = load_benchmark_questions(case_id)
409
+ if not question_files:
410
+ continue
411
+
412
+ cases_processed += 1
413
+ for question_file in question_files:
414
+ with open(question_file, "r") as file:
415
+ question_data = json.load(file)
416
+ question_id = os.path.basename(question_file).split(".")[0]
417
+
418
+ questions_processed += 1
419
+ response = create_multimodal_request(
420
+ question_data, case_details, case_id, question_id, client
421
+ )
422
+
423
+ if response is None:
424
+ skipped_questions += 1
425
+ print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
426
+ continue
427
+
428
+ print(
429
+ f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
430
+ )
431
+ print(f"Case ID: {case_id}")
432
+ print(f"Question ID: {question_id}")
433
+ print(f"Model Answer: {response.choices[0].message.content}")
434
+ print(f"Correct Answer: {question_data['answer']}\n")
435
+
436
+ print(f"\nBenchmark Summary:")
437
+ print(f"Total Cases Processed: {cases_processed}")
438
+ print(f"Total Questions Processed: {questions_processed}")
439
+ print(f"Total Questions Skipped: {skipped_questions}")
440
+
441
+
442
+ if __name__ == "__main__":
443
+ main()
experiments/benchmark_llavamed.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import requests
4
+ import base64
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ from llava.conversation import conv_templates
8
+ import time
9
+ import os
10
+ import glob
11
+ import logging
12
+ from datetime import datetime
13
+ from tqdm import tqdm
14
+ import re
15
+ from typing import Dict, List, Optional, Union, Any, Tuple
16
+
17
+
18
+ def process_image(image_path: str, target_size: int = 640) -> Image.Image:
19
+ """Process and resize an image to match model requirements.
20
+
21
+ Args:
22
+ image_path: Path to the input image file
23
+ target_size: Target size for both width and height in pixels
24
+
25
+ Returns:
26
+ PIL.Image: Processed and padded image with dimensions (target_size, target_size)
27
+ """
28
+ image = Image.open(image_path)
29
+ if image.mode != "RGB":
30
+ image = image.convert("RGB")
31
+
32
+ # Calculate scaling to maintain aspect ratio
33
+ ratio = min(target_size / image.width, target_size / image.height)
34
+ new_size = (int(image.width * ratio), int(image.height * ratio))
35
+
36
+ # Resize image
37
+ image = image.resize(new_size, Image.LANCZOS)
38
+
39
+ # Create new image with padding
40
+ new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
41
+ # Paste resized image in center
42
+ offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2)
43
+ new_image.paste(image, offset)
44
+
45
+ return new_image
46
+
47
+
48
+ def validate_answer(response_text: str) -> Optional[str]:
49
+ """Extract and validate a single-letter response from the model's output.
50
+ Handles multiple response formats and edge cases.
51
+
52
+ Args:
53
+ response_text: The full text output from the model
54
+
55
+ Returns:
56
+ A single letter answer (A-F) or None if no valid answer found
57
+ """
58
+ if not response_text:
59
+ return None
60
+
61
+ # Clean the response text
62
+ cleaned = response_text.strip()
63
+
64
+ # Comprehensive set of patterns to extract the answer
65
+ extraction_patterns = [
66
+ # Strict format with explicit letter answer
67
+ r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b",
68
+ # Patterns for extracting from longer descriptions
69
+ r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
70
+ r"\b(?:answer|option)\s*([A-F])[):]\s*",
71
+ # Patterns for extracting from descriptive sentences
72
+ r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
73
+ r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b",
74
+ # Patterns with contextual words
75
+ r"characteriz[e]?d?\s+by\s+([A-F])\b",
76
+ r"indicat[e]?s?\s+([A-F])\b",
77
+ # Fallback to Option X or Letterr X formats
78
+ r"Option\s*([A-F])\b",
79
+ r"\b([A-F])\)\s*",
80
+ # Fallback to standalone letter
81
+ r"^\s*([A-F])\s*$",
82
+ ]
83
+
84
+ # Try each pattern
85
+ for pattern in extraction_patterns:
86
+ matches = re.findall(pattern, cleaned, re.IGNORECASE)
87
+ for match in matches:
88
+ # Ensure match is a single valid letter
89
+ if isinstance(match, tuple):
90
+ match = match[0] if match[0] in "ABCDEF" else None
91
+ if match and match.upper() in "ABCDEF":
92
+ return match.upper()
93
+
94
+ # Final fallback: look for standalone letters in context
95
+ context_matches = re.findall(r"\b([A-F])\b", cleaned.upper())
96
+ context_letters = [m for m in context_matches if m in "ABCDEF"]
97
+ if context_letters:
98
+ return context_letters[0]
99
+
100
+ # No valid answer found
101
+ return None
102
+
103
+
104
+ def load_benchmark_questions(case_id: str) -> List[str]:
105
+ """Find all question files for a given case ID.
106
+
107
+ Args:
108
+ case_id: The ID of the medical case
109
+
110
+ Returns:
111
+ List of paths to question JSON files
112
+ """
113
+ benchmark_dir = "MedMAX/benchmark/questions"
114
+ return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
115
+
116
+
117
+ def count_total_questions() -> Tuple[int, int]:
118
+ """Count total number of cases and questions in benchmark.
119
+
120
+ Returns:
121
+ Tuple containing (total_cases, total_questions)
122
+ """
123
+ total_cases = len(glob.glob("MedMAX/benchmark/questions/*"))
124
+ total_questions = sum(
125
+ len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json"))
126
+ for case_id in os.listdir("MedMAX/benchmark/questions")
127
+ )
128
+ return total_cases, total_questions
129
+
130
+
131
+ def create_inference_request(
132
+ question_data: Dict[str, Any],
133
+ case_details: Dict[str, Any],
134
+ case_id: str,
135
+ question_id: str,
136
+ worker_addr: str,
137
+ model_name: str,
138
+ raw_output: bool = False,
139
+ ) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]:
140
+ """Create and send inference request to worker.
141
+
142
+ Args:
143
+ question_data: Dictionary containing question details and figures
144
+ case_details: Dictionary containing case information and figures
145
+ case_id: Identifier for the medical case
146
+ question_id: Identifier for the specific question
147
+ worker_addr: Address of the worker endpoint
148
+ model_name: Name of the model to use
149
+ raw_output: Whether to return raw model output
150
+
151
+ Returns:
152
+ If raw_output is False: Tuple of (validated_answer, duration)
153
+ If raw_output is True: Dictionary with full inference details
154
+ """
155
+ system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'.
156
+ """
157
+
158
+ prompt = f"""Given the following medical case:
159
+ Please answer this multiple choice question:
160
+ {question_data['question']}
161
+ Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """
162
+
163
+ try:
164
+ # Parse required figures
165
+ if isinstance(question_data["figures"], str):
166
+ try:
167
+ required_figures = json.loads(question_data["figures"])
168
+ except json.JSONDecodeError:
169
+ required_figures = [question_data["figures"]]
170
+ elif isinstance(question_data["figures"], list):
171
+ required_figures = question_data["figures"]
172
+ else:
173
+ required_figures = [str(question_data["figures"])]
174
+ except Exception as e:
175
+ print(f"Error parsing figures: {e}")
176
+ required_figures = []
177
+
178
+ required_figures = [
179
+ fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
180
+ ]
181
+
182
+ # Get image paths
183
+ image_paths = []
184
+ for figure in required_figures:
185
+ base_figure_num = "".join(filter(str.isdigit, figure))
186
+ figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
187
+
188
+ matching_figures = [
189
+ case_figure
190
+ for case_figure in case_details.get("figures", [])
191
+ if case_figure["number"] == f"Figure {base_figure_num}"
192
+ ]
193
+
194
+ for case_figure in matching_figures:
195
+ subfigures = []
196
+ if figure_letter:
197
+ subfigures = [
198
+ subfig
199
+ for subfig in case_figure.get("subfigures", [])
200
+ if subfig.get("number", "").lower().endswith(figure_letter.lower())
201
+ or subfig.get("label", "").lower() == figure_letter.lower()
202
+ ]
203
+ else:
204
+ subfigures = case_figure.get("subfigures", [])
205
+
206
+ for subfig in subfigures:
207
+ if "local_path" in subfig:
208
+ image_paths.append("MedMAX/data/" + subfig["local_path"])
209
+
210
+ if not image_paths:
211
+ print(f"No local images found for case {case_id}, question {question_id}")
212
+ return "skipped", 0.0 # Return a special 'skipped' marker
213
+
214
+ try:
215
+ start_time = time.time()
216
+
217
+ # Process each image
218
+ processed_images = [process_image(path) for path in image_paths]
219
+
220
+ # Create conversation
221
+ conv = conv_templates["mistral_instruct"].copy()
222
+
223
+ # Add image and message
224
+ if "<image>" not in prompt:
225
+ text = prompt + "\n<image>"
226
+ else:
227
+ text = prompt
228
+
229
+ message = (text, processed_images[0], "Default") # Currently handling first image
230
+ conv.append_message(conv.roles[0], message)
231
+ conv.append_message(conv.roles[1], None)
232
+
233
+ prompt = conv.get_prompt()
234
+ headers = {"User-Agent": "LLaVA-Med Client"}
235
+ pload = {
236
+ "model": model_name,
237
+ "prompt": prompt,
238
+ "max_new_tokens": 150, # Reduce this since we only need one letter
239
+ "temperature": 0.5, # Lower temperature for more focused responses
240
+ "stop": conv.sep2,
241
+ "images": conv.get_images(),
242
+ "top_p": 1, # Lower top_p for more focused sampling
243
+ "frequency_penalty": 0.0,
244
+ "presence_penalty": 0.0,
245
+ }
246
+
247
+ max_retries = 3
248
+ retry_delay = 5
249
+ response_text = None
250
+
251
+ for attempt in range(max_retries):
252
+ try:
253
+ response = requests.post(
254
+ worker_addr + "/worker_generate_stream",
255
+ headers=headers,
256
+ json=pload,
257
+ stream=True,
258
+ timeout=30,
259
+ )
260
+
261
+ complete_output = ""
262
+ for chunk in response.iter_lines(
263
+ chunk_size=8192, decode_unicode=False, delimiter=b"\0"
264
+ ):
265
+ if chunk:
266
+ data = json.loads(chunk.decode("utf-8"))
267
+ if data["error_code"] == 0:
268
+ output = data["text"].split("[/INST]")[-1]
269
+ complete_output = output
270
+ else:
271
+ print(f"\nError: {data['text']} (error_code: {data['error_code']})")
272
+ if attempt < max_retries - 1:
273
+ time.sleep(retry_delay)
274
+ break
275
+ return None, None
276
+
277
+ if complete_output:
278
+ response_text = complete_output
279
+ break
280
+
281
+ except (requests.exceptions.RequestException, json.JSONDecodeError) as e:
282
+ if attempt < max_retries - 1:
283
+ print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...")
284
+ time.sleep(retry_delay)
285
+ else:
286
+ print(f"\nFailed after {max_retries} attempts: {str(e)}")
287
+ return None, None
288
+
289
+ duration = time.time() - start_time
290
+
291
+ if raw_output:
292
+ inference_details = {
293
+ "raw_output": response_text,
294
+ "validated_answer": validate_answer(response_text),
295
+ "duration": duration,
296
+ "prompt": prompt,
297
+ "system_prompt": system_prompt,
298
+ "image_paths": image_paths,
299
+ "payload": pload,
300
+ }
301
+ return inference_details
302
+
303
+ return validate_answer(response_text), duration
304
+
305
+ except Exception as e:
306
+ print(f"Error in inference request: {str(e)}")
307
+ return None, None
308
+
309
+
310
+ def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
311
+ """Remove image-related and large data from the payload to keep the log lean.
312
+
313
+ Args:
314
+ payload: Original request payload dictionary
315
+
316
+ Returns:
317
+ Cleaned payload dictionary with large data removed
318
+ """
319
+ if not payload:
320
+ return None
321
+
322
+ # Create a copy of the payload to avoid modifying the original
323
+ cleaned_payload = payload.copy()
324
+
325
+ # Remove large or sensitive data
326
+ if "images" in cleaned_payload:
327
+ del cleaned_payload["images"]
328
+
329
+ return cleaned_payload
330
+
331
+
332
+ def main():
333
+ parser = argparse.ArgumentParser()
334
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
335
+ parser.add_argument("--worker-address", type=str)
336
+ parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b")
337
+ parser.add_argument("--output-dir", type=str, default="benchmark_results")
338
+ parser.add_argument(
339
+ "--raw-output", action="store_true", help="Return raw model output without validation"
340
+ )
341
+ parser.add_argument(
342
+ "--num-cases",
343
+ type=int,
344
+ help="Number of cases to process if looking at raw outputs",
345
+ default=2,
346
+ )
347
+ args = parser.parse_args()
348
+
349
+ # Setup output directory
350
+ os.makedirs(args.output_dir, exist_ok=True)
351
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
352
+
353
+ # Setup live logging files
354
+ live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json")
355
+ final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json")
356
+
357
+ # Initialize live log file
358
+ with open(live_log_filename, "w") as live_log_file:
359
+ live_log_file.write("[\n") # Start of JSON array
360
+
361
+ # Setup logging
362
+ logging.basicConfig(
363
+ filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"),
364
+ level=logging.INFO,
365
+ format="%(message)s",
366
+ )
367
+
368
+ # Get worker address
369
+ if args.worker_address:
370
+ worker_addr = args.worker_address
371
+ else:
372
+ try:
373
+ requests.post(args.controller_address + "/refresh_all_workers")
374
+ ret = requests.post(args.controller_address + "/list_models")
375
+ models = ret.json()["models"]
376
+ ret = requests.post(
377
+ args.controller_address + "/get_worker_address", json={"model": args.model_name}
378
+ )
379
+ worker_addr = ret.json()["address"]
380
+ print(f"Worker address: {worker_addr}")
381
+ except requests.exceptions.RequestException as e:
382
+ print(f"Failed to connect to controller: {e}")
383
+ return
384
+
385
+ if worker_addr == "":
386
+ print("No available worker")
387
+ return
388
+
389
+ # Load cases with local paths
390
+ with open("MedMAX/data/updated_cases.json", "r") as file:
391
+ data = json.load(file)
392
+
393
+ total_cases, total_questions = count_total_questions()
394
+ print(f"\nStarting benchmark with {args.model_name}")
395
+ print(f"Found {total_cases} cases with {total_questions} total questions")
396
+
397
+ results = {
398
+ "model": args.model_name,
399
+ "timestamp": datetime.now().isoformat(),
400
+ "total_cases": total_cases,
401
+ "total_questions": total_questions,
402
+ "results": [],
403
+ }
404
+
405
+ cases_processed = 0
406
+ questions_processed = 0
407
+ correct_answers = 0
408
+ skipped_questions = 0
409
+ total_processed_entries = 0
410
+
411
+ # Process each case
412
+ for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
413
+ question_files = load_benchmark_questions(case_id)
414
+ if not question_files:
415
+ continue
416
+
417
+ cases_processed += 1
418
+ for question_file in tqdm(
419
+ question_files, desc=f"Processing questions for case {case_id}", leave=False
420
+ ):
421
+ with open(question_file, "r") as file:
422
+ question_data = json.load(file)
423
+ question_id = os.path.basename(question_file).split(".")[0]
424
+
425
+ questions_processed += 1
426
+
427
+ # Get model's answer
428
+ inference_result = create_inference_request(
429
+ question_data,
430
+ case_details,
431
+ case_id,
432
+ question_id,
433
+ worker_addr,
434
+ args.model_name,
435
+ raw_output=True, # Always use raw output for detailed logging
436
+ )
437
+
438
+ # Handle skipped questions
439
+ if inference_result == ("skipped", 0.0):
440
+ skipped_questions += 1
441
+ print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)")
442
+
443
+ # Log skipped question
444
+ skipped_entry = {
445
+ "case_id": case_id,
446
+ "question_id": question_id,
447
+ "status": "skipped",
448
+ "reason": "No images found",
449
+ }
450
+ with open(live_log_filename, "a") as live_log_file:
451
+ json.dump(skipped_entry, live_log_file, indent=2)
452
+ live_log_file.write(",\n") # Add comma for next entry
453
+
454
+ continue
455
+
456
+ # Extract information
457
+ answer = inference_result["validated_answer"]
458
+ duration = inference_result["duration"]
459
+
460
+ # Prepare detailed logging entry
461
+ log_entry = {
462
+ "case_id": case_id,
463
+ "question_id": question_id,
464
+ "question": question_data["question"],
465
+ "correct_answer": question_data["answer"],
466
+ "raw_output": inference_result["raw_output"],
467
+ "validated_answer": answer,
468
+ "model_answer": answer,
469
+ "is_correct": answer == question_data["answer"] if answer else False,
470
+ "duration": duration,
471
+ "system_prompt": inference_result["system_prompt"],
472
+ "input_prompt": inference_result["prompt"],
473
+ "image_paths": inference_result["image_paths"],
474
+ "payload": clean_payload(inference_result["payload"]),
475
+ }
476
+
477
+ # Write to live log file
478
+ with open(live_log_filename, "a") as live_log_file:
479
+ json.dump(log_entry, live_log_file, indent=2)
480
+ live_log_file.write(",\n") # Add comma for next entry
481
+
482
+ # Print to console
483
+ print(f"\nCase {case_id}, Question {question_id}")
484
+ print(f"Model Answer: {answer}")
485
+ print(f"Correct Answer: {question_data['answer']}")
486
+ print(f"Time taken: {duration:.2f}s")
487
+
488
+ # Track correct answers
489
+ if answer == question_data["answer"]:
490
+ correct_answers += 1
491
+
492
+ # Append to results
493
+ results["results"].append(log_entry)
494
+ total_processed_entries += 1
495
+
496
+ # Optional: break if reached specified number of cases
497
+ if args.raw_output and cases_processed == args.num_cases:
498
+ break
499
+
500
+ # Optional: break if reached specified number of cases
501
+ if args.raw_output and cases_processed == args.num_cases:
502
+ break
503
+
504
+ # Close live log file
505
+ with open(live_log_filename, "a") as live_log_file:
506
+ # Remove trailing comma and close JSON array
507
+ live_log_file.seek(live_log_file.tell() - 2, 0) # Go back 2 chars to remove ',\n'
508
+ live_log_file.write("\n]")
509
+
510
+ # Calculate final statistics
511
+ results["summary"] = {
512
+ "cases_processed": cases_processed,
513
+ "questions_processed": questions_processed,
514
+ "total_processed_entries": total_processed_entries,
515
+ "correct_answers": correct_answers,
516
+ "skipped_questions": skipped_questions,
517
+ "accuracy": (
518
+ correct_answers / (questions_processed - skipped_questions)
519
+ if (questions_processed - skipped_questions) > 0
520
+ else 0
521
+ ),
522
+ }
523
+
524
+ # Save final results
525
+ with open(final_results_filename, "w") as f:
526
+ json.dump(results, f, indent=2)
527
+
528
+ print(f"\nBenchmark Summary:")
529
+ print(f"Total Cases Processed: {cases_processed}")
530
+ print(f"Total Questions Processed: {questions_processed}")
531
+ print(f"Total Processed Entries: {total_processed_entries}")
532
+ print(f"Correct Answers: {correct_answers}")
533
+ print(f"Skipped Questions: {skipped_questions}")
534
+ print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%")
535
+ print(f"\nResults saved to {args.output_dir}")
536
+ print(f"Live log: {live_log_filename}")
537
+ print(f"Final results: {final_results_filename}")
538
+
539
+
540
+ if __name__ == "__main__":
541
+ main()
experiments/benchmark_medrax.ipynb ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import operator\n",
10
+ "import warnings\n",
11
+ "from typing import *\n",
12
+ "import traceback\n",
13
+ "\n",
14
+ "import os\n",
15
+ "import torch\n",
16
+ "from dotenv import load_dotenv\n",
17
+ "from IPython.display import Image\n",
18
+ "from langgraph.checkpoint.memory import MemorySaver\n",
19
+ "from langgraph.graph import END, StateGraph\n",
20
+ "from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage\n",
21
+ "from langchain_openai import ChatOpenAI\n",
22
+ "from transformers import logging\n",
23
+ "import matplotlib.pyplot as plt\n",
24
+ "import numpy as np\n",
25
+ "import re\n",
26
+ "\n",
27
+ "from medrax.agent import *\n",
28
+ "from medrax.tools import *\n",
29
+ "from medrax.utils import *\n",
30
+ "\n",
31
+ "import json\n",
32
+ "import openai\n",
33
+ "import os\n",
34
+ "import glob\n",
35
+ "import time\n",
36
+ "import logging\n",
37
+ "from datetime import datetime\n",
38
+ "from tenacity import retry, wait_exponential, stop_after_attempt\n",
39
+ "\n",
40
+ "warnings.filterwarnings(\"ignore\")\n",
41
+ "_ = load_dotenv()\n",
42
+ "\n",
43
+ "\n",
44
+ "# Setup directory paths\n",
45
+ "ROOT = \"set this directory to where MedRAX is, .e.g /home/MedRAX\"\n",
46
+ "PROMPT_FILE = f\"{ROOT}/medrax/docs/system_prompts.txt\"\n",
47
+ "BENCHMARK_FILE = f\"{ROOT}/benchmark/questions\"\n",
48
+ "MODEL_DIR = f\"set this to where the tool models are, e.g /home/models\"\n",
49
+ "FIGURES_DIR = f\"{ROOT}/benchmark/figures\"\n",
50
+ "\n",
51
+ "model_name = \"medrax\"\n",
52
+ "temperature = 0.2\n",
53
+ "medrax_logs = f\"{ROOT}/experiments/medrax_logs\"\n",
54
+ "log_filename = f\"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n",
55
+ "logging.basicConfig(filename=log_filename, level=logging.INFO, format=\"%(message)s\", force=True)\n",
56
+ "device = \"cuda\""
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 2,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "def get_tools():\n",
66
+ " report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)\n",
67
+ " xray_classification_tool = ChestXRayClassifierTool(device=device)\n",
68
+ " segmentation_tool = ChestXRaySegmentationTool(device=device)\n",
69
+ " grounding_tool = XRayPhraseGroundingTool(\n",
70
+ " cache_dir=MODEL_DIR, temp_dir=\"temp\", device=device, load_in_8bit=True\n",
71
+ " )\n",
72
+ " xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)\n",
73
+ " llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)\n",
74
+ "\n",
75
+ " return [\n",
76
+ " report_tool,\n",
77
+ " xray_classification_tool,\n",
78
+ " segmentation_tool,\n",
79
+ " grounding_tool,\n",
80
+ " xray_vqa_tool,\n",
81
+ " llava_med_tool,\n",
82
+ " ]\n",
83
+ "\n",
84
+ "\n",
85
+ "def get_agent(tools):\n",
86
+ " prompts = load_prompts_from_file(PROMPT_FILE)\n",
87
+ " prompt = prompts[\"MEDICAL_ASSISTANT\"]\n",
88
+ "\n",
89
+ " checkpointer = MemorySaver()\n",
90
+ " model = ChatOpenAI(model=\"gpt-4o\", temperature=temperature, top_p=0.95)\n",
91
+ " agent = Agent(\n",
92
+ " model,\n",
93
+ " tools=tools,\n",
94
+ " log_tools=True,\n",
95
+ " log_dir=\"logs\",\n",
96
+ " system_prompt=prompt,\n",
97
+ " checkpointer=checkpointer,\n",
98
+ " )\n",
99
+ " thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
100
+ " return agent, thread\n",
101
+ "\n",
102
+ "\n",
103
+ "def run_medrax(agent, thread, prompt, image_urls=[]):\n",
104
+ " messages = [\n",
105
+ " HumanMessage(\n",
106
+ " content=[\n",
107
+ " {\"type\": \"text\", \"text\": prompt},\n",
108
+ " ]\n",
109
+ " + [{\"type\": \"image_url\", \"image_url\": {\"url\": image_url}} for image_url in image_urls]\n",
110
+ " )\n",
111
+ " ]\n",
112
+ "\n",
113
+ " final_response = None\n",
114
+ " for event in agent.workflow.stream({\"messages\": messages}, thread):\n",
115
+ " for v in event.values():\n",
116
+ " final_response = v\n",
117
+ "\n",
118
+ " final_response = final_response[\"messages\"][-1].content.strip()\n",
119
+ " agent_state = agent.workflow.get_state(thread)\n",
120
+ "\n",
121
+ " return final_response, str(agent_state)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 3,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):\n",
131
+ " # Parse required figures\n",
132
+ " try:\n",
133
+ " # Try multiple ways of parsing figures\n",
134
+ " if isinstance(question_data[\"figures\"], str):\n",
135
+ " try:\n",
136
+ " required_figures = json.loads(question_data[\"figures\"])\n",
137
+ " except json.JSONDecodeError:\n",
138
+ " required_figures = [question_data[\"figures\"]]\n",
139
+ " elif isinstance(question_data[\"figures\"], list):\n",
140
+ " required_figures = question_data[\"figures\"]\n",
141
+ " else:\n",
142
+ " required_figures = [str(question_data[\"figures\"])]\n",
143
+ " except Exception as e:\n",
144
+ " print(f\"Error parsing figures: {e}\")\n",
145
+ " required_figures = []\n",
146
+ "\n",
147
+ " # Ensure each figure starts with \"Figure \"\n",
148
+ " required_figures = [\n",
149
+ " fig if fig.startswith(\"Figure \") else f\"Figure {fig}\" for fig in required_figures\n",
150
+ " ]\n",
151
+ "\n",
152
+ " subfigures = []\n",
153
+ " for figure in required_figures:\n",
154
+ " # Handle both regular figures and those with letter suffixes\n",
155
+ " base_figure_num = \"\".join(filter(str.isdigit, figure))\n",
156
+ " figure_letter = \"\".join(filter(str.isalpha, figure.split()[-1])) or None\n",
157
+ "\n",
158
+ " # Find matching figures in case details\n",
159
+ " matching_figures = [\n",
160
+ " case_figure\n",
161
+ " for case_figure in case_details.get(\"figures\", [])\n",
162
+ " if case_figure[\"number\"] == f\"Figure {base_figure_num}\"\n",
163
+ " ]\n",
164
+ "\n",
165
+ " if not matching_figures:\n",
166
+ " print(f\"No matching figure found for {figure} in case {case_id}\")\n",
167
+ " continue\n",
168
+ "\n",
169
+ " for case_figure in matching_figures:\n",
170
+ " # If a specific letter is specified, filter subfigures\n",
171
+ " if figure_letter:\n",
172
+ " matching_subfigures = [\n",
173
+ " subfig\n",
174
+ " for subfig in case_figure.get(\"subfigures\", [])\n",
175
+ " if subfig.get(\"number\", \"\").lower().endswith(figure_letter.lower())\n",
176
+ " or subfig.get(\"label\", \"\").lower() == figure_letter.lower()\n",
177
+ " ]\n",
178
+ " subfigures.extend(matching_subfigures)\n",
179
+ " else:\n",
180
+ " # If no letter specified, add all subfigures\n",
181
+ " subfigures.extend(case_figure.get(\"subfigures\", []))\n",
182
+ "\n",
183
+ " # Add images to content\n",
184
+ " figure_prompt = \"\"\n",
185
+ " image_urls = []\n",
186
+ "\n",
187
+ " for subfig in subfigures:\n",
188
+ " if \"number\" in subfig:\n",
189
+ " subfig_number = subfig[\"number\"].lower().strip().replace(\" \", \"_\") + \".jpg\"\n",
190
+ " subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)\n",
191
+ " figure_prompt += f\"{subfig_number} located at {subfig_path}\\n\"\n",
192
+ " if \"url\" in subfig:\n",
193
+ " image_urls.append(subfig[\"url\"])\n",
194
+ " else:\n",
195
+ " print(f\"Subfigure missing URL: {subfig}\")\n",
196
+ "\n",
197
+ " prompt = (\n",
198
+ " f\"Answer this question correctly using chain of thought reasoning and \"\n",
199
+ " \"carefully evaluating choices. Solve using our own vision and reasoning and then\"\n",
200
+ " \"use tools to complement your reasoning. Trust your own judgement over any tools.\\n\"\n",
201
+ " f\"{question_data['question']}\\n{figure_prompt}\"\n",
202
+ " )\n",
203
+ "\n",
204
+ " try:\n",
205
+ " start_time = time.time()\n",
206
+ "\n",
207
+ " final_response, agent_state = run_medrax(\n",
208
+ " agent=agent, thread=thread, prompt=prompt, image_urls=image_urls\n",
209
+ " )\n",
210
+ " model_answer, agent_state = run_medrax(\n",
211
+ " agent=agent,\n",
212
+ " thread=thread,\n",
213
+ " prompt=\"If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)\",\n",
214
+ " )\n",
215
+ " duration = time.time() - start_time\n",
216
+ "\n",
217
+ " log_entry = {\n",
218
+ " \"case_id\": case_id,\n",
219
+ " \"question_id\": question_id,\n",
220
+ " \"timestamp\": datetime.now().isoformat(),\n",
221
+ " \"model\": model_name,\n",
222
+ " \"temperature\": temperature,\n",
223
+ " \"duration\": round(duration, 2),\n",
224
+ " \"usage\": \"\",\n",
225
+ " \"cost\": 0,\n",
226
+ " \"raw_response\": final_response,\n",
227
+ " \"model_answer\": model_answer.strip(),\n",
228
+ " \"correct_answer\": question_data[\"answer\"][0],\n",
229
+ " \"input\": {\n",
230
+ " \"messages\": prompt,\n",
231
+ " \"question_data\": {\n",
232
+ " \"question\": question_data[\"question\"],\n",
233
+ " \"explanation\": question_data[\"explanation\"],\n",
234
+ " \"metadata\": question_data.get(\"metadata\", {}),\n",
235
+ " \"figures\": question_data[\"figures\"],\n",
236
+ " },\n",
237
+ " \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
238
+ " \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
239
+ " },\n",
240
+ " \"agent_state\": agent_state,\n",
241
+ " }\n",
242
+ " logging.info(json.dumps(log_entry))\n",
243
+ " return final_response, model_answer.strip()\n",
244
+ "\n",
245
+ " except Exception as e:\n",
246
+ " log_entry = {\n",
247
+ " \"case_id\": case_id,\n",
248
+ " \"question_id\": question_id,\n",
249
+ " \"timestamp\": datetime.now().isoformat(),\n",
250
+ " \"model\": model_name,\n",
251
+ " \"temperature\": temperature,\n",
252
+ " \"status\": \"error\",\n",
253
+ " \"error\": str(e),\n",
254
+ " \"cost\": 0,\n",
255
+ " \"input\": {\n",
256
+ " \"messages\": prompt,\n",
257
+ " \"question_data\": {\n",
258
+ " \"question\": question_data[\"question\"],\n",
259
+ " \"explanation\": question_data[\"explanation\"],\n",
260
+ " \"metadata\": question_data.get(\"metadata\", {}),\n",
261
+ " \"figures\": question_data[\"figures\"],\n",
262
+ " },\n",
263
+ " \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
264
+ " \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
265
+ " },\n",
266
+ " }\n",
267
+ " logging.info(json.dumps(log_entry))\n",
268
+ " print(f\"Error processing case {case_id}, question {question_id}: {str(e)}\")\n",
269
+ " return \"\", \"\"\n",
270
+ "\n",
271
+ "\n",
272
+ "def load_benchmark_questions(case_id):\n",
273
+ " benchmark_dir = \"../benchmark/questions\"\n",
274
+ " return glob.glob(f\"{benchmark_dir}/{case_id}/{case_id}_*.json\")\n",
275
+ "\n",
276
+ "\n",
277
+ "def count_total_questions():\n",
278
+ " total_cases = len(glob.glob(\"../benchmark/questions/*\"))\n",
279
+ " total_questions = sum(\n",
280
+ " len(glob.glob(f\"../benchmark/questions/{case_id}/*.json\"))\n",
281
+ " for case_id in os.listdir(\"../benchmark/questions\")\n",
282
+ " )\n",
283
+ " return total_cases, total_questions\n",
284
+ "\n",
285
+ "\n",
286
+ "def main(tools):\n",
287
+ " with open(\"../data/eurorad_metadata.json\", \"r\") as file:\n",
288
+ " data = json.load(file)\n",
289
+ "\n",
290
+ " total_cases, total_questions = count_total_questions()\n",
291
+ " cases_processed = 0\n",
292
+ " questions_processed = 0\n",
293
+ " skipped_questions = 0\n",
294
+ "\n",
295
+ " print(f\"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\\n\")\n",
296
+ "\n",
297
+ " for case_id, case_details in data.items():\n",
298
+ " if int(case_details[\"case_id\"]) <= 17158:\n",
299
+ " continue\n",
300
+ "\n",
301
+ " print(f\"----------------------------------------------------------------\")\n",
302
+ " agent, thread = get_agent(tools)\n",
303
+ "\n",
304
+ " question_files = load_benchmark_questions(case_id)\n",
305
+ " if not question_files:\n",
306
+ " continue\n",
307
+ "\n",
308
+ " cases_processed += 1\n",
309
+ " for question_file in question_files:\n",
310
+ " with open(question_file, \"r\") as file:\n",
311
+ " question_data = json.load(file)\n",
312
+ " question_id = os.path.basename(question_file).split(\".\")[0]\n",
313
+ "\n",
314
+ " # agent, thread = get_agent(tools)\n",
315
+ " questions_processed += 1\n",
316
+ " final_response, model_answer = create_multimodal_request(\n",
317
+ " question_data, case_details, case_id, question_id, agent, thread\n",
318
+ " )\n",
319
+ "\n",
320
+ " # Handle cases where response is None\n",
321
+ " if final_response is None:\n",
322
+ " skipped_questions += 1\n",
323
+ " print(f\"Skipped question: Case ID {case_id}, Question ID {question_id}\")\n",
324
+ " continue\n",
325
+ "\n",
326
+ " print(\n",
327
+ " f\"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}\"\n",
328
+ " )\n",
329
+ " print(f\"Case ID: {case_id}\")\n",
330
+ " print(f\"Question ID: {question_id}\")\n",
331
+ " print(f\"Final Response: {final_response}\")\n",
332
+ " print(f\"Model Answer: {model_answer}\")\n",
333
+ " print(f\"Correct Answer: {question_data['answer']}\")\n",
334
+ " print(f\"----------------------------------------------------------------\\n\")\n",
335
+ "\n",
336
+ " print(f\"\\nBenchmark Summary:\")\n",
337
+ " print(f\"Total Cases Processed: {cases_processed}\")\n",
338
+ " print(f\"Total Questions Processed: {questions_processed}\")\n",
339
+ " print(f\"Total Questions Skipped: {skipped_questions}\")"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "tools = get_tools()\n",
349
+ "main(tools)"
350
+ ]
351
+ }
352
+ ],
353
+ "metadata": {
354
+ "kernelspec": {
355
+ "display_name": "medmax",
356
+ "language": "python",
357
+ "name": "python3"
358
+ },
359
+ "language_info": {
360
+ "codemirror_mode": {
361
+ "name": "ipython",
362
+ "version": 3
363
+ },
364
+ "file_extension": ".py",
365
+ "mimetype": "text/x-python",
366
+ "name": "python",
367
+ "nbconvert_exporter": "python",
368
+ "pygments_lexer": "ipython3",
369
+ "version": "3.10.16"
370
+ }
371
+ },
372
+ "nbformat": 4,
373
+ "nbformat_minor": 2
374
+ }
experiments/chexbench_gpt4.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import openai
3
+ import os
4
+ from datetime import datetime
5
+ import base64
6
+ import logging
7
+ from pathlib import Path
8
+ import time
9
+ from tqdm import tqdm
10
+ from typing import Dict, List, Optional, Union, Any
11
+
12
+ # Configuration constants
13
+ DEBUG_MODE = False
14
+ OUTPUT_DIR = "results"
15
+ MODEL_NAME = "gpt-4o-2024-05-13"
16
+ TEMPERATURE = 0.2
17
+ SUBSET = "Visual Question Answering"
18
+
19
+ # Set up logging configuration
20
+ logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO
21
+ logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s")
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def get_mime_type(file_path: str) -> str:
26
+ """
27
+ Determine MIME type based on file extension.
28
+
29
+ Args:
30
+ file_path (str): Path to the file
31
+
32
+ Returns:
33
+ str: MIME type string for the file
34
+ """
35
+ extension = os.path.splitext(file_path)[1].lower()
36
+ mime_types = {
37
+ ".png": "image/png",
38
+ ".jpg": "image/jpeg",
39
+ ".jpeg": "image/jpeg",
40
+ ".gif": "image/gif",
41
+ }
42
+ return mime_types.get(extension, "application/octet-stream")
43
+
44
+
45
+ def encode_image(image_path: str) -> str:
46
+ """
47
+ Encode image to base64 with extensive error checking.
48
+
49
+ Args:
50
+ image_path (str): Path to the image file
51
+
52
+ Returns:
53
+ str: Base64 encoded image string
54
+
55
+ Raises:
56
+ FileNotFoundError: If image file does not exist
57
+ ValueError: If image file is empty or too large
58
+ Exception: For other image processing errors
59
+ """
60
+ logger.debug(f"Attempting to read image from: {image_path}")
61
+ if not os.path.exists(image_path):
62
+ raise FileNotFoundError(f"Image file not found: {image_path}")
63
+
64
+ # Add check for file size
65
+ file_size = os.path.getsize(image_path)
66
+ if file_size > 20 * 1024 * 1024: # 20MB limit
67
+ raise ValueError("Image file size exceeds 20MB limit")
68
+ if file_size == 0:
69
+ raise ValueError("Image file is empty")
70
+ logger.debug(f"Image file size: {file_size / 1024:.2f} KB")
71
+
72
+ try:
73
+ from PIL import Image
74
+
75
+ # Try to open and verify the image
76
+ with Image.open(image_path) as img:
77
+ # Get image details
78
+ width, height = img.size
79
+ format = img.format
80
+ mode = img.mode
81
+ logger.debug(
82
+ f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}"
83
+ )
84
+
85
+ if format not in ["PNG", "JPEG", "GIF"]:
86
+ raise ValueError(f"Unsupported image format: {format}")
87
+
88
+ with open(image_path, "rb") as image_file:
89
+ # Read the first few bytes to verify it's a valid PNG
90
+ header = image_file.read(8)
91
+ # if header != b'\x89PNG\r\n\x1a\n':
92
+ # logger.warning("File does not have a valid PNG signature")
93
+
94
+ # Reset file pointer and read entire file
95
+ image_file.seek(0)
96
+ encoded = base64.b64encode(image_file.read()).decode("utf-8")
97
+ encoded_length = len(encoded)
98
+ logger.debug(f"Base64 encoded length: {encoded_length} characters")
99
+
100
+ # Verify the encoded string is not empty and starts correctly
101
+ if encoded_length == 0:
102
+ raise ValueError("Base64 encoding produced empty string")
103
+ if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"):
104
+ logger.warning("Base64 string doesn't start with expected JPEG or PNG header")
105
+
106
+ return encoded
107
+ except Exception as e:
108
+ logger.error(f"Error reading/encoding image: {str(e)}")
109
+ raise
110
+
111
+
112
+ def create_single_request(
113
+ image_path: str, question: str, options: Dict[str, str]
114
+ ) -> List[Dict[str, Any]]:
115
+ """
116
+ Create a single API request with image and question.
117
+
118
+ Args:
119
+ image_path (str): Path to the image file
120
+ question (str): Question text
121
+ options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1'
122
+
123
+ Returns:
124
+ List[Dict[str, Any]]: List of message dictionaries for the API request
125
+
126
+ Raises:
127
+ Exception: For errors in request creation
128
+ """
129
+ if DEBUG_MODE:
130
+ logger.debug("Creating API request...")
131
+
132
+ prompt = f"""Given the following medical examination question:
133
+ Please answer this multiple choice question:
134
+
135
+ Question: {question}
136
+
137
+ Options:
138
+ A) {options['option_0']}
139
+ B) {options['option_1']}
140
+
141
+ Base your answer only on the provided image and select either A or B."""
142
+
143
+ try:
144
+ encoded_image = encode_image(image_path)
145
+ mime_type = get_mime_type(image_path)
146
+
147
+ if DEBUG_MODE:
148
+ logger.debug(f"Image encoded with MIME type: {mime_type}")
149
+
150
+ messages = [
151
+ {
152
+ "role": "system",
153
+ "content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.",
154
+ },
155
+ {
156
+ "role": "user",
157
+ "content": [
158
+ {"type": "text", "text": prompt},
159
+ {
160
+ "type": "image_url",
161
+ "image_url": {"url": f"data:{mime_type};base64,{encoded_image}"},
162
+ },
163
+ ],
164
+ },
165
+ ]
166
+
167
+ if DEBUG_MODE:
168
+ log_messages = json.loads(json.dumps(messages))
169
+ log_messages[1]["content"][1]["image_url"][
170
+ "url"
171
+ ] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]"
172
+ logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}")
173
+
174
+ return messages
175
+
176
+ except Exception as e:
177
+ logger.error(f"Error creating request: {str(e)}")
178
+ raise
179
+
180
+
181
+ def check_answer(model_answer: str, correct_answer: int) -> bool:
182
+ """
183
+ Check if the model's answer matches the correct answer.
184
+
185
+ Args:
186
+ model_answer (str): The model's answer (A or B)
187
+ correct_answer (int): The correct answer index (0 for A, 1 for B)
188
+
189
+ Returns:
190
+ bool: True if answer is correct, False otherwise
191
+ """
192
+ if not isinstance(model_answer, str):
193
+ return False
194
+
195
+ # Clean the model answer to get just the letter
196
+ model_letter = model_answer.strip().upper()
197
+ if model_letter.startswith("A"):
198
+ model_index = 0
199
+ elif model_letter.startswith("B"):
200
+ model_index = 1
201
+ else:
202
+ return False
203
+
204
+ return model_index == correct_answer
205
+
206
+
207
+ def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str:
208
+ """
209
+ Save results to a JSON file with timestamp.
210
+
211
+ Args:
212
+ results (List[Dict[str, Any]]): List of result dictionaries
213
+ output_dir (str): Directory to save results
214
+
215
+ Returns:
216
+ str: Path to the saved file
217
+ """
218
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
219
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
220
+ output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json")
221
+
222
+ with open(output_file, "w") as f:
223
+ json.dump(results, f, indent=2)
224
+
225
+ logger.info(f"Batch results saved to {output_file}")
226
+ return output_file
227
+
228
+
229
+ def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]:
230
+ """
231
+ Calculate accuracy from results, handling error cases.
232
+
233
+ Args:
234
+ results (List[Dict[str, Any]]): List of result dictionaries
235
+
236
+ Returns:
237
+ tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total)
238
+ """
239
+ if not results:
240
+ return 0.0, 0, 0
241
+
242
+ total = len(results)
243
+ valid_results = [r for r in results if "output" in r]
244
+ correct = sum(
245
+ 1 for result in valid_results if result.get("output", {}).get("is_correct", False)
246
+ )
247
+
248
+ accuracy = (correct / total * 100) if total > 0 else 0
249
+ return accuracy, correct, total
250
+
251
+
252
+ def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float:
253
+ """
254
+ Calculate accuracy for the current batch.
255
+
256
+ Args:
257
+ results (List[Dict[str, Any]]): List of result dictionaries
258
+
259
+ Returns:
260
+ float: Accuracy percentage for the batch
261
+ """
262
+ valid_results = [r for r in results if "output" in r]
263
+ if not valid_results:
264
+ return 0.0
265
+ return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100
266
+
267
+
268
+ def process_batch(
269
+ data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50
270
+ ) -> List[Dict[str, Any]]:
271
+ """
272
+ Process a batch of examples and return results.
273
+
274
+ Args:
275
+ data (List[Dict[str, Any]]): List of data items to process
276
+ client (openai.OpenAI): OpenAI client instance
277
+ start_idx (int, optional): Starting index for batch. Defaults to 0
278
+ batch_size (int, optional): Size of batch to process. Defaults to 50
279
+
280
+ Returns:
281
+ List[Dict[str, Any]]: List of processed results
282
+ """
283
+ batch_results = []
284
+ end_idx = min(start_idx + batch_size, len(data))
285
+
286
+ pbar = tqdm(
287
+ range(start_idx, end_idx),
288
+ desc=f"Processing batch {start_idx//batch_size + 1}",
289
+ unit="example",
290
+ )
291
+
292
+ for index in pbar:
293
+ vqa_item = data[index]
294
+ options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]}
295
+
296
+ try:
297
+ messages = create_single_request(
298
+ image_path=vqa_item["image_path"], question=vqa_item["question"], options=options
299
+ )
300
+
301
+ response = client.chat.completions.create(
302
+ model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE
303
+ )
304
+
305
+ model_answer = response.choices[0].message.content.strip()
306
+ is_correct = check_answer(model_answer, vqa_item["answer"])
307
+
308
+ result = {
309
+ "timestamp": datetime.now().isoformat(),
310
+ "example_index": index,
311
+ "input": {
312
+ "question": vqa_item["question"],
313
+ "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
314
+ "image_path": vqa_item["image_path"],
315
+ },
316
+ "output": {
317
+ "model_answer": model_answer,
318
+ "correct_answer": "A" if vqa_item["answer"] == 0 else "B",
319
+ "is_correct": is_correct,
320
+ "usage": {
321
+ "prompt_tokens": response.usage.prompt_tokens,
322
+ "completion_tokens": response.usage.completion_tokens,
323
+ "total_tokens": response.usage.total_tokens,
324
+ },
325
+ },
326
+ }
327
+ batch_results.append(result)
328
+
329
+ # Update progress bar with current accuracy
330
+ current_accuracy = calculate_batch_accuracy(batch_results)
331
+ pbar.set_description(
332
+ f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% "
333
+ f"({len(batch_results)}/{index-start_idx+1} examples)"
334
+ )
335
+
336
+ except Exception as e:
337
+ error_result = {
338
+ "timestamp": datetime.now().isoformat(),
339
+ "example_index": index,
340
+ "error": str(e),
341
+ "input": {
342
+ "question": vqa_item["question"],
343
+ "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
344
+ "image_path": vqa_item["image_path"],
345
+ },
346
+ }
347
+ batch_results.append(error_result)
348
+ if DEBUG_MODE:
349
+ pbar.write(f"Error processing example {index}: {str(e)}")
350
+
351
+ time.sleep(1) # Rate limiting
352
+
353
+ return batch_results
354
+
355
+
356
+ def main() -> None:
357
+ """
358
+ Main function to process the entire dataset.
359
+
360
+ Raises:
361
+ ValueError: If OPENAI_API_KEY is not set
362
+ Exception: For other processing errors
363
+ """
364
+ logger.info("Starting full dataset processing...")
365
+ json_path = "../data/chexbench_updated.json"
366
+
367
+ try:
368
+ api_key = os.getenv("OPENAI_API_KEY")
369
+ if not api_key:
370
+ raise ValueError("OPENAI_API_KEY environment variable is not set.")
371
+ client = openai.OpenAI(api_key=api_key)
372
+
373
+ with open(json_path, "r") as f:
374
+ data = json.load(f)
375
+
376
+ subset_data = data[SUBSET]
377
+ total_examples = len(subset_data)
378
+ logger.info(f"Found {total_examples} examples in {SUBSET} subset")
379
+
380
+ all_results = []
381
+ batch_size = 50 # Process in batches of 50 examples
382
+
383
+ # Process all examples in batches
384
+ for start_idx in range(0, total_examples, batch_size):
385
+ batch_results = process_batch(subset_data, client, start_idx, batch_size)
386
+ all_results.extend(batch_results)
387
+
388
+ # Save intermediate results after each batch
389
+ output_file = save_results_to_json(all_results, OUTPUT_DIR)
390
+
391
+ # Calculate and log overall progress
392
+ overall_accuracy, correct, total = calculate_accuracy(all_results)
393
+ logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed")
394
+ logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)")
395
+
396
+ logger.info("Processing completed!")
397
+ logger.info(f"Final results saved to: {output_file}")
398
+
399
+ except Exception as e:
400
+ logger.error(f"Fatal error: {str(e)}")
401
+ raise
402
+
403
+
404
+ if __name__ == "__main__":
405
+ main()
experiments/compare_runs.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import random
4
+ from typing import List, Dict, Any, Tuple
5
+ import re
6
+ from collections import defaultdict
7
+
8
+ # Define category order
9
+ CATEGORY_ORDER = [
10
+ "detection",
11
+ "classification",
12
+ "localization",
13
+ "comparison",
14
+ "relationship",
15
+ "diagnosis",
16
+ "characterization",
17
+ ]
18
+
19
+
20
+ def extract_letter_answer(answer: str) -> str:
21
+ """Extract just the letter answer from various answer formats.
22
+
23
+ Args:
24
+ answer: The answer string to extract a letter from
25
+
26
+ Returns:
27
+ str: The extracted letter in uppercase, or empty string if no letter found
28
+ """
29
+ if not answer:
30
+ return ""
31
+
32
+ # Convert to string and clean
33
+ answer = str(answer).strip()
34
+
35
+ # If it's just a single letter A-F, return it
36
+ if len(answer) == 1 and answer.upper() in "ABCDEF":
37
+ return answer.upper()
38
+
39
+ # Try to match patterns like "A)", "A.", "A ", etc.
40
+ match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE)
41
+ if match:
42
+ return match.group(1).upper()
43
+
44
+ # Try to find any standalone A-F letters preceded by space or start of string
45
+ # and followed by space, period, parenthesis or end of string
46
+ matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE)
47
+ if matches:
48
+ return matches[0].upper()
49
+
50
+ # Last resort: just find any A-F letter
51
+ letters = re.findall(r"[A-F]", answer, re.IGNORECASE)
52
+ if letters:
53
+ return letters[0].upper()
54
+
55
+ # If no letter found, return original (cleaned)
56
+ return answer.strip().upper()
57
+
58
+
59
+ def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]:
60
+ """Parse JSON Lines file and extract valid predictions.
61
+
62
+ Args:
63
+ file_path: Path to the JSON Lines file to parse
64
+
65
+ Returns:
66
+ Tuple containing:
67
+ - str: Model name or file path if model name not found
68
+ - List[Dict[str, Any]]: List of valid prediction entries
69
+ """
70
+ valid_predictions = []
71
+ model_name = None
72
+
73
+ # First try to parse as LLaVA format
74
+ try:
75
+ with open(file_path, "r", encoding="utf-8") as f:
76
+ data = json.load(f)
77
+ if data.get("model") == "llava-med-v1.5-mistral-7b":
78
+ model_name = data["model"]
79
+ for result in data.get("results", []):
80
+ if all(k in result for k in ["case_id", "question_id", "correct_answer"]):
81
+ # Extract answer with priority: model_answer > validated_answer > raw_output
82
+ model_answer = (
83
+ result.get("model_answer")
84
+ or result.get("validated_answer")
85
+ or result.get("raw_output", "")
86
+ )
87
+
88
+ # Add default categories for LLaVA results
89
+ prediction = {
90
+ "case_id": result["case_id"],
91
+ "question_id": result["question_id"],
92
+ "model_answer": model_answer,
93
+ "correct_answer": result["correct_answer"],
94
+ "input": {
95
+ "question_data": {
96
+ "metadata": {
97
+ "categories": [
98
+ "detection",
99
+ "classification",
100
+ "localization",
101
+ "comparison",
102
+ "relationship",
103
+ "diagnosis",
104
+ "characterization",
105
+ ]
106
+ }
107
+ }
108
+ },
109
+ }
110
+ valid_predictions.append(prediction)
111
+ return model_name, valid_predictions
112
+ except (json.JSONDecodeError, KeyError):
113
+ pass
114
+
115
+ # If not LLaVA format, process as original format
116
+ with open(file_path, "r", encoding="utf-8") as f:
117
+ for line in f:
118
+ if line.startswith("HTTP Request:"):
119
+ continue
120
+ try:
121
+ data = json.loads(line.strip())
122
+ if "model" in data:
123
+ model_name = data["model"]
124
+ if all(
125
+ k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"]
126
+ ):
127
+ valid_predictions.append(data)
128
+ except json.JSONDecodeError:
129
+ continue
130
+
131
+ return model_name if model_name else file_path, valid_predictions
132
+
133
+
134
+ def filter_common_questions(
135
+ predictions_list: List[List[Dict[str, Any]]]
136
+ ) -> List[List[Dict[str, Any]]]:
137
+ """Ensure only questions that exist across all models are evaluated.
138
+
139
+ Args:
140
+ predictions_list: List of prediction lists from different models
141
+
142
+ Returns:
143
+ List[List[Dict[str, Any]]]: Filtered predictions containing only common questions
144
+ """
145
+ question_sets = [
146
+ set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list
147
+ ]
148
+ common_questions = set.intersection(*question_sets)
149
+
150
+ return [
151
+ [p for p in preds if (p["case_id"], p["question_id"]) in common_questions]
152
+ for preds in predictions_list
153
+ ]
154
+
155
+
156
+ def calculate_accuracy(
157
+ predictions: List[Dict[str, Any]]
158
+ ) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]:
159
+ """Compute overall and category-level accuracy.
160
+
161
+ Args:
162
+ predictions: List of prediction entries to analyze
163
+
164
+ Returns:
165
+ Tuple containing:
166
+ - float: Overall accuracy percentage
167
+ - int: Number of correct predictions
168
+ - int: Total number of predictions
169
+ - Dict[str, Dict[str, float]]: Category-level accuracy statistics
170
+ """
171
+ if not predictions:
172
+ return 0.0, 0, 0, {}
173
+
174
+ category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
175
+ correct = 0
176
+ total = 0
177
+ sample_size = min(5, len(predictions))
178
+ sampled_indices = random.sample(range(len(predictions)), sample_size)
179
+
180
+ print("\nSample extracted answers:")
181
+ for i in sampled_indices:
182
+ pred = predictions[i]
183
+ model_ans = extract_letter_answer(pred["model_answer"])
184
+ correct_ans = extract_letter_answer(pred["correct_answer"])
185
+ print(f"QID: {pred['question_id']}")
186
+ print(f" Raw Model Answer: {pred['model_answer']}")
187
+ print(f" Extracted Model Answer: {model_ans}")
188
+ print(f" Raw Correct Answer: {pred['correct_answer']}")
189
+ print(f" Extracted Correct Answer: {correct_ans}")
190
+ print("-" * 80)
191
+
192
+ for pred in predictions:
193
+ try:
194
+ model_ans = extract_letter_answer(pred["model_answer"])
195
+ correct_ans = extract_letter_answer(pred["correct_answer"])
196
+ categories = (
197
+ pred.get("input", {})
198
+ .get("question_data", {})
199
+ .get("metadata", {})
200
+ .get("categories", [])
201
+ )
202
+
203
+ if model_ans and correct_ans:
204
+ total += 1
205
+ is_correct = model_ans == correct_ans
206
+ if is_correct:
207
+ correct += 1
208
+
209
+ for category in categories:
210
+ category_performance[category]["total"] += 1
211
+ if is_correct:
212
+ category_performance[category]["correct"] += 1
213
+
214
+ except KeyError:
215
+ continue
216
+
217
+ category_accuracies = {
218
+ category: {
219
+ "accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0,
220
+ "total": stats["total"],
221
+ "correct": stats["correct"],
222
+ }
223
+ for category, stats in category_performance.items()
224
+ }
225
+
226
+ return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies)
227
+
228
+
229
+ def compare_models(file_paths: List[str]) -> None:
230
+ """Compare accuracy between multiple model prediction files.
231
+
232
+ Args:
233
+ file_paths: List of paths to model prediction files to compare
234
+ """
235
+ # Parse all files
236
+ parsed_results = [parse_json_lines(file_path) for file_path in file_paths]
237
+ model_names, predictions_list = zip(*parsed_results)
238
+
239
+ # Get initial stats
240
+ print(f"\n📊 **Initial Accuracy**:")
241
+ results = []
242
+ category_results = []
243
+
244
+ for preds, name in zip(predictions_list, model_names):
245
+ acc, correct, total, category_acc = calculate_accuracy(preds)
246
+ results.append((acc, correct, total, name))
247
+ category_results.append(category_acc)
248
+ print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
249
+
250
+ # Get common questions across all models
251
+ filtered_predictions = filter_common_questions(predictions_list)
252
+ print(
253
+ f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}"
254
+ )
255
+
256
+ # Compute accuracy on common questions
257
+ print(f"\n📊 **Accuracy on Common Questions**:")
258
+ filtered_results = []
259
+ filtered_category_results = []
260
+
261
+ for preds, name in zip(filtered_predictions, model_names):
262
+ acc, correct, total, category_acc = calculate_accuracy(preds)
263
+ filtered_results.append((acc, correct, total, name))
264
+ filtered_category_results.append(category_acc)
265
+ print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
266
+
267
+ # Print category-wise accuracy
268
+ print("\nCategory Performance (Common Questions):")
269
+ for category in CATEGORY_ORDER:
270
+ print(f"\n{category.capitalize()}:")
271
+ for model_name, category_acc in zip(model_names, filtered_category_results):
272
+ stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0})
273
+ print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})")
274
+
275
+
276
+ def main():
277
+ parser = argparse.ArgumentParser(
278
+ description="Compare accuracy across multiple model prediction files"
279
+ )
280
+ parser.add_argument("files", nargs="+", help="Paths to model prediction files")
281
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
282
+
283
+ args = parser.parse_args()
284
+ random.seed(args.seed)
285
+
286
+ compare_models(args.files)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
experiments/inspect_logs.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import argparse
3
+ import json
4
+ import glob
5
+ from pathlib import Path
6
+ from datetime import datetime
7
+
8
+
9
+ def get_latest_log() -> str:
10
+ """Find the most recently modified log file in the current directory.
11
+
12
+ Returns:
13
+ str: Path to the most recently modified log file
14
+
15
+ Raises:
16
+ FileNotFoundError: If no log files are found in the current directory
17
+ """
18
+ logs = list(Path(".").glob("api_usage_*.json"))
19
+ if not logs:
20
+ raise FileNotFoundError("No log files found in the current directory.")
21
+ return str(max(logs, key=lambda p: p.stat().st_mtime))
22
+
23
+
24
+ def format_cost(entry: dict) -> str:
25
+ """Format cost if available, otherwise return 'N/A'
26
+
27
+ Args:
28
+ entry: Log entry dictionary containing cost information
29
+
30
+ Returns:
31
+ str: Formatted cost string with $ and 4 decimal places, or 'N/A' if cost not found
32
+ """
33
+ return f"${entry.get('cost', 'N/A'):.4f}" if "cost" in entry else "N/A"
34
+
35
+
36
+ def print_gpt4_entry(entry: dict) -> None:
37
+ """Print entry for GPT-4 format
38
+
39
+ Args:
40
+ entry: Log entry dictionary in GPT-4 format containing model info, inputs and outputs
41
+ """
42
+ print("\n=== Log Entry ===")
43
+ print(f"Model: {entry['model']}")
44
+ print(f"Case ID: {entry['case_id']}")
45
+ print(f"Question ID: {entry['question_id']}")
46
+
47
+ print("\n=== Model Input ===")
48
+ messages = entry["input"]["messages"]
49
+ print("System message:", messages[0]["content"])
50
+ user_content = messages[1]["content"]
51
+ print("\nUser prompt:", user_content[0]["text"])
52
+ print("\nImages provided:")
53
+ for content in user_content[1:]:
54
+ print(f" - {content['image_url']['url']}")
55
+
56
+ print("\n=== Model Output ===")
57
+ print(f"Answer: {entry['model_answer']}")
58
+ print(f"Correct: {entry['correct_answer']}")
59
+
60
+ print("\n=== Usage Stats ===")
61
+ print(f"Duration: {entry['duration']}s")
62
+ print(f"Cost: {format_cost(entry)}")
63
+ print(
64
+ f"Tokens: {entry['usage']['total_tokens']}",
65
+ f"(prompt: {entry['usage']['prompt_tokens']},",
66
+ f"completion: {entry['usage']['completion_tokens']})",
67
+ )
68
+
69
+
70
+ def print_llama_entry(entry: dict) -> None:
71
+ """Print entry for Llama-3.2 format
72
+
73
+ Args:
74
+ entry: Log entry dictionary in Llama format containing model info, inputs and outputs
75
+ """
76
+ print("\n=== Log Entry ===")
77
+ print(f"Model: {entry['model']}")
78
+ print(f"Case ID: {entry['case_id']}")
79
+ print(f"Question ID: {entry['question_id']}")
80
+
81
+ print("\n=== Model Input ===")
82
+ print(f"Question: {entry['input']['question_data']['question']}")
83
+ print("\nImages provided:")
84
+ for url in entry["input"]["image_urls"]:
85
+ print(f" - {url}")
86
+ if entry["input"]["image_captions"]:
87
+ print("\nImage captions:")
88
+ for caption in entry["input"]["image_captions"]:
89
+ if caption:
90
+ print(f" - {caption}")
91
+
92
+ print("\n=== Model Output ===")
93
+ print(f"Answer: {entry['model_answer']}")
94
+ print(f"Correct: {entry['correct_answer']}")
95
+
96
+ print("\n=== Usage Stats ===")
97
+ print(f"Duration: {entry['duration']}s")
98
+ if "usage" in entry:
99
+ print(
100
+ f"Tokens: {entry['usage']['total_tokens']}",
101
+ f"(prompt: {entry['usage']['prompt_tokens']},",
102
+ f"completion: {entry['usage']['completion_tokens']})",
103
+ )
104
+
105
+
106
+ def determine_model_type(entry: dict) -> str:
107
+ """Determine the model type from the entry
108
+
109
+ Args:
110
+ entry: Log entry dictionary containing model information
111
+
112
+ Returns:
113
+ str: Model type - 'gpt4', 'llama', or 'unknown'
114
+ """
115
+ model = entry.get("model", "").lower()
116
+ if "gpt-4" in model:
117
+ return "gpt4"
118
+ elif "llama" in model:
119
+ return "llama"
120
+ elif "chexagent" in model:
121
+ return "chexagent"
122
+ elif "medrax" in model:
123
+ return "medrax"
124
+ else:
125
+ return "unknown"
126
+
127
+
128
+ def print_log_entry(
129
+ log_file: Optional[str] = None,
130
+ num_entries: Optional[int] = None,
131
+ model_filter: Optional[str] = None,
132
+ ) -> None:
133
+ """Print log entries from the specified log file or the latest log file.
134
+
135
+ Args:
136
+ log_file: Path to the log file. If None, uses the latest log file.
137
+ num_entries: Number of entries to print. If None, prints all entries.
138
+ model_filter: Filter entries by model type ('gpt4' or 'llama'). If None, prints all.
139
+ """
140
+ if log_file is None:
141
+ log_file = get_latest_log()
142
+ print(f"Using latest log file: {log_file}")
143
+
144
+ entries_printed = 0
145
+ total_entries = 0
146
+ filtered_entries = 0
147
+
148
+ with open(log_file, "r") as f:
149
+ for line in f:
150
+ if line.startswith("HTTP"):
151
+ continue
152
+ try:
153
+ total_entries += 1
154
+ entry = json.loads(line)
155
+
156
+ # Apply model filter if specified
157
+ model_type = determine_model_type(entry)
158
+ if model_filter and model_type != model_filter:
159
+ filtered_entries += 1
160
+ continue
161
+
162
+ if model_type == "gpt4":
163
+ print_gpt4_entry(entry)
164
+ elif model_type == "llama":
165
+ print_llama_entry(entry)
166
+ else:
167
+ print(f"Unknown model type in entry: {entry['model']}")
168
+ continue
169
+
170
+ print("=" * 50)
171
+ entries_printed += 1
172
+ if num_entries and entries_printed >= num_entries:
173
+ break
174
+
175
+ except (json.JSONDecodeError, KeyError) as e:
176
+ print(f"Error processing entry: {e}")
177
+ continue
178
+
179
+ print(f"\nSummary:")
180
+ print(f"Total entries: {total_entries}")
181
+ print(f"Entries printed: {entries_printed}")
182
+ if model_filter:
183
+ print(f"Entries filtered: {filtered_entries}")
184
+
185
+
186
+ def main() -> None:
187
+ """Main entry point for the script"""
188
+ parser = argparse.ArgumentParser(
189
+ description="Parse and display log entries from API usage logs."
190
+ )
191
+ parser.add_argument("-l", "--log_file", nargs="?", help="Path to the log file (optional)")
192
+ parser.add_argument("-n", "--num_entries", type=int, help="Number of entries to display")
193
+ parser.add_argument(
194
+ "-m",
195
+ "--model",
196
+ choices=["gpt4", "llama"],
197
+ default="gpt4",
198
+ help="Model type to display (default: gpt4)",
199
+ )
200
+ args = parser.parse_args()
201
+
202
+ try:
203
+ print_log_entry(args.log_file, args.num_entries, args.model)
204
+ except FileNotFoundError as e:
205
+ print(f"Error: {e}")
206
+ exit(1)
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
experiments/validate_logs.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Optional
2
+ import json
3
+ import sys
4
+ import glob
5
+ from pathlib import Path
6
+ from collections import defaultdict
7
+
8
+
9
+ def get_latest_log() -> str:
10
+ """Find the most recently modified log file in the current directory.
11
+
12
+ Returns:
13
+ str: Path to the most recently modified log file
14
+
15
+ Raises:
16
+ SystemExit: If no log files are found in current directory
17
+ """
18
+ log_pattern = "api_usage_*.json"
19
+ logs = list(Path(".").glob(log_pattern))
20
+ if not logs:
21
+ print(f"No files matching pattern '{log_pattern}' found in current directory")
22
+ sys.exit(1)
23
+ return str(max(logs, key=lambda p: p.stat().st_mtime))
24
+
25
+
26
+ def analyze_log_file(filename: str) -> Tuple[List[Dict], List[Dict], Dict[str, List[str]]]:
27
+ """Analyze a log file for entries missing images and errors.
28
+
29
+ Args:
30
+ filename: Path to the log file to analyze
31
+
32
+ Returns:
33
+ Tuple containing:
34
+ - List of entries with no images
35
+ - List of skipped/error entries
36
+ - Dict of processing errors by type
37
+
38
+ Raises:
39
+ SystemExit: If file cannot be found or read
40
+ """
41
+ no_images = []
42
+ errors = defaultdict(list)
43
+ skipped = []
44
+
45
+ try:
46
+ with open(filename, "r") as f:
47
+ for line_num, line in enumerate(f, 1):
48
+ # Skip HTTP request logs
49
+ if line.startswith("HTTP Request:") or line.strip() == "":
50
+ continue
51
+ try:
52
+ # Try to parse the JSON line
53
+ if not line.strip().startswith("{"):
54
+ continue
55
+ entry = json.loads(line.strip())
56
+ case_id = entry.get("case_id")
57
+ question_id = entry.get("question_id")
58
+
59
+ # Skip if we can't identify the question
60
+ if not case_id or not question_id:
61
+ continue
62
+
63
+ # Check for explicit skip/error status
64
+ if entry.get("status") in ["skipped", "error"]:
65
+ skipped.append(
66
+ {
67
+ "case_id": case_id,
68
+ "question_id": question_id,
69
+ "reason": entry.get("reason"),
70
+ "status": entry.get("status"),
71
+ }
72
+ )
73
+ continue
74
+
75
+ # Check user content for images
76
+ messages = entry.get("input", {}).get("messages", [])
77
+ has_image = False
78
+ for msg in messages:
79
+ content = msg.get("content", [])
80
+ if isinstance(content, list):
81
+ for item in content:
82
+ if isinstance(item, dict) and item.get("type") == "image_url":
83
+ has_image = True
84
+ break
85
+ if not has_image:
86
+ no_images.append(
87
+ {
88
+ "case_id": case_id,
89
+ "question_id": question_id,
90
+ "question": entry.get("input", {})
91
+ .get("question_data", {})
92
+ .get("question", "")[:100]
93
+ + "...", # First 100 chars of question
94
+ }
95
+ )
96
+ except json.JSONDecodeError:
97
+ errors["json_decode"].append(f"Line {line_num}: Invalid JSON")
98
+ continue
99
+ except Exception as e:
100
+ errors["other"].append(f"Line {line_num}: Error processing entry: {str(e)}")
101
+ except FileNotFoundError:
102
+ print(f"Error: Could not find log file: {filename}")
103
+ sys.exit(1)
104
+ except Exception as e:
105
+ print(f"Error reading file {filename}: {str(e)}")
106
+ sys.exit(1)
107
+
108
+ return no_images, skipped, errors
109
+
110
+
111
+ def print_results(
112
+ filename: str, no_images: List[Dict], skipped: List[Dict], errors: Dict[str, List[str]]
113
+ ) -> None:
114
+ """Print analysis results.
115
+
116
+ Args:
117
+ filename: Name of the analyzed log file
118
+ no_images: List of entries with no images
119
+ skipped: List of skipped/error entries
120
+ errors: Dict of processing errors by type
121
+ """
122
+ print(f"\nAnalyzing log file: {filename}")
123
+ print("\n=== Questions with No Images ===")
124
+ if no_images:
125
+ for entry in no_images:
126
+ print(f"\nCase ID: {entry['case_id']}")
127
+ print(f"Question ID: {entry['question_id']}")
128
+ print(f"Question Preview: {entry['question']}")
129
+ print(f"\nTotal questions without images: {len(no_images)}")
130
+
131
+ print("\n=== Skipped/Error Questions ===")
132
+ if skipped:
133
+ for entry in skipped:
134
+ print(f"\nCase ID: {entry['case_id']}")
135
+ print(f"Question ID: {entry['question_id']}")
136
+ print(f"Status: {entry['status']}")
137
+ print(f"Reason: {entry.get('reason', 'unknown')}")
138
+ print(f"\nTotal skipped/error questions: {len(skipped)}")
139
+
140
+ if errors:
141
+ print("\n=== Processing Errors ===")
142
+ for error_type, messages in errors.items():
143
+ if messages:
144
+ print(f"\n{error_type}:")
145
+ for msg in messages:
146
+ print(f" {msg}")
147
+
148
+
149
+ def main() -> None:
150
+ """Main entry point for log validation script."""
151
+ # If a file is specified as an argument, use it; otherwise find the latest log
152
+ if len(sys.argv) > 1:
153
+ log_file = sys.argv[1]
154
+ else:
155
+ log_file = get_latest_log()
156
+
157
+ no_images, skipped, errors = analyze_log_file(log_file)
158
+ print_results(log_file, no_images, skipped, errors)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
interface.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from pathlib import Path
4
+ import time
5
+ import shutil
6
+ from typing import AsyncGenerator, List, Optional, Tuple
7
+ from gradio import ChatMessage
8
+
9
+
10
+ class ChatInterface:
11
+ """
12
+ A chat interface for interacting with a medical AI agent through Gradio.
13
+
14
+ Handles file uploads, message processing, and chat history management.
15
+ Supports both regular image files and DICOM medical imaging files.
16
+ """
17
+
18
+ def __init__(self, agent, tools_dict):
19
+ """
20
+ Initialize the chat interface.
21
+
22
+ Args:
23
+ agent: The medical AI agent to handle requests
24
+ tools_dict (dict): Dictionary of available tools for image processing
25
+ """
26
+ self.agent = agent
27
+ self.tools_dict = tools_dict
28
+ self.upload_dir = Path("temp")
29
+ self.upload_dir.mkdir(exist_ok=True)
30
+ self.current_thread_id = None
31
+ # Separate storage for original and display paths
32
+ self.original_file_path = None # For LLM (.dcm or other)
33
+ self.display_file_path = None # For UI (always viewable format)
34
+
35
+ def handle_upload(self, file_path: str) -> str:
36
+ """
37
+ Handle new file upload and set appropriate paths.
38
+
39
+ Args:
40
+ file_path (str): Path to the uploaded file
41
+
42
+ Returns:
43
+ str: Display path for UI, or None if no file uploaded
44
+ """
45
+ if not file_path:
46
+ return None
47
+
48
+ source = Path(file_path)
49
+ timestamp = int(time.time())
50
+
51
+ # Save original file with proper suffix
52
+ suffix = source.suffix.lower()
53
+ saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
54
+ shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
55
+ self.original_file_path = str(saved_path)
56
+
57
+ # Handle DICOM conversion for display only
58
+ if suffix == ".dcm":
59
+ output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
60
+ self.display_file_path = output["image_path"]
61
+ else:
62
+ self.display_file_path = str(saved_path)
63
+
64
+ return self.display_file_path
65
+
66
+ def add_message(
67
+ self, message: str, display_image: str, history: List[dict]
68
+ ) -> Tuple[List[dict], gr.Textbox]:
69
+ """
70
+ Add a new message to the chat history.
71
+
72
+ Args:
73
+ message (str): Text message to add
74
+ display_image (str): Path to image being displayed
75
+ history (List[dict]): Current chat history
76
+
77
+ Returns:
78
+ Tuple[List[dict], gr.Textbox]: Updated history and textbox component
79
+ """
80
+ image_path = self.original_file_path or display_image
81
+ if image_path is not None:
82
+ history.append({"role": "user", "content": {"path": image_path}})
83
+ if message is not None:
84
+ history.append({"role": "user", "content": message})
85
+ return history, gr.Textbox(value=message, interactive=False)
86
+
87
+ async def process_message(
88
+ self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
89
+ ) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
90
+ """
91
+ Process a message and generate responses.
92
+
93
+ Args:
94
+ message (str): User message to process
95
+ display_image (Optional[str]): Path to currently displayed image
96
+ chat_history (List[ChatMessage]): Current chat history
97
+
98
+ Yields:
99
+ Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
100
+ """
101
+ chat_history = chat_history or []
102
+
103
+ # Initialize thread if needed
104
+ if not self.current_thread_id:
105
+ self.current_thread_id = str(time.time())
106
+
107
+ messages = []
108
+ image_path = self.original_file_path or display_image
109
+ if image_path is not None:
110
+ messages.append({"role": "user", "content": f"path: {image_path}"})
111
+ if message is not None:
112
+ messages.append({"role": "user", "content": message})
113
+
114
+ try:
115
+ for event in self.agent.workflow.stream(
116
+ {"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
117
+ ):
118
+ if isinstance(event, dict):
119
+ if "process" in event:
120
+ content = event["process"]["messages"][-1].content
121
+ if content:
122
+ content = re.sub(r"temp/[^\s]*", "", content)
123
+ chat_history.append(ChatMessage(role="assistant", content=content))
124
+ yield chat_history, self.display_file_path, ""
125
+
126
+ elif "execute" in event:
127
+ for message in event["execute"]["messages"]:
128
+ tool_name = message.name
129
+ tool_result = eval(message.content)[0]
130
+
131
+ if tool_result:
132
+ metadata = {"title": f"🖼️ Image from tool: {tool_name}"}
133
+ formatted_result = " ".join(
134
+ line.strip() for line in str(tool_result).splitlines()
135
+ ).strip()
136
+ metadata["description"] = formatted_result
137
+ chat_history.append(
138
+ ChatMessage(
139
+ role="assistant",
140
+ content=formatted_result,
141
+ metadata=metadata,
142
+ )
143
+ )
144
+
145
+ # For image_visualizer, use display path
146
+ if tool_name == "image_visualizer":
147
+ self.display_file_path = tool_result["image_path"]
148
+ chat_history.append(
149
+ ChatMessage(
150
+ role="assistant",
151
+ # content=gr.Image(value=self.display_file_path),
152
+ content={"path": self.display_file_path},
153
+ )
154
+ )
155
+
156
+ yield chat_history, self.display_file_path, ""
157
+
158
+ except Exception as e:
159
+ chat_history.append(
160
+ ChatMessage(
161
+ role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
162
+ )
163
+ )
164
+ yield chat_history, self.display_file_path
165
+
166
+
167
+ def create_demo(agent, tools_dict):
168
+ """
169
+ Create a Gradio demo interface for the medical AI agent.
170
+
171
+ Args:
172
+ agent: The medical AI agent to handle requests
173
+ tools_dict (dict): Dictionary of available tools for image processing
174
+
175
+ Returns:
176
+ gr.Blocks: Gradio Blocks interface
177
+ """
178
+ interface = ChatInterface(agent, tools_dict)
179
+
180
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
181
+ with gr.Column():
182
+ gr.Markdown(
183
+ """
184
+ # 🏥 MedRAX
185
+ Medical Reasoning Agent for Chest X-ray
186
+ """
187
+ )
188
+
189
+ with gr.Row():
190
+ with gr.Column(scale=3):
191
+ chatbot = gr.Chatbot(
192
+ [],
193
+ height=800,
194
+ container=True,
195
+ show_label=True,
196
+ elem_classes="chat-box",
197
+ type="messages",
198
+ label="Agent",
199
+ avatar_images=(
200
+ None,
201
+ "assets/medrax_logo.jpg",
202
+ ),
203
+ )
204
+ with gr.Row():
205
+ with gr.Column(scale=3):
206
+ txt = gr.Textbox(
207
+ show_label=False,
208
+ placeholder="Ask about the X-ray...",
209
+ container=False,
210
+ )
211
+
212
+ with gr.Column(scale=3):
213
+ image_display = gr.Image(
214
+ label="Image", type="filepath", height=700, container=True
215
+ )
216
+ with gr.Row():
217
+ upload_button = gr.UploadButton(
218
+ "📎 Upload X-Ray",
219
+ file_types=["image"],
220
+ )
221
+ dicom_upload = gr.UploadButton(
222
+ "📄 Upload DICOM",
223
+ file_types=["file"],
224
+ )
225
+ with gr.Row():
226
+ clear_btn = gr.Button("Clear Chat")
227
+ new_thread_btn = gr.Button("New Thread")
228
+
229
+ # Event handlers
230
+ def clear_chat():
231
+ interface.original_file_path = None
232
+ interface.display_file_path = None
233
+ return [], None
234
+
235
+ def new_thread():
236
+ interface.current_thread_id = str(time.time())
237
+ return [], interface.display_file_path
238
+
239
+ def handle_file_upload(file):
240
+ return interface.handle_upload(file.name)
241
+
242
+ chat_msg = txt.submit(
243
+ interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
244
+ )
245
+ bot_msg = chat_msg.then(
246
+ interface.process_message,
247
+ inputs=[txt, image_display, chatbot],
248
+ outputs=[chatbot, image_display, txt],
249
+ )
250
+ bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
251
+
252
+ upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display)
253
+
254
+ dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
255
+
256
+ clear_btn.click(clear_chat, outputs=[chatbot, image_display])
257
+ new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
258
+
259
+ return demo
main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import *
3
+ from dotenv import load_dotenv
4
+ from transformers import logging
5
+
6
+ from langgraph.checkpoint.memory import MemorySaver
7
+ from langchain_openai import ChatOpenAI
8
+ from langgraph.checkpoint.memory import MemorySaver
9
+ from langchain_openai import ChatOpenAI
10
+
11
+ from interface import create_demo
12
+ from medrax.agent import *
13
+ from medrax.tools import *
14
+ from medrax.utils import *
15
+
16
+ warnings.filterwarnings("ignore")
17
+ logging.set_verbosity_error()
18
+ _ = load_dotenv()
19
+
20
+
21
+ def initialize_agent(prompt_file, model_dir="/model-weights", temp_dir="temp", device="cuda"):
22
+ prompts = load_prompts_from_file(prompt_file)
23
+ prompt = prompts["MEDICAL_ASSISTANT"]
24
+
25
+ tools_dict = {
26
+ "ChestXRayClassifierTool": ChestXRayClassifierTool(device=device),
27
+ "ChestXRayReportGeneratorTool": ChestXRayReportGeneratorTool(
28
+ cache_dir=model_dir, device=device
29
+ ),
30
+ "ChestXRaySegmentationTool": ChestXRaySegmentationTool(device=device),
31
+ "LlavaMedTool": LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
32
+ "XRayVQATool": XRayVQATool(cache_dir=model_dir, device=device),
33
+ "ImageVisualizerTool": ImageVisualizerTool(),
34
+ "XRayPhraseGroundingTool": XRayPhraseGroundingTool(
35
+ cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
36
+ ),
37
+ "ChestXRayGeneratorTool": ChestXRayGeneratorTool(
38
+ model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
39
+ ),
40
+ "DicomProcessorTool": DicomProcessorTool(temp_dir=temp_dir),
41
+ }
42
+
43
+ checkpointer = MemorySaver()
44
+ model = ChatOpenAI(model="gpt-4o", temperature=0.7, top_p=0.95)
45
+ agent = Agent(
46
+ model,
47
+ tools=list(tools_dict.values()),
48
+ log_tools=True,
49
+ log_dir="logs",
50
+ system_prompt=prompt,
51
+ checkpointer=checkpointer,
52
+ )
53
+
54
+ print("Agent initialized")
55
+ return agent, tools_dict
56
+
57
+
58
+ if __name__ == "__main__":
59
+ print("Starting server...")
60
+ agent, tools_dict = initialize_agent("medrax/docs/system_prompts.txt")
61
+ demo = create_demo(agent, tools_dict)
62
+
63
+ demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
medrax/__init__.py ADDED
File without changes
medrax/agent/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .agent import AgentState, Agent
medrax/agent/agent.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import operator
3
+ from pathlib import Path
4
+ from dotenv import load_dotenv
5
+ from datetime import datetime
6
+ from typing import List, Dict, Any, TypedDict, Annotated, Optional
7
+
8
+ from langgraph.graph import StateGraph, END
9
+ from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
10
+ from langchain_core.language_models import BaseLanguageModel
11
+ from langchain_core.tools import BaseTool
12
+
13
+ _ = load_dotenv()
14
+
15
+
16
+ class ToolCallLog(TypedDict):
17
+ """
18
+ A TypedDict representing a log entry for a tool call.
19
+
20
+ Attributes:
21
+ timestamp (str): The timestamp of when the tool call was made.
22
+ tool_call_id (str): The unique identifier for the tool call.
23
+ name (str): The name of the tool that was called.
24
+ args (Any): The arguments passed to the tool.
25
+ content (str): The content or result of the tool call.
26
+ """
27
+
28
+ timestamp: str
29
+ tool_call_id: str
30
+ name: str
31
+ args: Any
32
+ content: str
33
+
34
+
35
+ class AgentState(TypedDict):
36
+ """
37
+ A TypedDict representing the state of an agent.
38
+
39
+ Attributes:
40
+ messages (Annotated[List[AnyMessage], operator.add]): A list of messages
41
+ representing the conversation history. The operator.add annotation
42
+ indicates that new messages should be appended to this list.
43
+ """
44
+
45
+ messages: Annotated[List[AnyMessage], operator.add]
46
+
47
+
48
+ class Agent:
49
+ """
50
+ A class representing an agent that processes requests and executes tools based on
51
+ language model responses.
52
+
53
+ Attributes:
54
+ model (BaseLanguageModel): The language model used for processing.
55
+ tools (Dict[str, BaseTool]): A dictionary of available tools.
56
+ checkpointer (Any): Manages and persists the agent's state.
57
+ system_prompt (str): The system instructions for the agent.
58
+ workflow (StateGraph): The compiled workflow for the agent's processing.
59
+ log_tools (bool): Whether to log tool calls.
60
+ log_path (Path): Path to save tool call logs.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ model: BaseLanguageModel,
66
+ tools: List[BaseTool],
67
+ checkpointer: Any = None,
68
+ system_prompt: str = "",
69
+ log_tools: bool = True,
70
+ log_dir: Optional[str] = "logs",
71
+ ):
72
+ """
73
+ Initialize the Agent.
74
+
75
+ Args:
76
+ model (BaseLanguageModel): The language model to use.
77
+ tools (List[BaseTool]): A list of available tools.
78
+ checkpointer (Any, optional): State persistence manager. Defaults to None.
79
+ system_prompt (str, optional): System instructions. Defaults to "".
80
+ log_tools (bool, optional): Whether to log tool calls. Defaults to True.
81
+ log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
82
+ """
83
+ self.system_prompt = system_prompt
84
+ self.log_tools = log_tools
85
+
86
+ if self.log_tools:
87
+ self.log_path = Path(log_dir or "logs")
88
+ self.log_path.mkdir(exist_ok=True)
89
+
90
+ # Define the agent workflow
91
+ workflow = StateGraph(AgentState)
92
+ workflow.add_node("process", self.process_request)
93
+ workflow.add_node("execute", self.execute_tools)
94
+ workflow.add_conditional_edges(
95
+ "process", self.has_tool_calls, {True: "execute", False: END}
96
+ )
97
+ workflow.add_edge("execute", "process")
98
+ workflow.set_entry_point("process")
99
+
100
+ self.workflow = workflow.compile(checkpointer=checkpointer)
101
+ self.tools = {t.name: t for t in tools}
102
+ self.model = model.bind_tools(tools)
103
+
104
+ def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
105
+ """
106
+ Process the request using the language model.
107
+
108
+ Args:
109
+ state (AgentState): The current state of the agent.
110
+
111
+ Returns:
112
+ Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
113
+ """
114
+ messages = state["messages"]
115
+ if self.system_prompt:
116
+ messages = [SystemMessage(content=self.system_prompt)] + messages
117
+ response = self.model.invoke(messages)
118
+ return {"messages": [response]}
119
+
120
+ def has_tool_calls(self, state: AgentState) -> bool:
121
+ """
122
+ Check if the response contains any tool calls.
123
+
124
+ Args:
125
+ state (AgentState): The current state of the agent.
126
+
127
+ Returns:
128
+ bool: True if tool calls exist, False otherwise.
129
+ """
130
+ response = state["messages"][-1]
131
+ return len(response.tool_calls) > 0
132
+
133
+ def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
134
+ """
135
+ Execute tool calls from the model's response.
136
+
137
+ Args:
138
+ state (AgentState): The current state of the agent.
139
+
140
+ Returns:
141
+ Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
142
+ """
143
+ tool_calls = state["messages"][-1].tool_calls
144
+ results = []
145
+
146
+ for call in tool_calls:
147
+ print(f"Executing tool: {call}")
148
+ if call["name"] not in self.tools:
149
+ print("\n....invalid tool....")
150
+ result = "invalid tool, please retry"
151
+ else:
152
+ result = self.tools[call["name"]].invoke(call["args"])
153
+
154
+ results.append(
155
+ ToolMessage(
156
+ tool_call_id=call["id"],
157
+ name=call["name"],
158
+ args=call["args"],
159
+ content=str(result),
160
+ )
161
+ )
162
+
163
+ self._save_tool_calls(results)
164
+ print("Returning to model processing!")
165
+
166
+ return {"messages": results}
167
+
168
+ def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
169
+ """
170
+ Save tool calls to a JSON file with timestamp-based naming.
171
+
172
+ Args:
173
+ tool_calls (List[ToolMessage]): List of tool calls to save.
174
+ """
175
+ if not self.log_tools:
176
+ return
177
+
178
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
179
+ filename = self.log_path / f"tool_calls_{timestamp}.json"
180
+
181
+ logs: List[ToolCallLog] = []
182
+ for call in tool_calls:
183
+ log_entry = {
184
+ "tool_call_id": call.tool_call_id,
185
+ "name": call.name,
186
+ "args": call.args,
187
+ "content": call.content,
188
+ "timestamp": datetime.now().isoformat(),
189
+ }
190
+ logs.append(log_entry)
191
+
192
+ with open(filename, "w") as f:
193
+ json.dump(logs, f, indent=4)