Adibvafa
commited on
Commit
·
eaca108
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- .gitignore +174 -0
- .vscode/launch.json +15 -0
- LICENSE +201 -0
- README.md +83 -0
- assets/medrax_logo.jpg +3 -0
- assets/medrax_logo.png +3 -0
- benchmark/__init__.py +0 -0
- benchmark/create_benchmark.py +352 -0
- benchmark/llm.py +42 -0
- benchmark/utils.py +78 -0
- data/eurorad_metadata.json +0 -0
- data/figures.py +74 -0
- data/get_cases.py +51 -0
- data/stats/age_distribution.png +3 -0
- data/stats/area_of_interest_distribution.png +3 -0
- data/stats/gender_distribution.png +3 -0
- demo/chest/LIDC.dcm +3 -0
- demo/chest/Pseudo.dcm +3 -0
- demo/chest/RIDER.dcm +3 -0
- demo/chest/TCGAA.dcm +3 -0
- demo/chest/__init__.py +0 -0
- demo/chest/effusion1.png +3 -0
- demo/chest/normal1.jpg +3 -0
- demo/chest/normal2.jpg +3 -0
- demo/chest/normal3.jpg +3 -0
- demo/chest/normal4.jpg +3 -0
- demo/chest/normal5.jpg +3 -0
- demo/chest/normal6.jpg +3 -0
- demo/chest/pneumonia1.jpg +3 -0
- demo/chest/pneumonia2.jpg +3 -0
- demo/chest/pneumonia3.jpg +3 -0
- demo/chest/pneumonia4.jpg +3 -0
- demo/chest/pneumonia5.jpg +3 -0
- experiments/README.md +63 -0
- experiments/analyze_axes.py +385 -0
- experiments/benchmark_chexagent.py +316 -0
- experiments/benchmark_gpt4o.py +331 -0
- experiments/benchmark_llama.py +443 -0
- experiments/benchmark_llavamed.py +541 -0
- experiments/benchmark_medrax.ipynb +374 -0
- experiments/chexbench_gpt4.py +405 -0
- experiments/compare_runs.py +290 -0
- experiments/inspect_logs.py +210 -0
- experiments/validate_logs.py +162 -0
- interface.py +259 -0
- main.py +63 -0
- medrax/__init__.py +0 -0
- medrax/agent/__init__.py +1 -0
- medrax/agent/agent.py +193 -0
.gitattributes
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.dcm filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
| 163 |
+
|
| 164 |
+
# ruff
|
| 165 |
+
ruff-cache/
|
| 166 |
+
.ruff_cache/
|
| 167 |
+
|
| 168 |
+
afallah/
|
| 169 |
+
|
| 170 |
+
logs/
|
| 171 |
+
|
| 172 |
+
temp/
|
| 173 |
+
|
| 174 |
+
.gradio/
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
// Use IntelliSense to learn about possible attributes.
|
| 3 |
+
// Hover to view descriptions of existing attributes.
|
| 4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
+
"version": "0.2.0",
|
| 6 |
+
"configurations": [
|
| 7 |
+
{
|
| 8 |
+
"name": "Python Debugger: main.py",
|
| 9 |
+
"type": "debugpy",
|
| 10 |
+
"request": "launch",
|
| 11 |
+
"program": "main.py",
|
| 12 |
+
"console": "integratedTerminal"
|
| 13 |
+
}
|
| 14 |
+
]
|
| 15 |
+
}
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">
|
| 2 |
+
🤖 MedRAX: Medical Reasoning Agent for Chest X-ray 🏥
|
| 3 |
+
</h1>
|
| 4 |
+
<br>
|
| 5 |
+
|
| 6 |
+
## Problem
|
| 7 |
+
Medical professionals face significant challenges when using traditional Large Language Models (LLMs) for X-ray analysis. Standard LLMs often hallucinate, lack specialized medical imaging capabilities, and can miss critical diagnostic details. While separate tools exist for various aspects of X-ray analysis, the current fragmented approach requires doctors to juggle multiple systems, leading to inefficient workflows and potential oversights in patient care.
|
| 8 |
+
<br>
|
| 9 |
+
<br>
|
| 10 |
+
|
| 11 |
+
## Our Solution
|
| 12 |
+
MedRAX is an intelligent medical assistant that seamlessly integrates an LLM with specialized X-ray analysis tools, providing a unified interface for comprehensive X-ray analysis. Through natural conversation, medical professionals can leverage powerful tools while the system intelligently coordinates their usage behind the scenes.
|
| 13 |
+
|
| 14 |
+
Our comprehensive toolset includes:
|
| 15 |
+
- **ChestXRayReportGenerator**: Generates detailed, accurate medical reports from X-ray images
|
| 16 |
+
- **ChestXRayClassifier**: Analyzes images for 18 different pathologies providing probability scores for each condition
|
| 17 |
+
- **ChestXRaySegmentation**: Precisely segments anatomical structures
|
| 18 |
+
- **MedicalVisualQA**: Answers to complex visual medical queries
|
| 19 |
+
- **XRayPhraseGrounding**: Locates and visualizes specific medical findings in X-rays with bounding box precision
|
| 20 |
+
- **ImageVisualizer**: Enhances and displays X-ray images for optimal viewing
|
| 21 |
+
- **ChestXRayGenerator**: Generates synthetic chest X-rays for educational purposes
|
| 22 |
+
- **DicomProcessor**: Handles DICOM file processing and analysis
|
| 23 |
+
<br>
|
| 24 |
+
|
| 25 |
+
## Technical Implementation
|
| 26 |
+
MedRAX is built on a robust technical foundation:
|
| 27 |
+
- **Core Architecture**: Leverages LangChain and LangGraph for sophisticated agent orchestration
|
| 28 |
+
- **Language Model**: Powered by OpenAI's API for natural language understanding and generation
|
| 29 |
+
- **Specialized Tools**: Integrates medical-domain fine-tuned models for various analysis tasks
|
| 30 |
+
- **Interface**: Built with Gradio for an intuitive, chat-based user experience
|
| 31 |
+
- **Modular Design**: Allows easy integration of additional specialized medical tools
|
| 32 |
+
<br>
|
| 33 |
+
|
| 34 |
+
## Potential Impact
|
| 35 |
+
- Accelerates X-ray analysis while maintaining high accuracy
|
| 36 |
+
- Reduces the likelihood of missed diagnoses through multi-tool verification
|
| 37 |
+
- Provides valuable educational support for medical students and residents
|
| 38 |
+
- Offers a scalable solution for facilities with limited specialist availability
|
| 39 |
+
- Improves patient outcomes through comprehensive analysis
|
| 40 |
+
- Streamlines workflow for medical professionals
|
| 41 |
+
<br>
|
| 42 |
+
|
| 43 |
+
## Setup and Usage
|
| 44 |
+
|
| 45 |
+
### Prerequisites
|
| 46 |
+
- GPU required for optimal performance
|
| 47 |
+
- Python 3.8+
|
| 48 |
+
- OpenAI API key
|
| 49 |
+
|
| 50 |
+
### Installation
|
| 51 |
+
1. Clone the repository:
|
| 52 |
+
```bash
|
| 53 |
+
git clone https://github.com/yourusername/MedRAX.git
|
| 54 |
+
cd MedRAX
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
2. Install dependencies:
|
| 58 |
+
```bash
|
| 59 |
+
pip install -e .
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
3. Set up environment variables:
|
| 63 |
+
```bash
|
| 64 |
+
echo "OPENAI_API_KEY=your_key_here" > .env
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Running the Application
|
| 68 |
+
Start the application:
|
| 69 |
+
```bash
|
| 70 |
+
python main.py
|
| 71 |
+
```
|
| 72 |
+
<br>
|
| 73 |
+
|
| 74 |
+
## Developers
|
| 75 |
+
- Adibvafa Fallahpour
|
| 76 |
+
- Jun Ma
|
| 77 |
+
- Hongwei Lyu
|
| 78 |
+
<br>
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
<p align="center">
|
| 82 |
+
Made with ❤️ in Toronto
|
| 83 |
+
</p>
|
assets/medrax_logo.jpg
ADDED
|
Git LFS Details
|
assets/medrax_logo.png
ADDED
|
Git LFS Details
|
benchmark/__init__.py
ADDED
|
File without changes
|
benchmark/create_benchmark.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Medical X-ray Question Generation Benchmark aka ChestAgentBench
|
| 4 |
+
|
| 5 |
+
This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o.
|
| 6 |
+
It structures questions across different analytical categories and saves them as JSON.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import json
|
| 12 |
+
from typing import *
|
| 13 |
+
from pprint import pprint
|
| 14 |
+
|
| 15 |
+
import openai
|
| 16 |
+
import numpy as np
|
| 17 |
+
from scipy import stats
|
| 18 |
+
import plotly.graph_objects as go
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
from benchmark.utils import load_eurorad_dataset
|
| 22 |
+
from benchmark.llm import get_llm_response
|
| 23 |
+
|
| 24 |
+
# Constants
|
| 25 |
+
DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data"
|
| 26 |
+
DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json")
|
| 27 |
+
|
| 28 |
+
SYSTEM_PROMPT = """
|
| 29 |
+
You are an expert medical benchmark creation assistant.
|
| 30 |
+
Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays.
|
| 31 |
+
""".strip()
|
| 32 |
+
|
| 33 |
+
CATEGORIES_META = {
|
| 34 |
+
"detection": "Identify and locate specific findings in the chest X-ray.",
|
| 35 |
+
"classification": "Determine whether specific findings are present or absent in the chest X-ray.",
|
| 36 |
+
"enumeration": "Count the number of target findings in the chest X-ray.",
|
| 37 |
+
"localization": "Locate a given finding in the chest X-ray.",
|
| 38 |
+
"comparison": "Compare the size or position of a specific finding in the chest X-ray.",
|
| 39 |
+
"relationship": "Determine the relationship between two or more findings in the chest X-ray.",
|
| 40 |
+
"diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.",
|
| 41 |
+
"characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.",
|
| 42 |
+
"reasoning": "Explain the medical rationale and thought process behind findings and conclusions.",
|
| 43 |
+
}
|
| 44 |
+
CATEGORIES = list(CATEGORIES_META.keys())
|
| 45 |
+
|
| 46 |
+
CATEGORY_COMBINATIONS = [
|
| 47 |
+
["detection", "localization", "characterization", "reasoning"], # Detailed Finding Analysis
|
| 48 |
+
["detection", "classification", "relationship", "reasoning"], # Pattern Recognition & Relations
|
| 49 |
+
["localization", "comparison", "relationship", "reasoning"], # Spatial Understanding
|
| 50 |
+
["classification", "comparison", "diagnosis", "reasoning"], # Clinical Decision Making
|
| 51 |
+
["classification", "characterization", "diagnosis", "reasoning"], # Diagnostic Characterization
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
DEFAULT_SECTIONS = [
|
| 55 |
+
"history",
|
| 56 |
+
"image_finding",
|
| 57 |
+
"discussion",
|
| 58 |
+
"differential_diagnosis",
|
| 59 |
+
"diagnosis",
|
| 60 |
+
"figures",
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Question:
|
| 65 |
+
"""A class to generate clinical questions from case data.
|
| 66 |
+
|
| 67 |
+
This class handles creating structured clinical questions by combining case data with
|
| 68 |
+
specified categories and difficulty levels.
|
| 69 |
+
|
| 70 |
+
Attributes:
|
| 71 |
+
type (str): The type of question (e.g. multiple choice)
|
| 72 |
+
difficulty (str): Difficulty level of the question
|
| 73 |
+
case_data (Dict[str, Any]): Dictionary containing the clinical case data
|
| 74 |
+
case_content (str): Formatted case data from selected sections
|
| 75 |
+
case_id (str): Unique identifier for the case
|
| 76 |
+
categories (List[str]): List of analytical categories this question tests
|
| 77 |
+
sections (List[str]): Case sections to include in question
|
| 78 |
+
raw_content (Optional[str]): Raw LLM response to the question prompt
|
| 79 |
+
content (Optional[Dict[str, str]]): Extracted content from the raw LLM response
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
type: str,
|
| 85 |
+
difficulty: str,
|
| 86 |
+
case_data: Dict[str, Any],
|
| 87 |
+
categories: List[str],
|
| 88 |
+
sections: List[str] = [
|
| 89 |
+
"history",
|
| 90 |
+
"image_finding",
|
| 91 |
+
"discussion",
|
| 92 |
+
"differential_diagnosis",
|
| 93 |
+
"diagnosis",
|
| 94 |
+
"figures",
|
| 95 |
+
],
|
| 96 |
+
system_prompt: str = "You are an expert medical benchmark creation assistant.",
|
| 97 |
+
) -> None:
|
| 98 |
+
self.type = type
|
| 99 |
+
self.difficulty = difficulty
|
| 100 |
+
self.case_data = case_data
|
| 101 |
+
self.case_id = case_data["case_id"]
|
| 102 |
+
self.categories = categories
|
| 103 |
+
self.sections = sections
|
| 104 |
+
self.system_prompt = system_prompt
|
| 105 |
+
self.case_content = self.select_case_sections()
|
| 106 |
+
self.raw_content: Optional[str] = None
|
| 107 |
+
self.content: Optional[Dict[str, str]] = None
|
| 108 |
+
|
| 109 |
+
def create_question_prompt(self) -> str:
|
| 110 |
+
"""Creates a formatted prompt for generating a clinical question.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
str: A structured prompt containing the question parameters and clinical data
|
| 114 |
+
"""
|
| 115 |
+
category_descriptions = "\n".join(
|
| 116 |
+
f"{category}: {desc}"
|
| 117 |
+
for category, desc in CATEGORIES_META.items()
|
| 118 |
+
if category in self.categories
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return f"""
|
| 122 |
+
You must follow these guidelines:
|
| 123 |
+
1. Questions must be answerable using only context and chest X-rays.
|
| 124 |
+
- Questions must explicitly mention the referenced figures
|
| 125 |
+
- Questions can only reference the chest X-ray figures
|
| 126 |
+
|
| 127 |
+
2. Questions must have unambiguous, verifiable answers, and should:
|
| 128 |
+
- Challenge the agent's analytical capabilities
|
| 129 |
+
- Require multi-step reasoning
|
| 130 |
+
- Test ability to make precise observations
|
| 131 |
+
- Evaluate capability to derive insights and findings from the chest X-ray
|
| 132 |
+
|
| 133 |
+
3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools.
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
Create a {self.difficulty} {self.type} clinical question that integrates the following:
|
| 137 |
+
|
| 138 |
+
{category_descriptions}
|
| 139 |
+
|
| 140 |
+
based on the following clinical case:
|
| 141 |
+
|
| 142 |
+
{self.case_content}
|
| 143 |
+
|
| 144 |
+
Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays.
|
| 145 |
+
Your question should require the agent to derive insights and findings from the chest X-ray by itself.
|
| 146 |
+
Your answer should be verifiable directly in the context of the case.
|
| 147 |
+
You can only use the image findings that come from the chest X-ray figures.
|
| 148 |
+
|
| 149 |
+
Your response must follow this exact format:
|
| 150 |
+
THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question]
|
| 151 |
+
QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.]
|
| 152 |
+
FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]]
|
| 153 |
+
EXPLANATION: [short explanation of why your answer is verifiable in the case]
|
| 154 |
+
ANSWER: [correct answer e.g. "A"]
|
| 155 |
+
""".strip().replace(
|
| 156 |
+
" ", ""
|
| 157 |
+
) # remove tabs
|
| 158 |
+
|
| 159 |
+
def select_case_sections(self) -> str:
|
| 160 |
+
"""Extract and format selected sections from case data into paragraphs.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
str: Formatted string with case sections and content
|
| 164 |
+
"""
|
| 165 |
+
section_mapping = {
|
| 166 |
+
"history": ("history", "No history provided."),
|
| 167 |
+
"image_finding": ("image_finding", "No findings provided."),
|
| 168 |
+
"discussion": ("discussion", "No discussion provided."),
|
| 169 |
+
"differential_diagnosis": (
|
| 170 |
+
"differential_diagnosis",
|
| 171 |
+
"No differential diagnosis provided.",
|
| 172 |
+
),
|
| 173 |
+
"diagnosis": ("diagnosis", "No diagnosis provided."),
|
| 174 |
+
"figures": ("figures", "No figures provided."),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
formatted = []
|
| 178 |
+
for section in self.sections:
|
| 179 |
+
if section in section_mapping:
|
| 180 |
+
key, default = section_mapping[section]
|
| 181 |
+
content = self.case_data.get(key, default)
|
| 182 |
+
|
| 183 |
+
if key == "figures":
|
| 184 |
+
figures_text = []
|
| 185 |
+
for figure in content:
|
| 186 |
+
for subfig in figure["subfigures"]:
|
| 187 |
+
figures_text.append(f"{subfig['number']}: {subfig['caption']}")
|
| 188 |
+
content = "\n".join(figures_text)
|
| 189 |
+
|
| 190 |
+
formatted.append(f"{section}:\n{content}")
|
| 191 |
+
|
| 192 |
+
return "\n\n".join(formatted)
|
| 193 |
+
|
| 194 |
+
def create_question(
|
| 195 |
+
self,
|
| 196 |
+
client: openai.OpenAI,
|
| 197 |
+
temperature: float = 0.7,
|
| 198 |
+
top_p: float = 0.95,
|
| 199 |
+
max_tokens: int = 500,
|
| 200 |
+
model: str = "gpt-4o",
|
| 201 |
+
) -> str:
|
| 202 |
+
"""Create a clinical question using LLM.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
client (openai.OpenAI): OpenAI client instance
|
| 206 |
+
temperature (float): Controls randomness in responses. Defaults to 0.7.
|
| 207 |
+
top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95.
|
| 208 |
+
max_tokens (int): Max tokens in model response. Defaults to 500.
|
| 209 |
+
model (str): OpenAI model to use. Defaults to "gpt-4o".
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
str: LLM response containing formatted question components
|
| 213 |
+
"""
|
| 214 |
+
self.raw_content = get_llm_response(
|
| 215 |
+
client=client,
|
| 216 |
+
prompt=self.create_question_prompt(),
|
| 217 |
+
system_prompt=self.system_prompt,
|
| 218 |
+
temperature=temperature,
|
| 219 |
+
top_p=top_p,
|
| 220 |
+
max_tokens=max_tokens,
|
| 221 |
+
model=model,
|
| 222 |
+
)
|
| 223 |
+
self.content = self.extract_content()
|
| 224 |
+
|
| 225 |
+
return self.raw_content
|
| 226 |
+
|
| 227 |
+
def extract_content(self) -> Dict[str, str]:
|
| 228 |
+
"""Extract sections from raw LLM response using regex patterns.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer
|
| 232 |
+
"""
|
| 233 |
+
keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"]
|
| 234 |
+
|
| 235 |
+
content = {}
|
| 236 |
+
for kw in keywords:
|
| 237 |
+
pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)"
|
| 238 |
+
match = re.search(pattern, self.raw_content, re.DOTALL)
|
| 239 |
+
content[kw.lower()] = match.group(1).strip() if match else None
|
| 240 |
+
|
| 241 |
+
return content
|
| 242 |
+
|
| 243 |
+
def save(self, output_path: str) -> Dict[str, Any]:
|
| 244 |
+
"""Save question content and metadata as a JSON file.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
output_path (str): Directory path where the JSON file will be saved
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Dict[str, Any]: Question data including content (thoughts, question, figures, options,
|
| 251 |
+
explanation, answer) and metadata (type, difficulty, categories, etc.)
|
| 252 |
+
"""
|
| 253 |
+
question_metadata = self.content.copy()
|
| 254 |
+
|
| 255 |
+
# Add metadata
|
| 256 |
+
question_metadata["metadata"] = {
|
| 257 |
+
"case_id": self.case_id,
|
| 258 |
+
"type": self.type,
|
| 259 |
+
"difficulty": self.difficulty,
|
| 260 |
+
"categories": self.categories,
|
| 261 |
+
"sections": self.sections,
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
# Create a directory for the case
|
| 265 |
+
case_dir = os.path.join(output_path, str(self.case_id))
|
| 266 |
+
os.makedirs(case_dir, exist_ok=True)
|
| 267 |
+
|
| 268 |
+
# Save the question metadata to a JSON file
|
| 269 |
+
output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json")
|
| 270 |
+
with open(output_file, "w") as f:
|
| 271 |
+
json.dump(question_metadata, f, indent=2)
|
| 272 |
+
|
| 273 |
+
return question_metadata
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def generate_questions(
|
| 277 |
+
dataset: Dict[str, Any],
|
| 278 |
+
client: openai.OpenAI,
|
| 279 |
+
output_dir: str,
|
| 280 |
+
skip_first: int = 100,
|
| 281 |
+
temperature: float = 0.7,
|
| 282 |
+
top_p: float = 0.95,
|
| 283 |
+
max_tokens: int = 1200,
|
| 284 |
+
model: str = "gpt-4o",
|
| 285 |
+
) -> None:
|
| 286 |
+
"""Generate questions for each case and category combination.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
dataset: Dictionary of case data
|
| 290 |
+
client: OpenAI client instance
|
| 291 |
+
output_dir: Directory to save generated questions
|
| 292 |
+
skip_first: Number of initial cases to skip
|
| 293 |
+
temperature: LLM temperature parameter
|
| 294 |
+
top_p: LLM top_p parameter
|
| 295 |
+
max_tokens: Maximum tokens for LLM response
|
| 296 |
+
model: LLM model name
|
| 297 |
+
"""
|
| 298 |
+
target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first]
|
| 299 |
+
|
| 300 |
+
for case_id in tqdm(target_cases, desc="Processing cases"):
|
| 301 |
+
case_data = dataset[case_id]
|
| 302 |
+
|
| 303 |
+
for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"):
|
| 304 |
+
question = Question(
|
| 305 |
+
type="multiple choice (A/B/C/D/E/F)",
|
| 306 |
+
difficulty="complex",
|
| 307 |
+
case_data=case_data,
|
| 308 |
+
categories=category,
|
| 309 |
+
sections=DEFAULT_SECTIONS,
|
| 310 |
+
system_prompt=SYSTEM_PROMPT,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
response = question.create_question(
|
| 314 |
+
client=client,
|
| 315 |
+
temperature=temperature,
|
| 316 |
+
top_p=top_p,
|
| 317 |
+
max_tokens=max_tokens,
|
| 318 |
+
model=model,
|
| 319 |
+
)
|
| 320 |
+
question.save(output_dir)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def main():
|
| 324 |
+
"""Main execution function."""
|
| 325 |
+
client = openai.OpenAI()
|
| 326 |
+
|
| 327 |
+
# Load and verify dataset
|
| 328 |
+
dataset = load_eurorad_dataset(
|
| 329 |
+
DATASET_PATH,
|
| 330 |
+
section="Chest Imaging",
|
| 331 |
+
as_dict=True,
|
| 332 |
+
filter_by_caption=[
|
| 333 |
+
"xray",
|
| 334 |
+
"x-ray",
|
| 335 |
+
"x ray",
|
| 336 |
+
"ray",
|
| 337 |
+
"xr",
|
| 338 |
+
"radiograph",
|
| 339 |
+
],
|
| 340 |
+
)
|
| 341 |
+
print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n")
|
| 342 |
+
|
| 343 |
+
# Optional: Print sample case for verification
|
| 344 |
+
case_data = dataset["16798"]
|
| 345 |
+
pprint(case_data, sort_dicts=False)
|
| 346 |
+
|
| 347 |
+
# Generate questions
|
| 348 |
+
generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions")
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if __name__ == "__main__":
|
| 352 |
+
main()
|
benchmark/llm.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_llm_response(
|
| 6 |
+
client: openai.OpenAI,
|
| 7 |
+
prompt: str,
|
| 8 |
+
system_prompt: str = "You are a helpful assistant.",
|
| 9 |
+
model: str = "gpt-4o-mini",
|
| 10 |
+
temperature: float = 0.7,
|
| 11 |
+
top_p: float = 0.95,
|
| 12 |
+
max_tokens: int = 500,
|
| 13 |
+
) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Get response from OpenAI language model.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
client (openai.OpenAI): OpenAI client
|
| 19 |
+
prompt (str): The user prompt/question to send to the model
|
| 20 |
+
system_prompt (str, optional): System prompt to set model behavior.
|
| 21 |
+
model (str, optional): OpenAI model to use. Defaults to "gpt-4o-mini".
|
| 22 |
+
temperature (float, optional): Controls randomness in responses. Defaults to 0.7.
|
| 23 |
+
top_p (float, optional): Controls diversity via nucleus sampling. Defaults to 0.95.
|
| 24 |
+
max_tokens (int, optional): Max tokens in model response. Defaults to 200.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
str: The model's response text
|
| 28 |
+
"""
|
| 29 |
+
messages = [
|
| 30 |
+
{"role": "system", "content": system_prompt},
|
| 31 |
+
{"role": "user", "content": prompt},
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
response = client.chat.completions.create(
|
| 35 |
+
model=model,
|
| 36 |
+
messages=messages,
|
| 37 |
+
temperature=temperature,
|
| 38 |
+
top_p=top_p,
|
| 39 |
+
max_tokens=max_tokens,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return response.choices[0].message.content
|
benchmark/utils.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_eurorad_dataset(
|
| 7 |
+
dataset_path: str,
|
| 8 |
+
section: str = "any",
|
| 9 |
+
as_dict: bool = False,
|
| 10 |
+
filter_by_caption: List[str] = [
|
| 11 |
+
"xray",
|
| 12 |
+
"x-ray",
|
| 13 |
+
"x ray",
|
| 14 |
+
"ray",
|
| 15 |
+
"xr",
|
| 16 |
+
"radiograph",
|
| 17 |
+
"radiogram",
|
| 18 |
+
"plain film",
|
| 19 |
+
],
|
| 20 |
+
) -> List[Dict] | Dict[str, Dict]:
|
| 21 |
+
"""
|
| 22 |
+
Load a dataset from a JSON file.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
dataset_path (str): Path to the JSON dataset file.
|
| 26 |
+
section (str, optional): Section of the dataset to load. Defaults to "any".
|
| 27 |
+
as_dict (bool, optional): Whether to return data as dict. Defaults to False.
|
| 28 |
+
filter_by_caption (List[str], optional): List of strings to filter cases by caption content. Defaults to [].
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
List[Dict] | Dict[str, Dict]: The loaded dataset as a list of dictionaries or dict if as_dict=True.
|
| 32 |
+
|
| 33 |
+
Raises:
|
| 34 |
+
FileNotFoundError: If dataset_path does not exist
|
| 35 |
+
json.JSONDecodeError: If file is not valid JSON
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
with open(dataset_path, "r", encoding="utf-8") as file:
|
| 39 |
+
data = json.load(file)
|
| 40 |
+
|
| 41 |
+
if filter_by_caption:
|
| 42 |
+
filtered_data = {}
|
| 43 |
+
for case_id, case in data.items():
|
| 44 |
+
if any(
|
| 45 |
+
any(x in subfig["caption"].lower() for x in filter_by_caption)
|
| 46 |
+
for figure in case["figures"]
|
| 47 |
+
for subfig in figure["subfigures"]
|
| 48 |
+
) or any(x in case["image_finding"].lower() for x in filter_by_caption):
|
| 49 |
+
filtered_data[case_id] = case
|
| 50 |
+
data = filtered_data
|
| 51 |
+
|
| 52 |
+
if section != "any":
|
| 53 |
+
section = section.strip().lower()
|
| 54 |
+
if not as_dict:
|
| 55 |
+
data = [
|
| 56 |
+
item for item in data.values() if item.get("section", "").strip().lower() == section
|
| 57 |
+
]
|
| 58 |
+
else:
|
| 59 |
+
data = {
|
| 60 |
+
k: v for k, v in data.items() if v.get("section", "").strip().lower() == section
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
elif not as_dict:
|
| 64 |
+
data = list(data.values())
|
| 65 |
+
|
| 66 |
+
return data
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def save_dataset(dataset: Dict | List[Dict], dataset_path: str):
|
| 70 |
+
"""
|
| 71 |
+
Save a dataset to a JSON file.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dataset (Dict | List[Dict]): The dataset to save as a dictionary or list of dictionaries.
|
| 75 |
+
dataset_path (str): Path where the JSON dataset file will be saved.
|
| 76 |
+
"""
|
| 77 |
+
with open(dataset_path, "w", encoding="utf-8") as file:
|
| 78 |
+
json.dump(dataset, file)
|
data/eurorad_metadata.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/figures.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import requests
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def download_eurorad_figures(metadata_path: str, output_dir: str) -> None:
|
| 9 |
+
"""
|
| 10 |
+
Download figures from Eurorad dataset and save them organized by case_id.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
metadata_path: Path to the eurorad_metadata.json file
|
| 14 |
+
output_dir: Base directory where figures will be saved
|
| 15 |
+
|
| 16 |
+
The figures will be saved as:
|
| 17 |
+
{output_dir}/{case_id}/{figure_number}.jpg
|
| 18 |
+
Example:
|
| 19 |
+
figures/189/Figure_1a.jpg
|
| 20 |
+
"""
|
| 21 |
+
# Create output directory if it doesn't exist
|
| 22 |
+
output_path = Path(output_dir)
|
| 23 |
+
output_path.mkdir(exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# Load metadata
|
| 26 |
+
with open(metadata_path) as f:
|
| 27 |
+
metadata = json.load(f)
|
| 28 |
+
|
| 29 |
+
# Iterate through all cases with progress bar
|
| 30 |
+
for case_id in tqdm(metadata, desc="Downloading cases", unit="case"):
|
| 31 |
+
case = metadata[case_id]
|
| 32 |
+
case_dir = output_path / str(case["case_id"])
|
| 33 |
+
case_dir.mkdir(exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# Process all figures and their subfigures
|
| 36 |
+
for figure in case["figures"]:
|
| 37 |
+
for subfig in figure["subfigures"]:
|
| 38 |
+
|
| 39 |
+
# Remove leading and trailing whitespace and convert to lowercase
|
| 40 |
+
subfig_name = f"{subfig['number'].strip().replace(' ', '_').lower()}.jpg"
|
| 41 |
+
subfig_path = Path(case_dir) / subfig_name
|
| 42 |
+
|
| 43 |
+
save_figure(
|
| 44 |
+
url=subfig["url"],
|
| 45 |
+
output_path=subfig_path,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def save_figure(url: str, output_path: Path) -> None:
|
| 50 |
+
"""
|
| 51 |
+
Download and save a single figure.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
url: URL of the figure to download
|
| 55 |
+
output_path: Path where the figure should be saved
|
| 56 |
+
"""
|
| 57 |
+
if output_path.exists():
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
response = requests.get(url, timeout=10)
|
| 62 |
+
response.raise_for_status()
|
| 63 |
+
with open(output_path, "wb") as f:
|
| 64 |
+
f.write(response.content)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error downloading {url}: {e}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
root = os.path.dirname(os.path.abspath(__file__))
|
| 71 |
+
download_eurorad_figures(
|
| 72 |
+
metadata_path=os.path.join(root, "eurorad_metadata.json"),
|
| 73 |
+
output_dir=os.path.join(root, "figures"),
|
| 74 |
+
)
|
data/get_cases.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from bs4 import BeautifulSoup
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_response(url):
|
| 9 |
+
headers = {
|
| 10 |
+
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36 Edg/108.0.1462.54"
|
| 11 |
+
}
|
| 12 |
+
return requests.get(url, headers=headers)
|
| 13 |
+
|
| 14 |
+
def get_case_numbers_from_page(page):
|
| 15 |
+
url = f"https://www.eurorad.org/advanced-search?sort_by=published_at&sort_order=ASC&page={page}&filter%5B0%5D=section%3A40"
|
| 16 |
+
|
| 17 |
+
# Remove proxy usage since it's likely triggering the protection
|
| 18 |
+
response = get_response(url)
|
| 19 |
+
print(response.text)
|
| 20 |
+
|
| 21 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 22 |
+
spans = soup.find_all("span", class_="case__number small")
|
| 23 |
+
|
| 24 |
+
# Remove '#' from the span text and strip extra whitespace
|
| 25 |
+
numbers = [span.text.strip().replace("#", "").strip() for span in spans]
|
| 26 |
+
return numbers
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
total_pages = 107 # Pages 0 through 106
|
| 31 |
+
all_numbers = []
|
| 32 |
+
|
| 33 |
+
for page in tqdm(range(total_pages)):
|
| 34 |
+
numbers = get_case_numbers_from_page(page)
|
| 35 |
+
all_numbers.extend(numbers)
|
| 36 |
+
|
| 37 |
+
if page != total_pages - 1 and len(numbers) != 9:
|
| 38 |
+
print(f"Warning: Page {page} returned {len(numbers)} cases instead of 9")
|
| 39 |
+
|
| 40 |
+
# Be kind to the server – avoid hitting it too fast
|
| 41 |
+
time.sleep(1)
|
| 42 |
+
break
|
| 43 |
+
|
| 44 |
+
with open('case_numbers.json', 'w') as f:
|
| 45 |
+
json.dump(all_numbers, f)
|
| 46 |
+
|
| 47 |
+
print(f"Saved {len(all_numbers)} case numbers to case_numbers.json")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|
data/stats/age_distribution.png
ADDED
|
Git LFS Details
|
data/stats/area_of_interest_distribution.png
ADDED
|
Git LFS Details
|
data/stats/gender_distribution.png
ADDED
|
Git LFS Details
|
demo/chest/LIDC.dcm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11d25b1d34dff083057de994fef7da3dcef75bd7b334823ec6cb9c16b3ba0338
|
| 3 |
+
size 17071804
|
demo/chest/Pseudo.dcm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b35ae460fb5f62eb6d6c4c5117f6683100ad92c5fb6ba1a3c36da39703c4652
|
| 3 |
+
size 7535280
|
demo/chest/RIDER.dcm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dc15f7afa5434991e1359f596433870ad611b42227db87d484d31976545de7fd
|
| 3 |
+
size 7534066
|
demo/chest/TCGAA.dcm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e8137290ac823d3da3c00ce3e18120123eaa62a786934c7afc52a989b0b64cf
|
| 3 |
+
size 7535274
|
demo/chest/__init__.py
ADDED
|
File without changes
|
demo/chest/effusion1.png
ADDED
|
Git LFS Details
|
demo/chest/normal1.jpg
ADDED
|
Git LFS Details
|
demo/chest/normal2.jpg
ADDED
|
Git LFS Details
|
demo/chest/normal3.jpg
ADDED
|
Git LFS Details
|
demo/chest/normal4.jpg
ADDED
|
Git LFS Details
|
demo/chest/normal5.jpg
ADDED
|
Git LFS Details
|
demo/chest/normal6.jpg
ADDED
|
Git LFS Details
|
demo/chest/pneumonia1.jpg
ADDED
|
Git LFS Details
|
demo/chest/pneumonia2.jpg
ADDED
|
Git LFS Details
|
demo/chest/pneumonia3.jpg
ADDED
|
Git LFS Details
|
demo/chest/pneumonia4.jpg
ADDED
|
Git LFS Details
|
demo/chest/pneumonia5.jpg
ADDED
|
Git LFS Details
|
experiments/README.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experiments
|
| 2 |
+
Below are the instructions for running experiments using our novel ChestAgentBench and the previous SoTA CheXbench. ChestAgentBench is a comprehensive benchmark containing over 2,500 complex medical queries across 8 diverse categories.
|
| 3 |
+
|
| 4 |
+
### ChestAgentBench
|
| 5 |
+
|
| 6 |
+
To run gpt-4o on ChestAgentBench, enter the `experiments` directory and run the following script:
|
| 7 |
+
```bash
|
| 8 |
+
python benchmark_gpt4o.py
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
To run llama 3.2 vision 90B on ChestAgentBench, run the following:
|
| 12 |
+
```bash
|
| 13 |
+
python benchmark_llama.py
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
To run chexagent on ChestAgentBench, run the following:
|
| 17 |
+
```bash
|
| 18 |
+
python benchmark_chexagent.py
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
To run llava-med on ChestAgentBench, you'll need to clone their repo and copy the following script into it, after you follow their setup instructions.
|
| 22 |
+
```bash
|
| 23 |
+
mv benchmark_llavamed.py ~/LLaVA-Med/llava/serve
|
| 24 |
+
python -m llava.serve.benchmark_llavamed --model-name llava-med-v1.5-mistral-7b --controller http://localhost:10000
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
If you want to inspect the logs, you can run the following. It will select the most recent log file by default.
|
| 28 |
+
```bash
|
| 29 |
+
python inspect_logs.py [optional: log-file] -n [num-logs]
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Finally, to analyze results, run:
|
| 33 |
+
```bash
|
| 34 |
+
python analyze_axes.py results/[logfile].json ../benchmark/questions/ --model [gpt4|llama|chexagent|llava-med] --max-questions [optional:int]
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### CheXbench
|
| 38 |
+
|
| 39 |
+
To run the models on chexbench, you can use `chexbench_gpt4.py` as a reference. You'll need to download the dataset files locally, and upload them for each request. Rad-ReStruct and Open-I use the same set of images, so you can download the `NLMCXR.zip` file just once and copy the images to both directories.
|
| 40 |
+
|
| 41 |
+
You can find the datasets here:
|
| 42 |
+
1. [SLAKE: A Semantically-Labeled Knowledge-Enhanced Dataset for Medical Visual Question Answering](https://www.med-vqa.com/slake/). Save this to `MedMAX/data/slake`.
|
| 43 |
+
2. [Rad-ReStruct: A Novel VQA Benchmark and Method for Structured Radiology Reporting](https://github.com/ChantalMP/Rad-ReStruct). Save the images to `MedMAX/data/rad-restruct/images`.
|
| 44 |
+
3. [Open-I Service of the National Library of Medicine](https://openi.nlm.nih.gov/faq). Save the images to `MedMAX/data/openi/images`.
|
| 45 |
+
|
| 46 |
+
Once you're finished, you'll want to fix the paths in the `chexbench.json` file to your local paths using the `MedMax/data/fix_chexbench.py` script.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
### Compare Runs
|
| 50 |
+
Analyze a single file based on overall accuracy and along different axes
|
| 51 |
+
```
|
| 52 |
+
python compare_runs.py results/medmax.json
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
For a direct evaluation comparing **2** models, on the exact same questions
|
| 56 |
+
```
|
| 57 |
+
python compare_runs.py results/medmax.json results/gpt4o.json
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
For a direct evaluation comparing **ALL** models, on the exact same questions (add as many model log files as you want).
|
| 61 |
+
```
|
| 62 |
+
python compare_runs.py results/medmax.json results/gpt4o.json results/llama.json results/chexagent.json results/llavamed.json
|
| 63 |
+
```
|
experiments/analyze_axes.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import argparse
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
QUESTION_TYPES = {
|
| 10 |
+
"Detailed Finding Analysis": ["detection", "localization", "characterization"],
|
| 11 |
+
"Pattern Recognition & Relations": ["detection", "classification", "relationship"],
|
| 12 |
+
"Spatial Understanding": ["localization", "comparison", "relationship"],
|
| 13 |
+
"Clinical Decision Making": ["classification", "comparison", "diagnosis"],
|
| 14 |
+
"Diagnostic Classification": ["classification", "characterization", "diagnosis"],
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]:
|
| 19 |
+
"""
|
| 20 |
+
Extract just the letter from various answer formats.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
answer: The answer text to extract letter from
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Optional[str]: The extracted letter in uppercase, or None if no letter found
|
| 27 |
+
"""
|
| 28 |
+
if not answer:
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
# Convert to string and clean
|
| 32 |
+
answer = str(answer).strip()
|
| 33 |
+
|
| 34 |
+
# If it's just a single letter, return it
|
| 35 |
+
if len(answer) == 1 and answer.isalpha():
|
| 36 |
+
return answer.upper()
|
| 37 |
+
|
| 38 |
+
# Try to extract letter from format like "A)" or "A."
|
| 39 |
+
if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ":
|
| 40 |
+
return answer[0].upper()
|
| 41 |
+
|
| 42 |
+
# Try to extract letter from format like "A) Some text"
|
| 43 |
+
if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")):
|
| 44 |
+
return answer[0].upper()
|
| 45 |
+
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def analyze_gpt4_results(
|
| 50 |
+
results_file: str, max_questions: Optional[int] = None
|
| 51 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
| 52 |
+
"""
|
| 53 |
+
Analyze results in GPT-4 format.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
results_file: Path to results file
|
| 57 |
+
max_questions: Maximum number of questions to analyze
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Tuple containing:
|
| 61 |
+
- overall_accuracy (float)
|
| 62 |
+
- category_accuracies (Dict)
|
| 63 |
+
- question_type_stats (Dict)
|
| 64 |
+
- correct_ids (List[str])
|
| 65 |
+
- incorrect_ids (List[str])
|
| 66 |
+
"""
|
| 67 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
| 68 |
+
all_questions = 0
|
| 69 |
+
all_correct = 0
|
| 70 |
+
correct_ids = []
|
| 71 |
+
incorrect_ids = []
|
| 72 |
+
|
| 73 |
+
with open(results_file, "r") as f:
|
| 74 |
+
lines = f.readlines()
|
| 75 |
+
|
| 76 |
+
processed_questions = 0
|
| 77 |
+
|
| 78 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
| 79 |
+
# Check if we've hit the maximum questions
|
| 80 |
+
if max_questions is not None and processed_questions >= max_questions:
|
| 81 |
+
break
|
| 82 |
+
if line.startswith("HTTP Request:"):
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
entry = json.loads(line)
|
| 87 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
| 88 |
+
question_id = entry.get("question_id")
|
| 89 |
+
|
| 90 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
| 91 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
| 92 |
+
|
| 93 |
+
if model_letter and correct_letter:
|
| 94 |
+
all_questions += 1
|
| 95 |
+
processed_questions += 1
|
| 96 |
+
is_correct = model_letter == correct_letter
|
| 97 |
+
|
| 98 |
+
if is_correct:
|
| 99 |
+
all_correct += 1
|
| 100 |
+
correct_ids.append(question_id)
|
| 101 |
+
else:
|
| 102 |
+
incorrect_ids.append(question_id)
|
| 103 |
+
|
| 104 |
+
for category in metadata.get("categories", []):
|
| 105 |
+
category_performance[category]["total"] += 1
|
| 106 |
+
if is_correct:
|
| 107 |
+
category_performance[category]["correct"] += 1
|
| 108 |
+
|
| 109 |
+
except json.JSONDecodeError:
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
return process_results(
|
| 113 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def analyze_llama_results(
|
| 118 |
+
results_file: str, max_questions: Optional[int] = None
|
| 119 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
| 120 |
+
"""
|
| 121 |
+
Analyze results in Llama format.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
results_file: Path to results file
|
| 125 |
+
max_questions: Maximum number of questions to analyze
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Tuple containing:
|
| 129 |
+
- overall_accuracy (float)
|
| 130 |
+
- category_accuracies (Dict)
|
| 131 |
+
- question_type_stats (Dict)
|
| 132 |
+
- correct_ids (List[str])
|
| 133 |
+
- incorrect_ids (List[str])
|
| 134 |
+
"""
|
| 135 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
| 136 |
+
all_questions = 0
|
| 137 |
+
all_correct = 0
|
| 138 |
+
correct_ids = []
|
| 139 |
+
incorrect_ids = []
|
| 140 |
+
|
| 141 |
+
with open(results_file, "r") as f:
|
| 142 |
+
lines = f.readlines()
|
| 143 |
+
|
| 144 |
+
# If max_questions is set, limit the number of lines processed
|
| 145 |
+
if max_questions is not None:
|
| 146 |
+
lines = lines[:max_questions]
|
| 147 |
+
|
| 148 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
| 149 |
+
if line.startswith("HTTP Request:"):
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
entry = json.loads(line)
|
| 154 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
| 155 |
+
question_id = entry.get("question_id")
|
| 156 |
+
|
| 157 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
| 158 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
| 159 |
+
|
| 160 |
+
if model_letter and correct_letter:
|
| 161 |
+
all_questions += 1
|
| 162 |
+
is_correct = model_letter == correct_letter
|
| 163 |
+
|
| 164 |
+
if is_correct:
|
| 165 |
+
all_correct += 1
|
| 166 |
+
correct_ids.append(question_id)
|
| 167 |
+
else:
|
| 168 |
+
incorrect_ids.append(question_id)
|
| 169 |
+
|
| 170 |
+
for category in metadata.get("categories", []):
|
| 171 |
+
category_performance[category]["total"] += 1
|
| 172 |
+
if is_correct:
|
| 173 |
+
category_performance[category]["correct"] += 1
|
| 174 |
+
|
| 175 |
+
except json.JSONDecodeError:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
return process_results(
|
| 179 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def analyze_chexagent_results(
|
| 184 |
+
results_file: str, max_questions: Optional[int] = None
|
| 185 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
| 186 |
+
"""
|
| 187 |
+
Analyze results in CheXagent format.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
results_file: Path to results file
|
| 191 |
+
max_questions: Maximum number of questions to analyze
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Tuple containing:
|
| 195 |
+
- overall_accuracy (float)
|
| 196 |
+
- category_accuracies (Dict)
|
| 197 |
+
- question_type_stats (Dict)
|
| 198 |
+
- correct_ids (List[str])
|
| 199 |
+
- incorrect_ids (List[str])
|
| 200 |
+
"""
|
| 201 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
| 202 |
+
all_questions = 0
|
| 203 |
+
all_correct = 0
|
| 204 |
+
correct_ids = []
|
| 205 |
+
incorrect_ids = []
|
| 206 |
+
|
| 207 |
+
with open(results_file, "r") as f:
|
| 208 |
+
lines = f.readlines()
|
| 209 |
+
|
| 210 |
+
# If max_questions is set, limit the number of lines processed
|
| 211 |
+
if max_questions is not None:
|
| 212 |
+
lines = lines[:max_questions]
|
| 213 |
+
|
| 214 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
| 215 |
+
try:
|
| 216 |
+
entry = json.loads(line)
|
| 217 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
| 218 |
+
question_id = entry.get("question_id")
|
| 219 |
+
|
| 220 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
| 221 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
| 222 |
+
|
| 223 |
+
if model_letter and correct_letter:
|
| 224 |
+
all_questions += 1
|
| 225 |
+
is_correct = model_letter == correct_letter
|
| 226 |
+
|
| 227 |
+
if is_correct:
|
| 228 |
+
all_correct += 1
|
| 229 |
+
correct_ids.append(question_id)
|
| 230 |
+
else:
|
| 231 |
+
incorrect_ids.append(question_id)
|
| 232 |
+
|
| 233 |
+
for category in metadata.get("categories", []):
|
| 234 |
+
category_performance[category]["total"] += 1
|
| 235 |
+
if is_correct:
|
| 236 |
+
category_performance[category]["correct"] += 1
|
| 237 |
+
|
| 238 |
+
except json.JSONDecodeError:
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
return process_results(
|
| 242 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def process_results(
|
| 247 |
+
category_performance: Dict,
|
| 248 |
+
all_questions: int,
|
| 249 |
+
all_correct: int,
|
| 250 |
+
correct_ids: Optional[List[str]] = None,
|
| 251 |
+
incorrect_ids: Optional[List[str]] = None,
|
| 252 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
| 253 |
+
"""
|
| 254 |
+
Process raw results into final statistics.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
category_performance: Dict containing performance by category
|
| 258 |
+
all_questions: Total number of questions
|
| 259 |
+
all_correct: Total number of correct answers
|
| 260 |
+
correct_ids: List of IDs for correctly answered questions
|
| 261 |
+
incorrect_ids: List of IDs for incorrectly answered questions
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
Tuple containing:
|
| 265 |
+
- overall_accuracy (float)
|
| 266 |
+
- category_accuracies (Dict)
|
| 267 |
+
- question_type_stats (Dict)
|
| 268 |
+
- correct_ids (List[str])
|
| 269 |
+
- incorrect_ids (List[str])
|
| 270 |
+
"""
|
| 271 |
+
category_accuracies = {
|
| 272 |
+
category: {
|
| 273 |
+
"accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0,
|
| 274 |
+
"total": stats["total"],
|
| 275 |
+
"correct": stats["correct"],
|
| 276 |
+
}
|
| 277 |
+
for category, stats in category_performance.items()
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
question_type_stats = {}
|
| 281 |
+
for qtype, categories in QUESTION_TYPES.items():
|
| 282 |
+
total = sum(
|
| 283 |
+
category_performance[cat]["total"] for cat in categories if cat in category_performance
|
| 284 |
+
)
|
| 285 |
+
correct = sum(
|
| 286 |
+
category_performance[cat]["correct"]
|
| 287 |
+
for cat in categories
|
| 288 |
+
if cat in category_performance
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
question_type_stats[qtype] = {
|
| 292 |
+
"accuracy": (correct / total * 100) if total > 0 else 0,
|
| 293 |
+
"total": total,
|
| 294 |
+
"correct": correct,
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0
|
| 298 |
+
|
| 299 |
+
return (
|
| 300 |
+
overall_accuracy,
|
| 301 |
+
category_accuracies,
|
| 302 |
+
question_type_stats,
|
| 303 |
+
correct_ids or [],
|
| 304 |
+
incorrect_ids or [],
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def print_analysis(
|
| 309 |
+
overall_accuracy: float,
|
| 310 |
+
category_accuracies: Dict,
|
| 311 |
+
question_type_stats: Dict,
|
| 312 |
+
correct_ids: List[str],
|
| 313 |
+
incorrect_ids: List[str],
|
| 314 |
+
model_name: str,
|
| 315 |
+
) -> None:
|
| 316 |
+
"""
|
| 317 |
+
Print analysis results.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
overall_accuracy: Overall accuracy percentage
|
| 321 |
+
category_accuracies: Dict containing accuracy metrics by category
|
| 322 |
+
question_type_stats: Dict containing stats by question type
|
| 323 |
+
correct_ids: List of IDs for correctly answered questions
|
| 324 |
+
incorrect_ids: List of IDs for incorrectly answered questions
|
| 325 |
+
model_name: Name of the model being analyzed
|
| 326 |
+
"""
|
| 327 |
+
total_questions = len(correct_ids) + len(incorrect_ids)
|
| 328 |
+
print(
|
| 329 |
+
f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
print("\nCategory Performance:")
|
| 333 |
+
sorted_categories = sorted(
|
| 334 |
+
category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True
|
| 335 |
+
)
|
| 336 |
+
for category, metrics in sorted_categories:
|
| 337 |
+
print(f"{category}:")
|
| 338 |
+
print(f" Accuracy: {metrics['accuracy']:.2f}%")
|
| 339 |
+
print(f" Total Questions: {metrics['total']}")
|
| 340 |
+
print(f" Correct Questions: {metrics['correct']}")
|
| 341 |
+
|
| 342 |
+
print("\nQuestion Type Performance:")
|
| 343 |
+
sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True)
|
| 344 |
+
for qtype, metrics in sorted_types:
|
| 345 |
+
print(f"\n{qtype}:")
|
| 346 |
+
print(f" Accuracy: {metrics['accuracy']:.2f}%")
|
| 347 |
+
print(f" Total Questions: {metrics['total']}")
|
| 348 |
+
print(f" Correct Questions: {metrics['correct']}")
|
| 349 |
+
print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}")
|
| 350 |
+
|
| 351 |
+
# Save question IDs to JSON
|
| 352 |
+
question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids}
|
| 353 |
+
|
| 354 |
+
output_filename = f"{model_name}_question_ids.json"
|
| 355 |
+
with open(output_filename, "w") as f:
|
| 356 |
+
json.dump(question_ids, f, indent=2)
|
| 357 |
+
|
| 358 |
+
print(f"\nQuestion IDs have been saved to {output_filename}")
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
parser = argparse.ArgumentParser(description="Analyze benchmark results")
|
| 363 |
+
parser.add_argument("results_file", help="Path to results file")
|
| 364 |
+
parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory")
|
| 365 |
+
parser.add_argument(
|
| 366 |
+
"--model",
|
| 367 |
+
choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"],
|
| 368 |
+
default="gpt4",
|
| 369 |
+
help="Specify model format (default: gpt4)",
|
| 370 |
+
)
|
| 371 |
+
parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze")
|
| 372 |
+
args = parser.parse_args()
|
| 373 |
+
|
| 374 |
+
if args.model == "gpt4":
|
| 375 |
+
results = analyze_gpt4_results(args.results_file, args.max_questions)
|
| 376 |
+
elif args.model == "llama":
|
| 377 |
+
results = analyze_llama_results(args.results_file, args.max_questions)
|
| 378 |
+
elif args.model == "chexagent":
|
| 379 |
+
results = analyze_chexagent_results(args.results_file, args.max_questions)
|
| 380 |
+
elif args.model == "medrax":
|
| 381 |
+
results = analyze_gpt4_results(args.results_file, args.max_questions)
|
| 382 |
+
else:
|
| 383 |
+
parser.error(f"Unsupported model: {args.model}")
|
| 384 |
+
|
| 385 |
+
print_analysis(*results, args.model)
|
experiments/benchmark_chexagent.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import glob
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
# Configure model settings
|
| 14 |
+
MODEL_NAME = "StanfordAIMI/CheXagent-2-3b"
|
| 15 |
+
DTYPE = torch.bfloat16
|
| 16 |
+
DEVICE = "cuda"
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
log_filename = f"model_inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 20 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def initialize_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
| 24 |
+
"""Initialize the CheXagent model and tokenizer.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
tuple containing:
|
| 28 |
+
- AutoModelForCausalLM: The initialized CheXagent model
|
| 29 |
+
- AutoTokenizer: The initialized tokenizer
|
| 30 |
+
"""
|
| 31 |
+
print("Loading model and tokenizer...")
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
| 33 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 34 |
+
MODEL_NAME, device_map="auto", trust_remote_code=True
|
| 35 |
+
)
|
| 36 |
+
model = model.to(DTYPE)
|
| 37 |
+
model.eval()
|
| 38 |
+
return model, tokenizer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def create_inference_request(
|
| 42 |
+
question_data: dict,
|
| 43 |
+
case_details: dict,
|
| 44 |
+
case_id: str,
|
| 45 |
+
question_id: str,
|
| 46 |
+
model: AutoModelForCausalLM,
|
| 47 |
+
tokenizer: AutoTokenizer,
|
| 48 |
+
) -> str | None:
|
| 49 |
+
"""Create and execute an inference request for the CheXagent model.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
question_data: Dictionary containing question details and metadata
|
| 53 |
+
case_details: Dictionary containing case information and image paths
|
| 54 |
+
case_id: Unique identifier for the medical case
|
| 55 |
+
question_id: Unique identifier for the question
|
| 56 |
+
model: The initialized CheXagent model
|
| 57 |
+
tokenizer: The initialized tokenizer
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
str | None: Single letter answer (A-F) if successful, None if failed
|
| 61 |
+
"""
|
| 62 |
+
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
|
| 63 |
+
Rules:
|
| 64 |
+
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
|
| 65 |
+
2. Do not add periods, explanations, or any other text
|
| 66 |
+
3. Do not use markdown or formatting
|
| 67 |
+
4. Do not restate the question
|
| 68 |
+
5. Do not explain your reasoning
|
| 69 |
+
|
| 70 |
+
Examples of valid responses:
|
| 71 |
+
A
|
| 72 |
+
B
|
| 73 |
+
C
|
| 74 |
+
|
| 75 |
+
Examples of invalid responses:
|
| 76 |
+
"A."
|
| 77 |
+
"Answer: B"
|
| 78 |
+
"C) This shows..."
|
| 79 |
+
"The answer is D"
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
prompt = f"""Given the following medical case:
|
| 83 |
+
Please answer this multiple choice question:
|
| 84 |
+
{question_data['question']}
|
| 85 |
+
Base your answer only on the provided images and case information."""
|
| 86 |
+
|
| 87 |
+
# Parse required figures
|
| 88 |
+
try:
|
| 89 |
+
if isinstance(question_data["figures"], str):
|
| 90 |
+
try:
|
| 91 |
+
required_figures = json.loads(question_data["figures"])
|
| 92 |
+
except json.JSONDecodeError:
|
| 93 |
+
required_figures = [question_data["figures"]]
|
| 94 |
+
elif isinstance(question_data["figures"], list):
|
| 95 |
+
required_figures = question_data["figures"]
|
| 96 |
+
else:
|
| 97 |
+
required_figures = [str(question_data["figures"])]
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Error parsing figures: {e}")
|
| 100 |
+
required_figures = []
|
| 101 |
+
|
| 102 |
+
required_figures = [
|
| 103 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# Get image paths
|
| 107 |
+
image_paths = []
|
| 108 |
+
for figure in required_figures:
|
| 109 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
| 110 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
| 111 |
+
|
| 112 |
+
matching_figures = [
|
| 113 |
+
case_figure
|
| 114 |
+
for case_figure in case_details.get("figures", [])
|
| 115 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
for case_figure in matching_figures:
|
| 119 |
+
subfigures = []
|
| 120 |
+
if figure_letter:
|
| 121 |
+
subfigures = [
|
| 122 |
+
subfig
|
| 123 |
+
for subfig in case_figure.get("subfigures", [])
|
| 124 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
| 125 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
| 126 |
+
]
|
| 127 |
+
else:
|
| 128 |
+
subfigures = case_figure.get("subfigures", [])
|
| 129 |
+
|
| 130 |
+
for subfig in subfigures:
|
| 131 |
+
if "local_path" in subfig:
|
| 132 |
+
image_paths.append("medrax/data/" + subfig["local_path"])
|
| 133 |
+
|
| 134 |
+
if not image_paths:
|
| 135 |
+
print(f"No local images found for case {case_id}, question {question_id}")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
start_time = time.time()
|
| 140 |
+
|
| 141 |
+
# Prepare input for the model
|
| 142 |
+
query = tokenizer.from_list_format(
|
| 143 |
+
[*[{"image": path} for path in image_paths], {"text": prompt}]
|
| 144 |
+
)
|
| 145 |
+
conv = [{"from": "system", "value": system_prompt}, {"from": "human", "value": query}]
|
| 146 |
+
input_ids = tokenizer.apply_chat_template(
|
| 147 |
+
conv, add_generation_prompt=True, return_tensors="pt"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Generate response
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
output = model.generate(
|
| 153 |
+
input_ids.to(DEVICE),
|
| 154 |
+
do_sample=False,
|
| 155 |
+
num_beams=1,
|
| 156 |
+
temperature=1.0,
|
| 157 |
+
top_p=1.0,
|
| 158 |
+
use_cache=True,
|
| 159 |
+
max_new_tokens=512,
|
| 160 |
+
)[0]
|
| 161 |
+
|
| 162 |
+
response = tokenizer.decode(output[input_ids.size(1) : -1])
|
| 163 |
+
duration = time.time() - start_time
|
| 164 |
+
|
| 165 |
+
# Clean response
|
| 166 |
+
clean_answer = validate_answer(response)
|
| 167 |
+
|
| 168 |
+
# Log response
|
| 169 |
+
log_entry = {
|
| 170 |
+
"case_id": case_id,
|
| 171 |
+
"question_id": question_id,
|
| 172 |
+
"timestamp": datetime.now().isoformat(),
|
| 173 |
+
"model": MODEL_NAME,
|
| 174 |
+
"duration": round(duration, 2),
|
| 175 |
+
"model_answer": clean_answer,
|
| 176 |
+
"correct_answer": question_data["answer"],
|
| 177 |
+
"input": {
|
| 178 |
+
"question_data": {
|
| 179 |
+
"question": question_data["question"],
|
| 180 |
+
"explanation": question_data["explanation"],
|
| 181 |
+
"metadata": question_data.get("metadata", {}),
|
| 182 |
+
"figures": question_data["figures"],
|
| 183 |
+
},
|
| 184 |
+
"image_paths": image_paths,
|
| 185 |
+
},
|
| 186 |
+
}
|
| 187 |
+
logging.info(json.dumps(log_entry))
|
| 188 |
+
return clean_answer
|
| 189 |
+
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
|
| 192 |
+
log_entry = {
|
| 193 |
+
"case_id": case_id,
|
| 194 |
+
"question_id": question_id,
|
| 195 |
+
"timestamp": datetime.now().isoformat(),
|
| 196 |
+
"model": MODEL_NAME,
|
| 197 |
+
"status": "error",
|
| 198 |
+
"error": str(e),
|
| 199 |
+
"input": {
|
| 200 |
+
"question_data": {
|
| 201 |
+
"question": question_data["question"],
|
| 202 |
+
"explanation": question_data["explanation"],
|
| 203 |
+
"metadata": question_data.get("metadata", {}),
|
| 204 |
+
"figures": question_data["figures"],
|
| 205 |
+
},
|
| 206 |
+
"image_paths": image_paths,
|
| 207 |
+
},
|
| 208 |
+
}
|
| 209 |
+
logging.info(json.dumps(log_entry))
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def validate_answer(response_text: str) -> str | None:
|
| 214 |
+
"""Enforce strict single-letter response format.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
response_text: Raw response text from the model
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
str | None: Single uppercase letter (A-F) if valid, None if invalid
|
| 221 |
+
"""
|
| 222 |
+
if not response_text:
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
# Remove all whitespace and convert to uppercase
|
| 226 |
+
cleaned = response_text.strip().upper()
|
| 227 |
+
|
| 228 |
+
# Check if it's exactly one valid letter
|
| 229 |
+
if len(cleaned) == 1 and cleaned in "ABCDEF":
|
| 230 |
+
return cleaned
|
| 231 |
+
|
| 232 |
+
# If not, try to extract just the letter
|
| 233 |
+
match = re.search(r"([A-F])", cleaned)
|
| 234 |
+
return match.group(1) if match else None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def load_benchmark_questions(case_id: str) -> list[str]:
|
| 238 |
+
"""Find all question files for a given case ID.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
case_id: Unique identifier for the medical case
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
list[str]: List of paths to question JSON files
|
| 245 |
+
"""
|
| 246 |
+
benchmark_dir = "../benchmark/questions"
|
| 247 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def count_total_questions() -> tuple[int, int]:
|
| 251 |
+
"""Count total number of cases and questions in benchmark.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
tuple containing:
|
| 255 |
+
- int: Total number of cases
|
| 256 |
+
- int: Total number of questions
|
| 257 |
+
"""
|
| 258 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
| 259 |
+
total_questions = sum(
|
| 260 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
| 261 |
+
for case_id in os.listdir("../benchmark/questions")
|
| 262 |
+
)
|
| 263 |
+
return total_cases, total_questions
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def main():
|
| 267 |
+
# Load the cases with local paths
|
| 268 |
+
with open("medrax/data/updated_cases.json", "r") as file:
|
| 269 |
+
data = json.load(file)
|
| 270 |
+
|
| 271 |
+
# Initialize model and tokenizer
|
| 272 |
+
model, tokenizer = initialize_model()
|
| 273 |
+
|
| 274 |
+
total_cases, total_questions = count_total_questions()
|
| 275 |
+
cases_processed = 0
|
| 276 |
+
questions_processed = 0
|
| 277 |
+
skipped_questions = 0
|
| 278 |
+
|
| 279 |
+
print(f"\nBeginning inference with {MODEL_NAME}")
|
| 280 |
+
print(f"Found {total_cases} cases with {total_questions} total questions")
|
| 281 |
+
|
| 282 |
+
# Process each case with progress bar
|
| 283 |
+
for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
|
| 284 |
+
question_files = load_benchmark_questions(case_id)
|
| 285 |
+
if not question_files:
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
cases_processed += 1
|
| 289 |
+
for question_file in tqdm(
|
| 290 |
+
question_files, desc=f"Processing questions for case {case_id}", leave=False
|
| 291 |
+
):
|
| 292 |
+
with open(question_file, "r") as file:
|
| 293 |
+
question_data = json.load(file)
|
| 294 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
| 295 |
+
|
| 296 |
+
questions_processed += 1
|
| 297 |
+
answer = create_inference_request(
|
| 298 |
+
question_data, case_details, case_id, question_id, model, tokenizer
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if answer is None:
|
| 302 |
+
skipped_questions += 1
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
print(f"\nCase {case_id}, Question {question_id}")
|
| 306 |
+
print(f"Model Answer: {answer}")
|
| 307 |
+
print(f"Correct Answer: {question_data['answer']}")
|
| 308 |
+
|
| 309 |
+
print(f"\nInference Summary:")
|
| 310 |
+
print(f"Total Cases Processed: {cases_processed}")
|
| 311 |
+
print(f"Total Questions Processed: {questions_processed}")
|
| 312 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
main()
|
experiments/benchmark_gpt4o.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import openai
|
| 3 |
+
import os
|
| 4 |
+
import glob
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from tenacity import retry, wait_exponential, stop_after_attempt
|
| 9 |
+
|
| 10 |
+
model_name = "chatgpt-4o-latest"
|
| 11 |
+
temperature = 0.2
|
| 12 |
+
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 13 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def calculate_cost(
|
| 17 |
+
prompt_tokens: int, completion_tokens: int, model: str = "chatgpt-4o-latest"
|
| 18 |
+
) -> float:
|
| 19 |
+
"""Calculate the cost of API usage based on token counts.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
prompt_tokens: Number of tokens in the prompt
|
| 23 |
+
completion_tokens: Number of tokens in the completion
|
| 24 |
+
model: Model name to use for pricing, defaults to chatgpt-4o-latest
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
float: Cost in USD
|
| 28 |
+
"""
|
| 29 |
+
pricing = {"chatgpt-4o-latest": {"prompt": 5.0, "completion": 15.0}}
|
| 30 |
+
rates = pricing.get(model, {"prompt": 5.0, "completion": 15.0})
|
| 31 |
+
return (prompt_tokens * rates["prompt"] + completion_tokens * rates["completion"]) / 1000000
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
| 35 |
+
def create_multimodal_request(
|
| 36 |
+
question_data: dict, case_details: dict, case_id: str, question_id: str, client: openai.OpenAI
|
| 37 |
+
) -> openai.types.chat.ChatCompletion:
|
| 38 |
+
"""Create and send a multimodal request to the OpenAI API.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
question_data: Dictionary containing question details and figures
|
| 42 |
+
case_details: Dictionary containing case information and figures
|
| 43 |
+
case_id: Identifier for the medical case
|
| 44 |
+
question_id: Identifier for the specific question
|
| 45 |
+
client: OpenAI client instance
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
openai.types.chat.ChatCompletion: API response object, or None if request fails
|
| 49 |
+
"""
|
| 50 |
+
prompt = f"""Given the following medical case:
|
| 51 |
+
Please answer this multiple choice question:
|
| 52 |
+
{question_data['question']}
|
| 53 |
+
Base your answer only on the provided images and case information."""
|
| 54 |
+
|
| 55 |
+
content = [{"type": "text", "text": prompt}]
|
| 56 |
+
|
| 57 |
+
# Parse required figures
|
| 58 |
+
try:
|
| 59 |
+
# Try multiple ways of parsing figures
|
| 60 |
+
if isinstance(question_data["figures"], str):
|
| 61 |
+
try:
|
| 62 |
+
required_figures = json.loads(question_data["figures"])
|
| 63 |
+
except json.JSONDecodeError:
|
| 64 |
+
required_figures = [question_data["figures"]]
|
| 65 |
+
elif isinstance(question_data["figures"], list):
|
| 66 |
+
required_figures = question_data["figures"]
|
| 67 |
+
else:
|
| 68 |
+
required_figures = [str(question_data["figures"])]
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Error parsing figures: {e}")
|
| 71 |
+
required_figures = []
|
| 72 |
+
|
| 73 |
+
# Ensure each figure starts with "Figure "
|
| 74 |
+
required_figures = [
|
| 75 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
subfigures = []
|
| 79 |
+
for figure in required_figures:
|
| 80 |
+
# Handle both regular figures and those with letter suffixes
|
| 81 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
| 82 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
| 83 |
+
|
| 84 |
+
# Find matching figures in case details
|
| 85 |
+
matching_figures = [
|
| 86 |
+
case_figure
|
| 87 |
+
for case_figure in case_details.get("figures", [])
|
| 88 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
if not matching_figures:
|
| 92 |
+
print(f"No matching figure found for {figure} in case {case_id}")
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
for case_figure in matching_figures:
|
| 96 |
+
# If a specific letter is specified, filter subfigures
|
| 97 |
+
if figure_letter:
|
| 98 |
+
matching_subfigures = [
|
| 99 |
+
subfig
|
| 100 |
+
for subfig in case_figure.get("subfigures", [])
|
| 101 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
| 102 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
| 103 |
+
]
|
| 104 |
+
subfigures.extend(matching_subfigures)
|
| 105 |
+
else:
|
| 106 |
+
# If no letter specified, add all subfigures
|
| 107 |
+
subfigures.extend(case_figure.get("subfigures", []))
|
| 108 |
+
|
| 109 |
+
# Add images to content
|
| 110 |
+
for subfig in subfigures:
|
| 111 |
+
if "url" in subfig:
|
| 112 |
+
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
|
| 113 |
+
else:
|
| 114 |
+
print(f"Subfigure missing URL: {subfig}")
|
| 115 |
+
|
| 116 |
+
# If no images found, log and return None
|
| 117 |
+
if len(content) == 1: # Only the text prompt exists
|
| 118 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
messages = [
|
| 122 |
+
{
|
| 123 |
+
"role": "system",
|
| 124 |
+
"content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F).",
|
| 125 |
+
},
|
| 126 |
+
{"role": "user", "content": content},
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
if len(content) == 1: # Only the text prompt exists
|
| 130 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
| 131 |
+
log_entry = {
|
| 132 |
+
"case_id": case_id,
|
| 133 |
+
"question_id": question_id,
|
| 134 |
+
"timestamp": datetime.now().isoformat(),
|
| 135 |
+
"model": model_name,
|
| 136 |
+
"temperature": temperature,
|
| 137 |
+
"status": "skipped",
|
| 138 |
+
"reason": "no_images",
|
| 139 |
+
"cost": 0,
|
| 140 |
+
"input": {
|
| 141 |
+
"messages": messages,
|
| 142 |
+
"question_data": {
|
| 143 |
+
"question": question_data["question"],
|
| 144 |
+
"explanation": question_data["explanation"],
|
| 145 |
+
"metadata": question_data.get("metadata", {}),
|
| 146 |
+
"figures": question_data["figures"],
|
| 147 |
+
},
|
| 148 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
| 149 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
| 150 |
+
},
|
| 151 |
+
}
|
| 152 |
+
logging.info(json.dumps(log_entry))
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
start_time = time.time()
|
| 157 |
+
|
| 158 |
+
response = client.chat.completions.create(
|
| 159 |
+
model=model_name, messages=messages, max_tokens=50, temperature=temperature
|
| 160 |
+
)
|
| 161 |
+
duration = time.time() - start_time
|
| 162 |
+
|
| 163 |
+
log_entry = {
|
| 164 |
+
"case_id": case_id,
|
| 165 |
+
"question_id": question_id,
|
| 166 |
+
"timestamp": datetime.now().isoformat(),
|
| 167 |
+
"model": model_name,
|
| 168 |
+
"temperature": temperature,
|
| 169 |
+
"duration": round(duration, 2),
|
| 170 |
+
"usage": {
|
| 171 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
| 172 |
+
"completion_tokens": response.usage.completion_tokens,
|
| 173 |
+
"total_tokens": response.usage.total_tokens,
|
| 174 |
+
},
|
| 175 |
+
"cost": calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens),
|
| 176 |
+
"model_answer": response.choices[0].message.content,
|
| 177 |
+
"correct_answer": question_data["answer"],
|
| 178 |
+
"input": {
|
| 179 |
+
"messages": messages,
|
| 180 |
+
"question_data": {
|
| 181 |
+
"question": question_data["question"],
|
| 182 |
+
"explanation": question_data["explanation"],
|
| 183 |
+
"metadata": question_data.get("metadata", {}),
|
| 184 |
+
"figures": question_data["figures"],
|
| 185 |
+
},
|
| 186 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
| 187 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
| 188 |
+
},
|
| 189 |
+
}
|
| 190 |
+
logging.info(json.dumps(log_entry))
|
| 191 |
+
return response
|
| 192 |
+
|
| 193 |
+
except openai.RateLimitError:
|
| 194 |
+
log_entry = {
|
| 195 |
+
"case_id": case_id,
|
| 196 |
+
"question_id": question_id,
|
| 197 |
+
"timestamp": datetime.now().isoformat(),
|
| 198 |
+
"model": model_name,
|
| 199 |
+
"temperature": temperature,
|
| 200 |
+
"status": "error",
|
| 201 |
+
"reason": "rate_limit",
|
| 202 |
+
"cost": 0,
|
| 203 |
+
"input": {
|
| 204 |
+
"messages": messages,
|
| 205 |
+
"question_data": {
|
| 206 |
+
"question": question_data["question"],
|
| 207 |
+
"explanation": question_data["explanation"],
|
| 208 |
+
"metadata": question_data.get("metadata", {}),
|
| 209 |
+
"figures": question_data["figures"],
|
| 210 |
+
},
|
| 211 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
| 212 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
| 213 |
+
},
|
| 214 |
+
}
|
| 215 |
+
logging.info(json.dumps(log_entry))
|
| 216 |
+
print(
|
| 217 |
+
f"\nRate limit hit for case {case_id}, question {question_id}. Waiting 20s...",
|
| 218 |
+
flush=True,
|
| 219 |
+
)
|
| 220 |
+
time.sleep(20)
|
| 221 |
+
raise
|
| 222 |
+
except Exception as e:
|
| 223 |
+
log_entry = {
|
| 224 |
+
"case_id": case_id,
|
| 225 |
+
"question_id": question_id,
|
| 226 |
+
"timestamp": datetime.now().isoformat(),
|
| 227 |
+
"model": model_name,
|
| 228 |
+
"temperature": temperature,
|
| 229 |
+
"status": "error",
|
| 230 |
+
"error": str(e),
|
| 231 |
+
"cost": 0,
|
| 232 |
+
"input": {
|
| 233 |
+
"messages": messages,
|
| 234 |
+
"question_data": {
|
| 235 |
+
"question": question_data["question"],
|
| 236 |
+
"explanation": question_data["explanation"],
|
| 237 |
+
"metadata": question_data.get("metadata", {}),
|
| 238 |
+
"figures": question_data["figures"],
|
| 239 |
+
},
|
| 240 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
| 241 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
| 242 |
+
},
|
| 243 |
+
}
|
| 244 |
+
logging.info(json.dumps(log_entry))
|
| 245 |
+
print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
|
| 246 |
+
raise
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def load_benchmark_questions(case_id: str) -> list:
|
| 250 |
+
"""Load benchmark questions for a given case.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
case_id: Identifier for the medical case
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
list: List of paths to question files
|
| 257 |
+
"""
|
| 258 |
+
benchmark_dir = "../benchmark/questions"
|
| 259 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def count_total_questions() -> tuple[int, int]:
|
| 263 |
+
"""Count total number of cases and questions in benchmark.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
tuple: (total_cases, total_questions)
|
| 267 |
+
"""
|
| 268 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
| 269 |
+
total_questions = sum(
|
| 270 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
| 271 |
+
for case_id in os.listdir("../benchmark/questions")
|
| 272 |
+
)
|
| 273 |
+
return total_cases, total_questions
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def main() -> None:
|
| 277 |
+
"""Main function to run the benchmark evaluation."""
|
| 278 |
+
with open("../data/eurorad_metadata.json", "r") as file:
|
| 279 |
+
data = json.load(file)
|
| 280 |
+
|
| 281 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 282 |
+
if not api_key:
|
| 283 |
+
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
| 284 |
+
global client
|
| 285 |
+
client = openai.OpenAI(api_key=api_key)
|
| 286 |
+
|
| 287 |
+
total_cases, total_questions = count_total_questions()
|
| 288 |
+
cases_processed = 0
|
| 289 |
+
questions_processed = 0
|
| 290 |
+
skipped_questions = 0
|
| 291 |
+
|
| 292 |
+
print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}")
|
| 293 |
+
|
| 294 |
+
for case_id, case_details in data.items():
|
| 295 |
+
question_files = load_benchmark_questions(case_id)
|
| 296 |
+
if not question_files:
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
cases_processed += 1
|
| 300 |
+
for question_file in question_files:
|
| 301 |
+
with open(question_file, "r") as file:
|
| 302 |
+
question_data = json.load(file)
|
| 303 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
| 304 |
+
|
| 305 |
+
questions_processed += 1
|
| 306 |
+
response = create_multimodal_request(
|
| 307 |
+
question_data, case_details, case_id, question_id, client
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Handle cases where response is None
|
| 311 |
+
if response is None:
|
| 312 |
+
skipped_questions += 1
|
| 313 |
+
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
print(
|
| 317 |
+
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
|
| 318 |
+
)
|
| 319 |
+
print(f"Case ID: {case_id}")
|
| 320 |
+
print(f"Question ID: {question_id}")
|
| 321 |
+
print(f"Model Answer: {response.choices[0].message.content}")
|
| 322 |
+
print(f"Correct Answer: {question_data['answer']}\n")
|
| 323 |
+
|
| 324 |
+
print(f"\nBenchmark Summary:")
|
| 325 |
+
print(f"Total Cases Processed: {cases_processed}")
|
| 326 |
+
print(f"Total Questions Processed: {questions_processed}")
|
| 327 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
main()
|
experiments/benchmark_llama.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Any, Union
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import glob
|
| 6 |
+
import time
|
| 7 |
+
import logging
|
| 8 |
+
import socket
|
| 9 |
+
import requests
|
| 10 |
+
import httpx
|
| 11 |
+
import backoff
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from tenacity import retry, wait_exponential, stop_after_attempt
|
| 14 |
+
from openai import OpenAI
|
| 15 |
+
|
| 16 |
+
# Configure model settings
|
| 17 |
+
MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct"
|
| 18 |
+
temperature = 0.2
|
| 19 |
+
|
| 20 |
+
# Configure logging
|
| 21 |
+
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 22 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def verify_dns() -> bool:
|
| 26 |
+
"""Verify DNS resolution and connectivity.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
bool: True if DNS resolution succeeds, False otherwise
|
| 30 |
+
"""
|
| 31 |
+
try:
|
| 32 |
+
# Try to resolve openrouter.ai
|
| 33 |
+
socket.gethostbyname("openrouter.ai")
|
| 34 |
+
return True
|
| 35 |
+
except socket.gaierror:
|
| 36 |
+
print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...")
|
| 37 |
+
# Modify resolv.conf to use Google DNS
|
| 38 |
+
try:
|
| 39 |
+
with open("/etc/resolv.conf", "w") as f:
|
| 40 |
+
f.write("nameserver 8.8.8.8\n")
|
| 41 |
+
return True
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Failed to update DNS settings: {e}")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def verify_connection() -> bool:
|
| 48 |
+
"""Verify connection to OpenRouter API.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
bool: True if connection succeeds, False otherwise
|
| 52 |
+
"""
|
| 53 |
+
try:
|
| 54 |
+
response = requests.get("https://openrouter.ai/api/v1/status", timeout=10)
|
| 55 |
+
return response.status_code == 200
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Connection test failed: {e}")
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def initialize_client() -> OpenAI:
|
| 62 |
+
"""Initialize the OpenRouter client with proper timeout settings and connection verification.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
OpenAI: Configured OpenAI client for OpenRouter
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
ValueError: If OPENROUTER_API_KEY environment variable is not set
|
| 69 |
+
ConnectionError: If DNS verification or connection test fails
|
| 70 |
+
"""
|
| 71 |
+
api_key = os.getenv("OPENROUTER_API_KEY")
|
| 72 |
+
if not api_key:
|
| 73 |
+
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
|
| 74 |
+
|
| 75 |
+
# Configure timeout settings for the client
|
| 76 |
+
timeout_settings = 120 # Increased timeout for large images/responses
|
| 77 |
+
|
| 78 |
+
# Verify DNS and connection
|
| 79 |
+
if not verify_dns():
|
| 80 |
+
raise ConnectionError("DNS verification failed. Please check your network settings.")
|
| 81 |
+
|
| 82 |
+
if not verify_connection():
|
| 83 |
+
raise ConnectionError(
|
| 84 |
+
"Cannot connect to OpenRouter. Please check your internet connection."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Set up client with retry and timeout settings
|
| 88 |
+
return OpenAI(
|
| 89 |
+
base_url="https://openrouter.ai/api/v1",
|
| 90 |
+
api_key=api_key,
|
| 91 |
+
timeout=timeout_settings,
|
| 92 |
+
http_client=httpx.Client(
|
| 93 |
+
timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3)
|
| 94 |
+
),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@backoff.on_exception(
|
| 99 |
+
backoff.expo,
|
| 100 |
+
(ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError),
|
| 101 |
+
max_tries=5,
|
| 102 |
+
max_time=300, # Maximum total time to try in seconds
|
| 103 |
+
)
|
| 104 |
+
def create_multimodal_request(
|
| 105 |
+
question_data: Dict[str, Any],
|
| 106 |
+
case_details: Dict[str, Any],
|
| 107 |
+
case_id: str,
|
| 108 |
+
question_id: str,
|
| 109 |
+
client: OpenAI,
|
| 110 |
+
) -> Optional[Any]:
|
| 111 |
+
"""Create and send a multimodal request to the model.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
question_data: Dictionary containing question details
|
| 115 |
+
case_details: Dictionary containing case information
|
| 116 |
+
case_id: ID of the medical case
|
| 117 |
+
question_id: ID of the specific question
|
| 118 |
+
client: OpenAI client instance
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Optional[Any]: Model response if successful, None if skipped
|
| 122 |
+
|
| 123 |
+
Raises:
|
| 124 |
+
ConnectionError: If connection fails
|
| 125 |
+
TimeoutError: If request times out
|
| 126 |
+
Exception: For other errors
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
|
| 130 |
+
Rules:
|
| 131 |
+
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
|
| 132 |
+
2. Do not add periods, explanations, or any other text
|
| 133 |
+
3. Do not use markdown or formatting
|
| 134 |
+
4. Do not restate the question
|
| 135 |
+
5. Do not explain your reasoning
|
| 136 |
+
|
| 137 |
+
Examples of valid responses:
|
| 138 |
+
A
|
| 139 |
+
B
|
| 140 |
+
C
|
| 141 |
+
|
| 142 |
+
Examples of invalid responses:
|
| 143 |
+
"A."
|
| 144 |
+
"Answer: B"
|
| 145 |
+
"C) This shows..."
|
| 146 |
+
"The answer is D"
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
prompt = f"""Given the following medical case:
|
| 150 |
+
Please answer this multiple choice question:
|
| 151 |
+
{question_data['question']}
|
| 152 |
+
Base your answer only on the provided images and case information."""
|
| 153 |
+
|
| 154 |
+
# Parse required figures
|
| 155 |
+
try:
|
| 156 |
+
if isinstance(question_data["figures"], str):
|
| 157 |
+
try:
|
| 158 |
+
required_figures = json.loads(question_data["figures"])
|
| 159 |
+
except json.JSONDecodeError:
|
| 160 |
+
required_figures = [question_data["figures"]]
|
| 161 |
+
elif isinstance(question_data["figures"], list):
|
| 162 |
+
required_figures = question_data["figures"]
|
| 163 |
+
else:
|
| 164 |
+
required_figures = [str(question_data["figures"])]
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"Error parsing figures: {e}")
|
| 167 |
+
required_figures = []
|
| 168 |
+
|
| 169 |
+
required_figures = [
|
| 170 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
# Process subfigures and prepare content
|
| 174 |
+
content = [{"type": "text", "text": prompt}]
|
| 175 |
+
image_urls = []
|
| 176 |
+
image_captions = []
|
| 177 |
+
|
| 178 |
+
for figure in required_figures:
|
| 179 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
| 180 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
| 181 |
+
|
| 182 |
+
matching_figures = [
|
| 183 |
+
case_figure
|
| 184 |
+
for case_figure in case_details.get("figures", [])
|
| 185 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
for case_figure in matching_figures:
|
| 189 |
+
subfigures = []
|
| 190 |
+
if figure_letter:
|
| 191 |
+
subfigures = [
|
| 192 |
+
subfig
|
| 193 |
+
for subfig in case_figure.get("subfigures", [])
|
| 194 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
| 195 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
| 196 |
+
]
|
| 197 |
+
else:
|
| 198 |
+
subfigures = case_figure.get("subfigures", [])
|
| 199 |
+
|
| 200 |
+
for subfig in subfigures:
|
| 201 |
+
if "url" in subfig:
|
| 202 |
+
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
|
| 203 |
+
image_urls.append(subfig["url"])
|
| 204 |
+
image_captions.append(subfig.get("caption", ""))
|
| 205 |
+
|
| 206 |
+
if len(content) == 1: # Only the text prompt exists
|
| 207 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
| 208 |
+
# Log the skipped question
|
| 209 |
+
log_entry = {
|
| 210 |
+
"case_id": case_id,
|
| 211 |
+
"question_id": question_id,
|
| 212 |
+
"timestamp": datetime.now().isoformat(),
|
| 213 |
+
"model": MODEL_NAME,
|
| 214 |
+
"status": "skipped",
|
| 215 |
+
"reason": "no_images",
|
| 216 |
+
"input": {
|
| 217 |
+
"question_data": {
|
| 218 |
+
"question": question_data["question"],
|
| 219 |
+
"explanation": question_data["explanation"],
|
| 220 |
+
"metadata": question_data.get("metadata", {}),
|
| 221 |
+
"figures": question_data["figures"],
|
| 222 |
+
},
|
| 223 |
+
"image_urls": image_urls,
|
| 224 |
+
},
|
| 225 |
+
}
|
| 226 |
+
logging.info(json.dumps(log_entry))
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
start_time = time.time()
|
| 231 |
+
|
| 232 |
+
response = client.chat.completions.create(
|
| 233 |
+
model=MODEL_NAME,
|
| 234 |
+
temperature=temperature,
|
| 235 |
+
messages=[
|
| 236 |
+
{"role": "system", "content": system_prompt},
|
| 237 |
+
{"role": "user", "content": content},
|
| 238 |
+
],
|
| 239 |
+
)
|
| 240 |
+
duration = time.time() - start_time
|
| 241 |
+
|
| 242 |
+
# Get raw response
|
| 243 |
+
raw_answer = response.choices[0].message.content
|
| 244 |
+
|
| 245 |
+
# Validate and clean
|
| 246 |
+
clean_answer = validate_answer(raw_answer)
|
| 247 |
+
|
| 248 |
+
if not clean_answer:
|
| 249 |
+
print(f"Warning: Invalid response format for case {case_id}, question {question_id}")
|
| 250 |
+
print(f"Raw response: {raw_answer}")
|
| 251 |
+
|
| 252 |
+
# Update response object with cleaned answer
|
| 253 |
+
response.choices[0].message.content = clean_answer
|
| 254 |
+
|
| 255 |
+
# Log response
|
| 256 |
+
log_entry = {
|
| 257 |
+
"case_id": case_id,
|
| 258 |
+
"question_id": question_id,
|
| 259 |
+
"timestamp": datetime.now().isoformat(),
|
| 260 |
+
"model": MODEL_NAME,
|
| 261 |
+
"temperature": temperature,
|
| 262 |
+
"duration": round(duration, 2),
|
| 263 |
+
"usage": {
|
| 264 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
| 265 |
+
"completion_tokens": response.usage.completion_tokens,
|
| 266 |
+
"total_tokens": response.usage.total_tokens,
|
| 267 |
+
},
|
| 268 |
+
"model_answer": response.choices[0].message.content,
|
| 269 |
+
"correct_answer": question_data["answer"],
|
| 270 |
+
"input": {
|
| 271 |
+
"question_data": {
|
| 272 |
+
"question": question_data["question"],
|
| 273 |
+
"explanation": question_data["explanation"],
|
| 274 |
+
"metadata": question_data.get("metadata", {}),
|
| 275 |
+
"figures": question_data["figures"],
|
| 276 |
+
},
|
| 277 |
+
"image_urls": image_urls,
|
| 278 |
+
},
|
| 279 |
+
}
|
| 280 |
+
logging.info(json.dumps(log_entry))
|
| 281 |
+
return response
|
| 282 |
+
|
| 283 |
+
except ConnectionError as e:
|
| 284 |
+
print(f"Connection error for case {case_id}, question {question_id}: {str(e)}")
|
| 285 |
+
print("Retrying after a longer delay...")
|
| 286 |
+
time.sleep(30) # Add a longer delay before retry
|
| 287 |
+
raise
|
| 288 |
+
except TimeoutError as e:
|
| 289 |
+
print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}")
|
| 290 |
+
print("Retrying with increased timeout...")
|
| 291 |
+
raise
|
| 292 |
+
except Exception as e:
|
| 293 |
+
# Log failed requests too
|
| 294 |
+
log_entry = {
|
| 295 |
+
"case_id": case_id,
|
| 296 |
+
"question_id": question_id,
|
| 297 |
+
"timestamp": datetime.now().isoformat(),
|
| 298 |
+
"model": MODEL_NAME,
|
| 299 |
+
"temperature": temperature,
|
| 300 |
+
"status": "error",
|
| 301 |
+
"error": str(e),
|
| 302 |
+
"input": {
|
| 303 |
+
"question_data": {
|
| 304 |
+
"question": question_data["question"],
|
| 305 |
+
"explanation": question_data["explanation"],
|
| 306 |
+
"metadata": question_data.get("metadata", {}),
|
| 307 |
+
"figures": question_data["figures"],
|
| 308 |
+
},
|
| 309 |
+
"image_urls": image_urls,
|
| 310 |
+
},
|
| 311 |
+
}
|
| 312 |
+
logging.info(json.dumps(log_entry))
|
| 313 |
+
raise
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def extract_answer(response_text: str) -> Optional[str]:
|
| 317 |
+
"""Extract single letter answer from model response.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
response_text: Raw text response from model
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
Optional[str]: Single letter answer if found, None otherwise
|
| 324 |
+
"""
|
| 325 |
+
# Convert to uppercase and remove periods
|
| 326 |
+
text = response_text.upper().replace(".", "")
|
| 327 |
+
|
| 328 |
+
# Look for common patterns
|
| 329 |
+
patterns = [
|
| 330 |
+
r"ANSWER:\s*([A-F])", # Matches "ANSWER: X"
|
| 331 |
+
r"OPTION\s*([A-F])", # Matches "OPTION X"
|
| 332 |
+
r"([A-F])\)", # Matches "X)"
|
| 333 |
+
r"\b([A-F])\b", # Matches single letter
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
for pattern in patterns:
|
| 337 |
+
matches = re.findall(pattern, text)
|
| 338 |
+
if matches:
|
| 339 |
+
return matches[0]
|
| 340 |
+
|
| 341 |
+
return None
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def validate_answer(response_text: str) -> Optional[str]:
|
| 345 |
+
"""Enforce strict single-letter response format.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
response_text: Raw text response from model
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
Optional[str]: Valid single letter answer if found, None otherwise
|
| 352 |
+
"""
|
| 353 |
+
if not response_text:
|
| 354 |
+
return None
|
| 355 |
+
|
| 356 |
+
# Remove all whitespace and convert to uppercase
|
| 357 |
+
cleaned = response_text.strip().upper()
|
| 358 |
+
|
| 359 |
+
# Check if it's exactly one valid letter
|
| 360 |
+
if len(cleaned) == 1 and cleaned in "ABCDEF":
|
| 361 |
+
return cleaned
|
| 362 |
+
|
| 363 |
+
# If not, try to extract just the letter
|
| 364 |
+
match = re.search(r"([A-F])", cleaned)
|
| 365 |
+
return match.group(1) if match else None
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def load_benchmark_questions(case_id: str) -> List[str]:
|
| 369 |
+
"""Find all question files for a given case ID.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
case_id: ID of the medical case
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
List[str]: List of paths to question files
|
| 376 |
+
"""
|
| 377 |
+
benchmark_dir = "../benchmark/questions"
|
| 378 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def count_total_questions() -> Tuple[int, int]:
|
| 382 |
+
"""Count total number of cases and questions.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
Tuple[int, int]: (total_cases, total_questions)
|
| 386 |
+
"""
|
| 387 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
| 388 |
+
total_questions = sum(
|
| 389 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
| 390 |
+
for case_id in os.listdir("../benchmark/questions")
|
| 391 |
+
)
|
| 392 |
+
return total_cases, total_questions
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def main():
|
| 396 |
+
with open("../data/eurorad_metadata.json", "r") as file:
|
| 397 |
+
data = json.load(file)
|
| 398 |
+
|
| 399 |
+
client = initialize_client()
|
| 400 |
+
total_cases, total_questions = count_total_questions()
|
| 401 |
+
cases_processed = 0
|
| 402 |
+
questions_processed = 0
|
| 403 |
+
skipped_questions = 0
|
| 404 |
+
|
| 405 |
+
print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}")
|
| 406 |
+
|
| 407 |
+
for case_id, case_details in data.items():
|
| 408 |
+
question_files = load_benchmark_questions(case_id)
|
| 409 |
+
if not question_files:
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
cases_processed += 1
|
| 413 |
+
for question_file in question_files:
|
| 414 |
+
with open(question_file, "r") as file:
|
| 415 |
+
question_data = json.load(file)
|
| 416 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
| 417 |
+
|
| 418 |
+
questions_processed += 1
|
| 419 |
+
response = create_multimodal_request(
|
| 420 |
+
question_data, case_details, case_id, question_id, client
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
if response is None:
|
| 424 |
+
skipped_questions += 1
|
| 425 |
+
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
|
| 426 |
+
continue
|
| 427 |
+
|
| 428 |
+
print(
|
| 429 |
+
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
|
| 430 |
+
)
|
| 431 |
+
print(f"Case ID: {case_id}")
|
| 432 |
+
print(f"Question ID: {question_id}")
|
| 433 |
+
print(f"Model Answer: {response.choices[0].message.content}")
|
| 434 |
+
print(f"Correct Answer: {question_data['answer']}\n")
|
| 435 |
+
|
| 436 |
+
print(f"\nBenchmark Summary:")
|
| 437 |
+
print(f"Total Cases Processed: {cases_processed}")
|
| 438 |
+
print(f"Total Questions Processed: {questions_processed}")
|
| 439 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
if __name__ == "__main__":
|
| 443 |
+
main()
|
experiments/benchmark_llavamed.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import requests
|
| 4 |
+
import base64
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from llava.conversation import conv_templates
|
| 8 |
+
import time
|
| 9 |
+
import os
|
| 10 |
+
import glob
|
| 11 |
+
import logging
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import re
|
| 15 |
+
from typing import Dict, List, Optional, Union, Any, Tuple
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def process_image(image_path: str, target_size: int = 640) -> Image.Image:
|
| 19 |
+
"""Process and resize an image to match model requirements.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
image_path: Path to the input image file
|
| 23 |
+
target_size: Target size for both width and height in pixels
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
PIL.Image: Processed and padded image with dimensions (target_size, target_size)
|
| 27 |
+
"""
|
| 28 |
+
image = Image.open(image_path)
|
| 29 |
+
if image.mode != "RGB":
|
| 30 |
+
image = image.convert("RGB")
|
| 31 |
+
|
| 32 |
+
# Calculate scaling to maintain aspect ratio
|
| 33 |
+
ratio = min(target_size / image.width, target_size / image.height)
|
| 34 |
+
new_size = (int(image.width * ratio), int(image.height * ratio))
|
| 35 |
+
|
| 36 |
+
# Resize image
|
| 37 |
+
image = image.resize(new_size, Image.LANCZOS)
|
| 38 |
+
|
| 39 |
+
# Create new image with padding
|
| 40 |
+
new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
|
| 41 |
+
# Paste resized image in center
|
| 42 |
+
offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2)
|
| 43 |
+
new_image.paste(image, offset)
|
| 44 |
+
|
| 45 |
+
return new_image
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def validate_answer(response_text: str) -> Optional[str]:
|
| 49 |
+
"""Extract and validate a single-letter response from the model's output.
|
| 50 |
+
Handles multiple response formats and edge cases.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
response_text: The full text output from the model
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
A single letter answer (A-F) or None if no valid answer found
|
| 57 |
+
"""
|
| 58 |
+
if not response_text:
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
# Clean the response text
|
| 62 |
+
cleaned = response_text.strip()
|
| 63 |
+
|
| 64 |
+
# Comprehensive set of patterns to extract the answer
|
| 65 |
+
extraction_patterns = [
|
| 66 |
+
# Strict format with explicit letter answer
|
| 67 |
+
r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b",
|
| 68 |
+
# Patterns for extracting from longer descriptions
|
| 69 |
+
r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
|
| 70 |
+
r"\b(?:answer|option)\s*([A-F])[):]\s*",
|
| 71 |
+
# Patterns for extracting from descriptive sentences
|
| 72 |
+
r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
|
| 73 |
+
r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b",
|
| 74 |
+
# Patterns with contextual words
|
| 75 |
+
r"characteriz[e]?d?\s+by\s+([A-F])\b",
|
| 76 |
+
r"indicat[e]?s?\s+([A-F])\b",
|
| 77 |
+
# Fallback to Option X or Letterr X formats
|
| 78 |
+
r"Option\s*([A-F])\b",
|
| 79 |
+
r"\b([A-F])\)\s*",
|
| 80 |
+
# Fallback to standalone letter
|
| 81 |
+
r"^\s*([A-F])\s*$",
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
# Try each pattern
|
| 85 |
+
for pattern in extraction_patterns:
|
| 86 |
+
matches = re.findall(pattern, cleaned, re.IGNORECASE)
|
| 87 |
+
for match in matches:
|
| 88 |
+
# Ensure match is a single valid letter
|
| 89 |
+
if isinstance(match, tuple):
|
| 90 |
+
match = match[0] if match[0] in "ABCDEF" else None
|
| 91 |
+
if match and match.upper() in "ABCDEF":
|
| 92 |
+
return match.upper()
|
| 93 |
+
|
| 94 |
+
# Final fallback: look for standalone letters in context
|
| 95 |
+
context_matches = re.findall(r"\b([A-F])\b", cleaned.upper())
|
| 96 |
+
context_letters = [m for m in context_matches if m in "ABCDEF"]
|
| 97 |
+
if context_letters:
|
| 98 |
+
return context_letters[0]
|
| 99 |
+
|
| 100 |
+
# No valid answer found
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def load_benchmark_questions(case_id: str) -> List[str]:
|
| 105 |
+
"""Find all question files for a given case ID.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
case_id: The ID of the medical case
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
List of paths to question JSON files
|
| 112 |
+
"""
|
| 113 |
+
benchmark_dir = "MedMAX/benchmark/questions"
|
| 114 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def count_total_questions() -> Tuple[int, int]:
|
| 118 |
+
"""Count total number of cases and questions in benchmark.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Tuple containing (total_cases, total_questions)
|
| 122 |
+
"""
|
| 123 |
+
total_cases = len(glob.glob("MedMAX/benchmark/questions/*"))
|
| 124 |
+
total_questions = sum(
|
| 125 |
+
len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json"))
|
| 126 |
+
for case_id in os.listdir("MedMAX/benchmark/questions")
|
| 127 |
+
)
|
| 128 |
+
return total_cases, total_questions
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def create_inference_request(
|
| 132 |
+
question_data: Dict[str, Any],
|
| 133 |
+
case_details: Dict[str, Any],
|
| 134 |
+
case_id: str,
|
| 135 |
+
question_id: str,
|
| 136 |
+
worker_addr: str,
|
| 137 |
+
model_name: str,
|
| 138 |
+
raw_output: bool = False,
|
| 139 |
+
) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]:
|
| 140 |
+
"""Create and send inference request to worker.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
question_data: Dictionary containing question details and figures
|
| 144 |
+
case_details: Dictionary containing case information and figures
|
| 145 |
+
case_id: Identifier for the medical case
|
| 146 |
+
question_id: Identifier for the specific question
|
| 147 |
+
worker_addr: Address of the worker endpoint
|
| 148 |
+
model_name: Name of the model to use
|
| 149 |
+
raw_output: Whether to return raw model output
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
If raw_output is False: Tuple of (validated_answer, duration)
|
| 153 |
+
If raw_output is True: Dictionary with full inference details
|
| 154 |
+
"""
|
| 155 |
+
system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
prompt = f"""Given the following medical case:
|
| 159 |
+
Please answer this multiple choice question:
|
| 160 |
+
{question_data['question']}
|
| 161 |
+
Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
# Parse required figures
|
| 165 |
+
if isinstance(question_data["figures"], str):
|
| 166 |
+
try:
|
| 167 |
+
required_figures = json.loads(question_data["figures"])
|
| 168 |
+
except json.JSONDecodeError:
|
| 169 |
+
required_figures = [question_data["figures"]]
|
| 170 |
+
elif isinstance(question_data["figures"], list):
|
| 171 |
+
required_figures = question_data["figures"]
|
| 172 |
+
else:
|
| 173 |
+
required_figures = [str(question_data["figures"])]
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f"Error parsing figures: {e}")
|
| 176 |
+
required_figures = []
|
| 177 |
+
|
| 178 |
+
required_figures = [
|
| 179 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
# Get image paths
|
| 183 |
+
image_paths = []
|
| 184 |
+
for figure in required_figures:
|
| 185 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
| 186 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
| 187 |
+
|
| 188 |
+
matching_figures = [
|
| 189 |
+
case_figure
|
| 190 |
+
for case_figure in case_details.get("figures", [])
|
| 191 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
for case_figure in matching_figures:
|
| 195 |
+
subfigures = []
|
| 196 |
+
if figure_letter:
|
| 197 |
+
subfigures = [
|
| 198 |
+
subfig
|
| 199 |
+
for subfig in case_figure.get("subfigures", [])
|
| 200 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
| 201 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
| 202 |
+
]
|
| 203 |
+
else:
|
| 204 |
+
subfigures = case_figure.get("subfigures", [])
|
| 205 |
+
|
| 206 |
+
for subfig in subfigures:
|
| 207 |
+
if "local_path" in subfig:
|
| 208 |
+
image_paths.append("MedMAX/data/" + subfig["local_path"])
|
| 209 |
+
|
| 210 |
+
if not image_paths:
|
| 211 |
+
print(f"No local images found for case {case_id}, question {question_id}")
|
| 212 |
+
return "skipped", 0.0 # Return a special 'skipped' marker
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
start_time = time.time()
|
| 216 |
+
|
| 217 |
+
# Process each image
|
| 218 |
+
processed_images = [process_image(path) for path in image_paths]
|
| 219 |
+
|
| 220 |
+
# Create conversation
|
| 221 |
+
conv = conv_templates["mistral_instruct"].copy()
|
| 222 |
+
|
| 223 |
+
# Add image and message
|
| 224 |
+
if "<image>" not in prompt:
|
| 225 |
+
text = prompt + "\n<image>"
|
| 226 |
+
else:
|
| 227 |
+
text = prompt
|
| 228 |
+
|
| 229 |
+
message = (text, processed_images[0], "Default") # Currently handling first image
|
| 230 |
+
conv.append_message(conv.roles[0], message)
|
| 231 |
+
conv.append_message(conv.roles[1], None)
|
| 232 |
+
|
| 233 |
+
prompt = conv.get_prompt()
|
| 234 |
+
headers = {"User-Agent": "LLaVA-Med Client"}
|
| 235 |
+
pload = {
|
| 236 |
+
"model": model_name,
|
| 237 |
+
"prompt": prompt,
|
| 238 |
+
"max_new_tokens": 150, # Reduce this since we only need one letter
|
| 239 |
+
"temperature": 0.5, # Lower temperature for more focused responses
|
| 240 |
+
"stop": conv.sep2,
|
| 241 |
+
"images": conv.get_images(),
|
| 242 |
+
"top_p": 1, # Lower top_p for more focused sampling
|
| 243 |
+
"frequency_penalty": 0.0,
|
| 244 |
+
"presence_penalty": 0.0,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
max_retries = 3
|
| 248 |
+
retry_delay = 5
|
| 249 |
+
response_text = None
|
| 250 |
+
|
| 251 |
+
for attempt in range(max_retries):
|
| 252 |
+
try:
|
| 253 |
+
response = requests.post(
|
| 254 |
+
worker_addr + "/worker_generate_stream",
|
| 255 |
+
headers=headers,
|
| 256 |
+
json=pload,
|
| 257 |
+
stream=True,
|
| 258 |
+
timeout=30,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
complete_output = ""
|
| 262 |
+
for chunk in response.iter_lines(
|
| 263 |
+
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
| 264 |
+
):
|
| 265 |
+
if chunk:
|
| 266 |
+
data = json.loads(chunk.decode("utf-8"))
|
| 267 |
+
if data["error_code"] == 0:
|
| 268 |
+
output = data["text"].split("[/INST]")[-1]
|
| 269 |
+
complete_output = output
|
| 270 |
+
else:
|
| 271 |
+
print(f"\nError: {data['text']} (error_code: {data['error_code']})")
|
| 272 |
+
if attempt < max_retries - 1:
|
| 273 |
+
time.sleep(retry_delay)
|
| 274 |
+
break
|
| 275 |
+
return None, None
|
| 276 |
+
|
| 277 |
+
if complete_output:
|
| 278 |
+
response_text = complete_output
|
| 279 |
+
break
|
| 280 |
+
|
| 281 |
+
except (requests.exceptions.RequestException, json.JSONDecodeError) as e:
|
| 282 |
+
if attempt < max_retries - 1:
|
| 283 |
+
print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...")
|
| 284 |
+
time.sleep(retry_delay)
|
| 285 |
+
else:
|
| 286 |
+
print(f"\nFailed after {max_retries} attempts: {str(e)}")
|
| 287 |
+
return None, None
|
| 288 |
+
|
| 289 |
+
duration = time.time() - start_time
|
| 290 |
+
|
| 291 |
+
if raw_output:
|
| 292 |
+
inference_details = {
|
| 293 |
+
"raw_output": response_text,
|
| 294 |
+
"validated_answer": validate_answer(response_text),
|
| 295 |
+
"duration": duration,
|
| 296 |
+
"prompt": prompt,
|
| 297 |
+
"system_prompt": system_prompt,
|
| 298 |
+
"image_paths": image_paths,
|
| 299 |
+
"payload": pload,
|
| 300 |
+
}
|
| 301 |
+
return inference_details
|
| 302 |
+
|
| 303 |
+
return validate_answer(response_text), duration
|
| 304 |
+
|
| 305 |
+
except Exception as e:
|
| 306 |
+
print(f"Error in inference request: {str(e)}")
|
| 307 |
+
return None, None
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| 311 |
+
"""Remove image-related and large data from the payload to keep the log lean.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
payload: Original request payload dictionary
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Cleaned payload dictionary with large data removed
|
| 318 |
+
"""
|
| 319 |
+
if not payload:
|
| 320 |
+
return None
|
| 321 |
+
|
| 322 |
+
# Create a copy of the payload to avoid modifying the original
|
| 323 |
+
cleaned_payload = payload.copy()
|
| 324 |
+
|
| 325 |
+
# Remove large or sensitive data
|
| 326 |
+
if "images" in cleaned_payload:
|
| 327 |
+
del cleaned_payload["images"]
|
| 328 |
+
|
| 329 |
+
return cleaned_payload
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def main():
|
| 333 |
+
parser = argparse.ArgumentParser()
|
| 334 |
+
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
| 335 |
+
parser.add_argument("--worker-address", type=str)
|
| 336 |
+
parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b")
|
| 337 |
+
parser.add_argument("--output-dir", type=str, default="benchmark_results")
|
| 338 |
+
parser.add_argument(
|
| 339 |
+
"--raw-output", action="store_true", help="Return raw model output without validation"
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--num-cases",
|
| 343 |
+
type=int,
|
| 344 |
+
help="Number of cases to process if looking at raw outputs",
|
| 345 |
+
default=2,
|
| 346 |
+
)
|
| 347 |
+
args = parser.parse_args()
|
| 348 |
+
|
| 349 |
+
# Setup output directory
|
| 350 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 351 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 352 |
+
|
| 353 |
+
# Setup live logging files
|
| 354 |
+
live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json")
|
| 355 |
+
final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json")
|
| 356 |
+
|
| 357 |
+
# Initialize live log file
|
| 358 |
+
with open(live_log_filename, "w") as live_log_file:
|
| 359 |
+
live_log_file.write("[\n") # Start of JSON array
|
| 360 |
+
|
| 361 |
+
# Setup logging
|
| 362 |
+
logging.basicConfig(
|
| 363 |
+
filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"),
|
| 364 |
+
level=logging.INFO,
|
| 365 |
+
format="%(message)s",
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Get worker address
|
| 369 |
+
if args.worker_address:
|
| 370 |
+
worker_addr = args.worker_address
|
| 371 |
+
else:
|
| 372 |
+
try:
|
| 373 |
+
requests.post(args.controller_address + "/refresh_all_workers")
|
| 374 |
+
ret = requests.post(args.controller_address + "/list_models")
|
| 375 |
+
models = ret.json()["models"]
|
| 376 |
+
ret = requests.post(
|
| 377 |
+
args.controller_address + "/get_worker_address", json={"model": args.model_name}
|
| 378 |
+
)
|
| 379 |
+
worker_addr = ret.json()["address"]
|
| 380 |
+
print(f"Worker address: {worker_addr}")
|
| 381 |
+
except requests.exceptions.RequestException as e:
|
| 382 |
+
print(f"Failed to connect to controller: {e}")
|
| 383 |
+
return
|
| 384 |
+
|
| 385 |
+
if worker_addr == "":
|
| 386 |
+
print("No available worker")
|
| 387 |
+
return
|
| 388 |
+
|
| 389 |
+
# Load cases with local paths
|
| 390 |
+
with open("MedMAX/data/updated_cases.json", "r") as file:
|
| 391 |
+
data = json.load(file)
|
| 392 |
+
|
| 393 |
+
total_cases, total_questions = count_total_questions()
|
| 394 |
+
print(f"\nStarting benchmark with {args.model_name}")
|
| 395 |
+
print(f"Found {total_cases} cases with {total_questions} total questions")
|
| 396 |
+
|
| 397 |
+
results = {
|
| 398 |
+
"model": args.model_name,
|
| 399 |
+
"timestamp": datetime.now().isoformat(),
|
| 400 |
+
"total_cases": total_cases,
|
| 401 |
+
"total_questions": total_questions,
|
| 402 |
+
"results": [],
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
cases_processed = 0
|
| 406 |
+
questions_processed = 0
|
| 407 |
+
correct_answers = 0
|
| 408 |
+
skipped_questions = 0
|
| 409 |
+
total_processed_entries = 0
|
| 410 |
+
|
| 411 |
+
# Process each case
|
| 412 |
+
for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
|
| 413 |
+
question_files = load_benchmark_questions(case_id)
|
| 414 |
+
if not question_files:
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
cases_processed += 1
|
| 418 |
+
for question_file in tqdm(
|
| 419 |
+
question_files, desc=f"Processing questions for case {case_id}", leave=False
|
| 420 |
+
):
|
| 421 |
+
with open(question_file, "r") as file:
|
| 422 |
+
question_data = json.load(file)
|
| 423 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
| 424 |
+
|
| 425 |
+
questions_processed += 1
|
| 426 |
+
|
| 427 |
+
# Get model's answer
|
| 428 |
+
inference_result = create_inference_request(
|
| 429 |
+
question_data,
|
| 430 |
+
case_details,
|
| 431 |
+
case_id,
|
| 432 |
+
question_id,
|
| 433 |
+
worker_addr,
|
| 434 |
+
args.model_name,
|
| 435 |
+
raw_output=True, # Always use raw output for detailed logging
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Handle skipped questions
|
| 439 |
+
if inference_result == ("skipped", 0.0):
|
| 440 |
+
skipped_questions += 1
|
| 441 |
+
print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)")
|
| 442 |
+
|
| 443 |
+
# Log skipped question
|
| 444 |
+
skipped_entry = {
|
| 445 |
+
"case_id": case_id,
|
| 446 |
+
"question_id": question_id,
|
| 447 |
+
"status": "skipped",
|
| 448 |
+
"reason": "No images found",
|
| 449 |
+
}
|
| 450 |
+
with open(live_log_filename, "a") as live_log_file:
|
| 451 |
+
json.dump(skipped_entry, live_log_file, indent=2)
|
| 452 |
+
live_log_file.write(",\n") # Add comma for next entry
|
| 453 |
+
|
| 454 |
+
continue
|
| 455 |
+
|
| 456 |
+
# Extract information
|
| 457 |
+
answer = inference_result["validated_answer"]
|
| 458 |
+
duration = inference_result["duration"]
|
| 459 |
+
|
| 460 |
+
# Prepare detailed logging entry
|
| 461 |
+
log_entry = {
|
| 462 |
+
"case_id": case_id,
|
| 463 |
+
"question_id": question_id,
|
| 464 |
+
"question": question_data["question"],
|
| 465 |
+
"correct_answer": question_data["answer"],
|
| 466 |
+
"raw_output": inference_result["raw_output"],
|
| 467 |
+
"validated_answer": answer,
|
| 468 |
+
"model_answer": answer,
|
| 469 |
+
"is_correct": answer == question_data["answer"] if answer else False,
|
| 470 |
+
"duration": duration,
|
| 471 |
+
"system_prompt": inference_result["system_prompt"],
|
| 472 |
+
"input_prompt": inference_result["prompt"],
|
| 473 |
+
"image_paths": inference_result["image_paths"],
|
| 474 |
+
"payload": clean_payload(inference_result["payload"]),
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
# Write to live log file
|
| 478 |
+
with open(live_log_filename, "a") as live_log_file:
|
| 479 |
+
json.dump(log_entry, live_log_file, indent=2)
|
| 480 |
+
live_log_file.write(",\n") # Add comma for next entry
|
| 481 |
+
|
| 482 |
+
# Print to console
|
| 483 |
+
print(f"\nCase {case_id}, Question {question_id}")
|
| 484 |
+
print(f"Model Answer: {answer}")
|
| 485 |
+
print(f"Correct Answer: {question_data['answer']}")
|
| 486 |
+
print(f"Time taken: {duration:.2f}s")
|
| 487 |
+
|
| 488 |
+
# Track correct answers
|
| 489 |
+
if answer == question_data["answer"]:
|
| 490 |
+
correct_answers += 1
|
| 491 |
+
|
| 492 |
+
# Append to results
|
| 493 |
+
results["results"].append(log_entry)
|
| 494 |
+
total_processed_entries += 1
|
| 495 |
+
|
| 496 |
+
# Optional: break if reached specified number of cases
|
| 497 |
+
if args.raw_output and cases_processed == args.num_cases:
|
| 498 |
+
break
|
| 499 |
+
|
| 500 |
+
# Optional: break if reached specified number of cases
|
| 501 |
+
if args.raw_output and cases_processed == args.num_cases:
|
| 502 |
+
break
|
| 503 |
+
|
| 504 |
+
# Close live log file
|
| 505 |
+
with open(live_log_filename, "a") as live_log_file:
|
| 506 |
+
# Remove trailing comma and close JSON array
|
| 507 |
+
live_log_file.seek(live_log_file.tell() - 2, 0) # Go back 2 chars to remove ',\n'
|
| 508 |
+
live_log_file.write("\n]")
|
| 509 |
+
|
| 510 |
+
# Calculate final statistics
|
| 511 |
+
results["summary"] = {
|
| 512 |
+
"cases_processed": cases_processed,
|
| 513 |
+
"questions_processed": questions_processed,
|
| 514 |
+
"total_processed_entries": total_processed_entries,
|
| 515 |
+
"correct_answers": correct_answers,
|
| 516 |
+
"skipped_questions": skipped_questions,
|
| 517 |
+
"accuracy": (
|
| 518 |
+
correct_answers / (questions_processed - skipped_questions)
|
| 519 |
+
if (questions_processed - skipped_questions) > 0
|
| 520 |
+
else 0
|
| 521 |
+
),
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
# Save final results
|
| 525 |
+
with open(final_results_filename, "w") as f:
|
| 526 |
+
json.dump(results, f, indent=2)
|
| 527 |
+
|
| 528 |
+
print(f"\nBenchmark Summary:")
|
| 529 |
+
print(f"Total Cases Processed: {cases_processed}")
|
| 530 |
+
print(f"Total Questions Processed: {questions_processed}")
|
| 531 |
+
print(f"Total Processed Entries: {total_processed_entries}")
|
| 532 |
+
print(f"Correct Answers: {correct_answers}")
|
| 533 |
+
print(f"Skipped Questions: {skipped_questions}")
|
| 534 |
+
print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%")
|
| 535 |
+
print(f"\nResults saved to {args.output_dir}")
|
| 536 |
+
print(f"Live log: {live_log_filename}")
|
| 537 |
+
print(f"Final results: {final_results_filename}")
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
if __name__ == "__main__":
|
| 541 |
+
main()
|
experiments/benchmark_medrax.ipynb
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import operator\n",
|
| 10 |
+
"import warnings\n",
|
| 11 |
+
"from typing import *\n",
|
| 12 |
+
"import traceback\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"import os\n",
|
| 15 |
+
"import torch\n",
|
| 16 |
+
"from dotenv import load_dotenv\n",
|
| 17 |
+
"from IPython.display import Image\n",
|
| 18 |
+
"from langgraph.checkpoint.memory import MemorySaver\n",
|
| 19 |
+
"from langgraph.graph import END, StateGraph\n",
|
| 20 |
+
"from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage\n",
|
| 21 |
+
"from langchain_openai import ChatOpenAI\n",
|
| 22 |
+
"from transformers import logging\n",
|
| 23 |
+
"import matplotlib.pyplot as plt\n",
|
| 24 |
+
"import numpy as np\n",
|
| 25 |
+
"import re\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"from medrax.agent import *\n",
|
| 28 |
+
"from medrax.tools import *\n",
|
| 29 |
+
"from medrax.utils import *\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"import json\n",
|
| 32 |
+
"import openai\n",
|
| 33 |
+
"import os\n",
|
| 34 |
+
"import glob\n",
|
| 35 |
+
"import time\n",
|
| 36 |
+
"import logging\n",
|
| 37 |
+
"from datetime import datetime\n",
|
| 38 |
+
"from tenacity import retry, wait_exponential, stop_after_attempt\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"warnings.filterwarnings(\"ignore\")\n",
|
| 41 |
+
"_ = load_dotenv()\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"# Setup directory paths\n",
|
| 45 |
+
"ROOT = \"set this directory to where MedRAX is, .e.g /home/MedRAX\"\n",
|
| 46 |
+
"PROMPT_FILE = f\"{ROOT}/medrax/docs/system_prompts.txt\"\n",
|
| 47 |
+
"BENCHMARK_FILE = f\"{ROOT}/benchmark/questions\"\n",
|
| 48 |
+
"MODEL_DIR = f\"set this to where the tool models are, e.g /home/models\"\n",
|
| 49 |
+
"FIGURES_DIR = f\"{ROOT}/benchmark/figures\"\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"model_name = \"medrax\"\n",
|
| 52 |
+
"temperature = 0.2\n",
|
| 53 |
+
"medrax_logs = f\"{ROOT}/experiments/medrax_logs\"\n",
|
| 54 |
+
"log_filename = f\"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n",
|
| 55 |
+
"logging.basicConfig(filename=log_filename, level=logging.INFO, format=\"%(message)s\", force=True)\n",
|
| 56 |
+
"device = \"cuda\""
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": 2,
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"def get_tools():\n",
|
| 66 |
+
" report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)\n",
|
| 67 |
+
" xray_classification_tool = ChestXRayClassifierTool(device=device)\n",
|
| 68 |
+
" segmentation_tool = ChestXRaySegmentationTool(device=device)\n",
|
| 69 |
+
" grounding_tool = XRayPhraseGroundingTool(\n",
|
| 70 |
+
" cache_dir=MODEL_DIR, temp_dir=\"temp\", device=device, load_in_8bit=True\n",
|
| 71 |
+
" )\n",
|
| 72 |
+
" xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)\n",
|
| 73 |
+
" llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" return [\n",
|
| 76 |
+
" report_tool,\n",
|
| 77 |
+
" xray_classification_tool,\n",
|
| 78 |
+
" segmentation_tool,\n",
|
| 79 |
+
" grounding_tool,\n",
|
| 80 |
+
" xray_vqa_tool,\n",
|
| 81 |
+
" llava_med_tool,\n",
|
| 82 |
+
" ]\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"def get_agent(tools):\n",
|
| 86 |
+
" prompts = load_prompts_from_file(PROMPT_FILE)\n",
|
| 87 |
+
" prompt = prompts[\"MEDICAL_ASSISTANT\"]\n",
|
| 88 |
+
"\n",
|
| 89 |
+
" checkpointer = MemorySaver()\n",
|
| 90 |
+
" model = ChatOpenAI(model=\"gpt-4o\", temperature=temperature, top_p=0.95)\n",
|
| 91 |
+
" agent = Agent(\n",
|
| 92 |
+
" model,\n",
|
| 93 |
+
" tools=tools,\n",
|
| 94 |
+
" log_tools=True,\n",
|
| 95 |
+
" log_dir=\"logs\",\n",
|
| 96 |
+
" system_prompt=prompt,\n",
|
| 97 |
+
" checkpointer=checkpointer,\n",
|
| 98 |
+
" )\n",
|
| 99 |
+
" thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
|
| 100 |
+
" return agent, thread\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"def run_medrax(agent, thread, prompt, image_urls=[]):\n",
|
| 104 |
+
" messages = [\n",
|
| 105 |
+
" HumanMessage(\n",
|
| 106 |
+
" content=[\n",
|
| 107 |
+
" {\"type\": \"text\", \"text\": prompt},\n",
|
| 108 |
+
" ]\n",
|
| 109 |
+
" + [{\"type\": \"image_url\", \"image_url\": {\"url\": image_url}} for image_url in image_urls]\n",
|
| 110 |
+
" )\n",
|
| 111 |
+
" ]\n",
|
| 112 |
+
"\n",
|
| 113 |
+
" final_response = None\n",
|
| 114 |
+
" for event in agent.workflow.stream({\"messages\": messages}, thread):\n",
|
| 115 |
+
" for v in event.values():\n",
|
| 116 |
+
" final_response = v\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" final_response = final_response[\"messages\"][-1].content.strip()\n",
|
| 119 |
+
" agent_state = agent.workflow.get_state(thread)\n",
|
| 120 |
+
"\n",
|
| 121 |
+
" return final_response, str(agent_state)"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": 3,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [],
|
| 129 |
+
"source": [
|
| 130 |
+
"def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):\n",
|
| 131 |
+
" # Parse required figures\n",
|
| 132 |
+
" try:\n",
|
| 133 |
+
" # Try multiple ways of parsing figures\n",
|
| 134 |
+
" if isinstance(question_data[\"figures\"], str):\n",
|
| 135 |
+
" try:\n",
|
| 136 |
+
" required_figures = json.loads(question_data[\"figures\"])\n",
|
| 137 |
+
" except json.JSONDecodeError:\n",
|
| 138 |
+
" required_figures = [question_data[\"figures\"]]\n",
|
| 139 |
+
" elif isinstance(question_data[\"figures\"], list):\n",
|
| 140 |
+
" required_figures = question_data[\"figures\"]\n",
|
| 141 |
+
" else:\n",
|
| 142 |
+
" required_figures = [str(question_data[\"figures\"])]\n",
|
| 143 |
+
" except Exception as e:\n",
|
| 144 |
+
" print(f\"Error parsing figures: {e}\")\n",
|
| 145 |
+
" required_figures = []\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" # Ensure each figure starts with \"Figure \"\n",
|
| 148 |
+
" required_figures = [\n",
|
| 149 |
+
" fig if fig.startswith(\"Figure \") else f\"Figure {fig}\" for fig in required_figures\n",
|
| 150 |
+
" ]\n",
|
| 151 |
+
"\n",
|
| 152 |
+
" subfigures = []\n",
|
| 153 |
+
" for figure in required_figures:\n",
|
| 154 |
+
" # Handle both regular figures and those with letter suffixes\n",
|
| 155 |
+
" base_figure_num = \"\".join(filter(str.isdigit, figure))\n",
|
| 156 |
+
" figure_letter = \"\".join(filter(str.isalpha, figure.split()[-1])) or None\n",
|
| 157 |
+
"\n",
|
| 158 |
+
" # Find matching figures in case details\n",
|
| 159 |
+
" matching_figures = [\n",
|
| 160 |
+
" case_figure\n",
|
| 161 |
+
" for case_figure in case_details.get(\"figures\", [])\n",
|
| 162 |
+
" if case_figure[\"number\"] == f\"Figure {base_figure_num}\"\n",
|
| 163 |
+
" ]\n",
|
| 164 |
+
"\n",
|
| 165 |
+
" if not matching_figures:\n",
|
| 166 |
+
" print(f\"No matching figure found for {figure} in case {case_id}\")\n",
|
| 167 |
+
" continue\n",
|
| 168 |
+
"\n",
|
| 169 |
+
" for case_figure in matching_figures:\n",
|
| 170 |
+
" # If a specific letter is specified, filter subfigures\n",
|
| 171 |
+
" if figure_letter:\n",
|
| 172 |
+
" matching_subfigures = [\n",
|
| 173 |
+
" subfig\n",
|
| 174 |
+
" for subfig in case_figure.get(\"subfigures\", [])\n",
|
| 175 |
+
" if subfig.get(\"number\", \"\").lower().endswith(figure_letter.lower())\n",
|
| 176 |
+
" or subfig.get(\"label\", \"\").lower() == figure_letter.lower()\n",
|
| 177 |
+
" ]\n",
|
| 178 |
+
" subfigures.extend(matching_subfigures)\n",
|
| 179 |
+
" else:\n",
|
| 180 |
+
" # If no letter specified, add all subfigures\n",
|
| 181 |
+
" subfigures.extend(case_figure.get(\"subfigures\", []))\n",
|
| 182 |
+
"\n",
|
| 183 |
+
" # Add images to content\n",
|
| 184 |
+
" figure_prompt = \"\"\n",
|
| 185 |
+
" image_urls = []\n",
|
| 186 |
+
"\n",
|
| 187 |
+
" for subfig in subfigures:\n",
|
| 188 |
+
" if \"number\" in subfig:\n",
|
| 189 |
+
" subfig_number = subfig[\"number\"].lower().strip().replace(\" \", \"_\") + \".jpg\"\n",
|
| 190 |
+
" subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)\n",
|
| 191 |
+
" figure_prompt += f\"{subfig_number} located at {subfig_path}\\n\"\n",
|
| 192 |
+
" if \"url\" in subfig:\n",
|
| 193 |
+
" image_urls.append(subfig[\"url\"])\n",
|
| 194 |
+
" else:\n",
|
| 195 |
+
" print(f\"Subfigure missing URL: {subfig}\")\n",
|
| 196 |
+
"\n",
|
| 197 |
+
" prompt = (\n",
|
| 198 |
+
" f\"Answer this question correctly using chain of thought reasoning and \"\n",
|
| 199 |
+
" \"carefully evaluating choices. Solve using our own vision and reasoning and then\"\n",
|
| 200 |
+
" \"use tools to complement your reasoning. Trust your own judgement over any tools.\\n\"\n",
|
| 201 |
+
" f\"{question_data['question']}\\n{figure_prompt}\"\n",
|
| 202 |
+
" )\n",
|
| 203 |
+
"\n",
|
| 204 |
+
" try:\n",
|
| 205 |
+
" start_time = time.time()\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" final_response, agent_state = run_medrax(\n",
|
| 208 |
+
" agent=agent, thread=thread, prompt=prompt, image_urls=image_urls\n",
|
| 209 |
+
" )\n",
|
| 210 |
+
" model_answer, agent_state = run_medrax(\n",
|
| 211 |
+
" agent=agent,\n",
|
| 212 |
+
" thread=thread,\n",
|
| 213 |
+
" prompt=\"If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)\",\n",
|
| 214 |
+
" )\n",
|
| 215 |
+
" duration = time.time() - start_time\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" log_entry = {\n",
|
| 218 |
+
" \"case_id\": case_id,\n",
|
| 219 |
+
" \"question_id\": question_id,\n",
|
| 220 |
+
" \"timestamp\": datetime.now().isoformat(),\n",
|
| 221 |
+
" \"model\": model_name,\n",
|
| 222 |
+
" \"temperature\": temperature,\n",
|
| 223 |
+
" \"duration\": round(duration, 2),\n",
|
| 224 |
+
" \"usage\": \"\",\n",
|
| 225 |
+
" \"cost\": 0,\n",
|
| 226 |
+
" \"raw_response\": final_response,\n",
|
| 227 |
+
" \"model_answer\": model_answer.strip(),\n",
|
| 228 |
+
" \"correct_answer\": question_data[\"answer\"][0],\n",
|
| 229 |
+
" \"input\": {\n",
|
| 230 |
+
" \"messages\": prompt,\n",
|
| 231 |
+
" \"question_data\": {\n",
|
| 232 |
+
" \"question\": question_data[\"question\"],\n",
|
| 233 |
+
" \"explanation\": question_data[\"explanation\"],\n",
|
| 234 |
+
" \"metadata\": question_data.get(\"metadata\", {}),\n",
|
| 235 |
+
" \"figures\": question_data[\"figures\"],\n",
|
| 236 |
+
" },\n",
|
| 237 |
+
" \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
|
| 238 |
+
" \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
|
| 239 |
+
" },\n",
|
| 240 |
+
" \"agent_state\": agent_state,\n",
|
| 241 |
+
" }\n",
|
| 242 |
+
" logging.info(json.dumps(log_entry))\n",
|
| 243 |
+
" return final_response, model_answer.strip()\n",
|
| 244 |
+
"\n",
|
| 245 |
+
" except Exception as e:\n",
|
| 246 |
+
" log_entry = {\n",
|
| 247 |
+
" \"case_id\": case_id,\n",
|
| 248 |
+
" \"question_id\": question_id,\n",
|
| 249 |
+
" \"timestamp\": datetime.now().isoformat(),\n",
|
| 250 |
+
" \"model\": model_name,\n",
|
| 251 |
+
" \"temperature\": temperature,\n",
|
| 252 |
+
" \"status\": \"error\",\n",
|
| 253 |
+
" \"error\": str(e),\n",
|
| 254 |
+
" \"cost\": 0,\n",
|
| 255 |
+
" \"input\": {\n",
|
| 256 |
+
" \"messages\": prompt,\n",
|
| 257 |
+
" \"question_data\": {\n",
|
| 258 |
+
" \"question\": question_data[\"question\"],\n",
|
| 259 |
+
" \"explanation\": question_data[\"explanation\"],\n",
|
| 260 |
+
" \"metadata\": question_data.get(\"metadata\", {}),\n",
|
| 261 |
+
" \"figures\": question_data[\"figures\"],\n",
|
| 262 |
+
" },\n",
|
| 263 |
+
" \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
|
| 264 |
+
" \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
|
| 265 |
+
" },\n",
|
| 266 |
+
" }\n",
|
| 267 |
+
" logging.info(json.dumps(log_entry))\n",
|
| 268 |
+
" print(f\"Error processing case {case_id}, question {question_id}: {str(e)}\")\n",
|
| 269 |
+
" return \"\", \"\"\n",
|
| 270 |
+
"\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"def load_benchmark_questions(case_id):\n",
|
| 273 |
+
" benchmark_dir = \"../benchmark/questions\"\n",
|
| 274 |
+
" return glob.glob(f\"{benchmark_dir}/{case_id}/{case_id}_*.json\")\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"def count_total_questions():\n",
|
| 278 |
+
" total_cases = len(glob.glob(\"../benchmark/questions/*\"))\n",
|
| 279 |
+
" total_questions = sum(\n",
|
| 280 |
+
" len(glob.glob(f\"../benchmark/questions/{case_id}/*.json\"))\n",
|
| 281 |
+
" for case_id in os.listdir(\"../benchmark/questions\")\n",
|
| 282 |
+
" )\n",
|
| 283 |
+
" return total_cases, total_questions\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"\n",
|
| 286 |
+
"def main(tools):\n",
|
| 287 |
+
" with open(\"../data/eurorad_metadata.json\", \"r\") as file:\n",
|
| 288 |
+
" data = json.load(file)\n",
|
| 289 |
+
"\n",
|
| 290 |
+
" total_cases, total_questions = count_total_questions()\n",
|
| 291 |
+
" cases_processed = 0\n",
|
| 292 |
+
" questions_processed = 0\n",
|
| 293 |
+
" skipped_questions = 0\n",
|
| 294 |
+
"\n",
|
| 295 |
+
" print(f\"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\\n\")\n",
|
| 296 |
+
"\n",
|
| 297 |
+
" for case_id, case_details in data.items():\n",
|
| 298 |
+
" if int(case_details[\"case_id\"]) <= 17158:\n",
|
| 299 |
+
" continue\n",
|
| 300 |
+
"\n",
|
| 301 |
+
" print(f\"----------------------------------------------------------------\")\n",
|
| 302 |
+
" agent, thread = get_agent(tools)\n",
|
| 303 |
+
"\n",
|
| 304 |
+
" question_files = load_benchmark_questions(case_id)\n",
|
| 305 |
+
" if not question_files:\n",
|
| 306 |
+
" continue\n",
|
| 307 |
+
"\n",
|
| 308 |
+
" cases_processed += 1\n",
|
| 309 |
+
" for question_file in question_files:\n",
|
| 310 |
+
" with open(question_file, \"r\") as file:\n",
|
| 311 |
+
" question_data = json.load(file)\n",
|
| 312 |
+
" question_id = os.path.basename(question_file).split(\".\")[0]\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" # agent, thread = get_agent(tools)\n",
|
| 315 |
+
" questions_processed += 1\n",
|
| 316 |
+
" final_response, model_answer = create_multimodal_request(\n",
|
| 317 |
+
" question_data, case_details, case_id, question_id, agent, thread\n",
|
| 318 |
+
" )\n",
|
| 319 |
+
"\n",
|
| 320 |
+
" # Handle cases where response is None\n",
|
| 321 |
+
" if final_response is None:\n",
|
| 322 |
+
" skipped_questions += 1\n",
|
| 323 |
+
" print(f\"Skipped question: Case ID {case_id}, Question ID {question_id}\")\n",
|
| 324 |
+
" continue\n",
|
| 325 |
+
"\n",
|
| 326 |
+
" print(\n",
|
| 327 |
+
" f\"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}\"\n",
|
| 328 |
+
" )\n",
|
| 329 |
+
" print(f\"Case ID: {case_id}\")\n",
|
| 330 |
+
" print(f\"Question ID: {question_id}\")\n",
|
| 331 |
+
" print(f\"Final Response: {final_response}\")\n",
|
| 332 |
+
" print(f\"Model Answer: {model_answer}\")\n",
|
| 333 |
+
" print(f\"Correct Answer: {question_data['answer']}\")\n",
|
| 334 |
+
" print(f\"----------------------------------------------------------------\\n\")\n",
|
| 335 |
+
"\n",
|
| 336 |
+
" print(f\"\\nBenchmark Summary:\")\n",
|
| 337 |
+
" print(f\"Total Cases Processed: {cases_processed}\")\n",
|
| 338 |
+
" print(f\"Total Questions Processed: {questions_processed}\")\n",
|
| 339 |
+
" print(f\"Total Questions Skipped: {skipped_questions}\")"
|
| 340 |
+
]
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
"cell_type": "code",
|
| 344 |
+
"execution_count": null,
|
| 345 |
+
"metadata": {},
|
| 346 |
+
"outputs": [],
|
| 347 |
+
"source": [
|
| 348 |
+
"tools = get_tools()\n",
|
| 349 |
+
"main(tools)"
|
| 350 |
+
]
|
| 351 |
+
}
|
| 352 |
+
],
|
| 353 |
+
"metadata": {
|
| 354 |
+
"kernelspec": {
|
| 355 |
+
"display_name": "medmax",
|
| 356 |
+
"language": "python",
|
| 357 |
+
"name": "python3"
|
| 358 |
+
},
|
| 359 |
+
"language_info": {
|
| 360 |
+
"codemirror_mode": {
|
| 361 |
+
"name": "ipython",
|
| 362 |
+
"version": 3
|
| 363 |
+
},
|
| 364 |
+
"file_extension": ".py",
|
| 365 |
+
"mimetype": "text/x-python",
|
| 366 |
+
"name": "python",
|
| 367 |
+
"nbconvert_exporter": "python",
|
| 368 |
+
"pygments_lexer": "ipython3",
|
| 369 |
+
"version": "3.10.16"
|
| 370 |
+
}
|
| 371 |
+
},
|
| 372 |
+
"nbformat": 4,
|
| 373 |
+
"nbformat_minor": 2
|
| 374 |
+
}
|
experiments/chexbench_gpt4.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import openai
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import base64
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import time
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from typing import Dict, List, Optional, Union, Any
|
| 11 |
+
|
| 12 |
+
# Configuration constants
|
| 13 |
+
DEBUG_MODE = False
|
| 14 |
+
OUTPUT_DIR = "results"
|
| 15 |
+
MODEL_NAME = "gpt-4o-2024-05-13"
|
| 16 |
+
TEMPERATURE = 0.2
|
| 17 |
+
SUBSET = "Visual Question Answering"
|
| 18 |
+
|
| 19 |
+
# Set up logging configuration
|
| 20 |
+
logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO
|
| 21 |
+
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_mime_type(file_path: str) -> str:
|
| 26 |
+
"""
|
| 27 |
+
Determine MIME type based on file extension.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
file_path (str): Path to the file
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
str: MIME type string for the file
|
| 34 |
+
"""
|
| 35 |
+
extension = os.path.splitext(file_path)[1].lower()
|
| 36 |
+
mime_types = {
|
| 37 |
+
".png": "image/png",
|
| 38 |
+
".jpg": "image/jpeg",
|
| 39 |
+
".jpeg": "image/jpeg",
|
| 40 |
+
".gif": "image/gif",
|
| 41 |
+
}
|
| 42 |
+
return mime_types.get(extension, "application/octet-stream")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def encode_image(image_path: str) -> str:
|
| 46 |
+
"""
|
| 47 |
+
Encode image to base64 with extensive error checking.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
image_path (str): Path to the image file
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
str: Base64 encoded image string
|
| 54 |
+
|
| 55 |
+
Raises:
|
| 56 |
+
FileNotFoundError: If image file does not exist
|
| 57 |
+
ValueError: If image file is empty or too large
|
| 58 |
+
Exception: For other image processing errors
|
| 59 |
+
"""
|
| 60 |
+
logger.debug(f"Attempting to read image from: {image_path}")
|
| 61 |
+
if not os.path.exists(image_path):
|
| 62 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
| 63 |
+
|
| 64 |
+
# Add check for file size
|
| 65 |
+
file_size = os.path.getsize(image_path)
|
| 66 |
+
if file_size > 20 * 1024 * 1024: # 20MB limit
|
| 67 |
+
raise ValueError("Image file size exceeds 20MB limit")
|
| 68 |
+
if file_size == 0:
|
| 69 |
+
raise ValueError("Image file is empty")
|
| 70 |
+
logger.debug(f"Image file size: {file_size / 1024:.2f} KB")
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
from PIL import Image
|
| 74 |
+
|
| 75 |
+
# Try to open and verify the image
|
| 76 |
+
with Image.open(image_path) as img:
|
| 77 |
+
# Get image details
|
| 78 |
+
width, height = img.size
|
| 79 |
+
format = img.format
|
| 80 |
+
mode = img.mode
|
| 81 |
+
logger.debug(
|
| 82 |
+
f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if format not in ["PNG", "JPEG", "GIF"]:
|
| 86 |
+
raise ValueError(f"Unsupported image format: {format}")
|
| 87 |
+
|
| 88 |
+
with open(image_path, "rb") as image_file:
|
| 89 |
+
# Read the first few bytes to verify it's a valid PNG
|
| 90 |
+
header = image_file.read(8)
|
| 91 |
+
# if header != b'\x89PNG\r\n\x1a\n':
|
| 92 |
+
# logger.warning("File does not have a valid PNG signature")
|
| 93 |
+
|
| 94 |
+
# Reset file pointer and read entire file
|
| 95 |
+
image_file.seek(0)
|
| 96 |
+
encoded = base64.b64encode(image_file.read()).decode("utf-8")
|
| 97 |
+
encoded_length = len(encoded)
|
| 98 |
+
logger.debug(f"Base64 encoded length: {encoded_length} characters")
|
| 99 |
+
|
| 100 |
+
# Verify the encoded string is not empty and starts correctly
|
| 101 |
+
if encoded_length == 0:
|
| 102 |
+
raise ValueError("Base64 encoding produced empty string")
|
| 103 |
+
if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"):
|
| 104 |
+
logger.warning("Base64 string doesn't start with expected JPEG or PNG header")
|
| 105 |
+
|
| 106 |
+
return encoded
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error reading/encoding image: {str(e)}")
|
| 109 |
+
raise
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_single_request(
|
| 113 |
+
image_path: str, question: str, options: Dict[str, str]
|
| 114 |
+
) -> List[Dict[str, Any]]:
|
| 115 |
+
"""
|
| 116 |
+
Create a single API request with image and question.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
image_path (str): Path to the image file
|
| 120 |
+
question (str): Question text
|
| 121 |
+
options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1'
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
List[Dict[str, Any]]: List of message dictionaries for the API request
|
| 125 |
+
|
| 126 |
+
Raises:
|
| 127 |
+
Exception: For errors in request creation
|
| 128 |
+
"""
|
| 129 |
+
if DEBUG_MODE:
|
| 130 |
+
logger.debug("Creating API request...")
|
| 131 |
+
|
| 132 |
+
prompt = f"""Given the following medical examination question:
|
| 133 |
+
Please answer this multiple choice question:
|
| 134 |
+
|
| 135 |
+
Question: {question}
|
| 136 |
+
|
| 137 |
+
Options:
|
| 138 |
+
A) {options['option_0']}
|
| 139 |
+
B) {options['option_1']}
|
| 140 |
+
|
| 141 |
+
Base your answer only on the provided image and select either A or B."""
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
encoded_image = encode_image(image_path)
|
| 145 |
+
mime_type = get_mime_type(image_path)
|
| 146 |
+
|
| 147 |
+
if DEBUG_MODE:
|
| 148 |
+
logger.debug(f"Image encoded with MIME type: {mime_type}")
|
| 149 |
+
|
| 150 |
+
messages = [
|
| 151 |
+
{
|
| 152 |
+
"role": "system",
|
| 153 |
+
"content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.",
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"role": "user",
|
| 157 |
+
"content": [
|
| 158 |
+
{"type": "text", "text": prompt},
|
| 159 |
+
{
|
| 160 |
+
"type": "image_url",
|
| 161 |
+
"image_url": {"url": f"data:{mime_type};base64,{encoded_image}"},
|
| 162 |
+
},
|
| 163 |
+
],
|
| 164 |
+
},
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
if DEBUG_MODE:
|
| 168 |
+
log_messages = json.loads(json.dumps(messages))
|
| 169 |
+
log_messages[1]["content"][1]["image_url"][
|
| 170 |
+
"url"
|
| 171 |
+
] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]"
|
| 172 |
+
logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}")
|
| 173 |
+
|
| 174 |
+
return messages
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"Error creating request: {str(e)}")
|
| 178 |
+
raise
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def check_answer(model_answer: str, correct_answer: int) -> bool:
|
| 182 |
+
"""
|
| 183 |
+
Check if the model's answer matches the correct answer.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
model_answer (str): The model's answer (A or B)
|
| 187 |
+
correct_answer (int): The correct answer index (0 for A, 1 for B)
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
bool: True if answer is correct, False otherwise
|
| 191 |
+
"""
|
| 192 |
+
if not isinstance(model_answer, str):
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
# Clean the model answer to get just the letter
|
| 196 |
+
model_letter = model_answer.strip().upper()
|
| 197 |
+
if model_letter.startswith("A"):
|
| 198 |
+
model_index = 0
|
| 199 |
+
elif model_letter.startswith("B"):
|
| 200 |
+
model_index = 1
|
| 201 |
+
else:
|
| 202 |
+
return False
|
| 203 |
+
|
| 204 |
+
return model_index == correct_answer
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str:
|
| 208 |
+
"""
|
| 209 |
+
Save results to a JSON file with timestamp.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
| 213 |
+
output_dir (str): Directory to save results
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
str: Path to the saved file
|
| 217 |
+
"""
|
| 218 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 219 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 220 |
+
output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json")
|
| 221 |
+
|
| 222 |
+
with open(output_file, "w") as f:
|
| 223 |
+
json.dump(results, f, indent=2)
|
| 224 |
+
|
| 225 |
+
logger.info(f"Batch results saved to {output_file}")
|
| 226 |
+
return output_file
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]:
|
| 230 |
+
"""
|
| 231 |
+
Calculate accuracy from results, handling error cases.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total)
|
| 238 |
+
"""
|
| 239 |
+
if not results:
|
| 240 |
+
return 0.0, 0, 0
|
| 241 |
+
|
| 242 |
+
total = len(results)
|
| 243 |
+
valid_results = [r for r in results if "output" in r]
|
| 244 |
+
correct = sum(
|
| 245 |
+
1 for result in valid_results if result.get("output", {}).get("is_correct", False)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
accuracy = (correct / total * 100) if total > 0 else 0
|
| 249 |
+
return accuracy, correct, total
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float:
|
| 253 |
+
"""
|
| 254 |
+
Calculate accuracy for the current batch.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
float: Accuracy percentage for the batch
|
| 261 |
+
"""
|
| 262 |
+
valid_results = [r for r in results if "output" in r]
|
| 263 |
+
if not valid_results:
|
| 264 |
+
return 0.0
|
| 265 |
+
return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def process_batch(
|
| 269 |
+
data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50
|
| 270 |
+
) -> List[Dict[str, Any]]:
|
| 271 |
+
"""
|
| 272 |
+
Process a batch of examples and return results.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
data (List[Dict[str, Any]]): List of data items to process
|
| 276 |
+
client (openai.OpenAI): OpenAI client instance
|
| 277 |
+
start_idx (int, optional): Starting index for batch. Defaults to 0
|
| 278 |
+
batch_size (int, optional): Size of batch to process. Defaults to 50
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
List[Dict[str, Any]]: List of processed results
|
| 282 |
+
"""
|
| 283 |
+
batch_results = []
|
| 284 |
+
end_idx = min(start_idx + batch_size, len(data))
|
| 285 |
+
|
| 286 |
+
pbar = tqdm(
|
| 287 |
+
range(start_idx, end_idx),
|
| 288 |
+
desc=f"Processing batch {start_idx//batch_size + 1}",
|
| 289 |
+
unit="example",
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
for index in pbar:
|
| 293 |
+
vqa_item = data[index]
|
| 294 |
+
options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]}
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
messages = create_single_request(
|
| 298 |
+
image_path=vqa_item["image_path"], question=vqa_item["question"], options=options
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
response = client.chat.completions.create(
|
| 302 |
+
model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
model_answer = response.choices[0].message.content.strip()
|
| 306 |
+
is_correct = check_answer(model_answer, vqa_item["answer"])
|
| 307 |
+
|
| 308 |
+
result = {
|
| 309 |
+
"timestamp": datetime.now().isoformat(),
|
| 310 |
+
"example_index": index,
|
| 311 |
+
"input": {
|
| 312 |
+
"question": vqa_item["question"],
|
| 313 |
+
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
|
| 314 |
+
"image_path": vqa_item["image_path"],
|
| 315 |
+
},
|
| 316 |
+
"output": {
|
| 317 |
+
"model_answer": model_answer,
|
| 318 |
+
"correct_answer": "A" if vqa_item["answer"] == 0 else "B",
|
| 319 |
+
"is_correct": is_correct,
|
| 320 |
+
"usage": {
|
| 321 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
| 322 |
+
"completion_tokens": response.usage.completion_tokens,
|
| 323 |
+
"total_tokens": response.usage.total_tokens,
|
| 324 |
+
},
|
| 325 |
+
},
|
| 326 |
+
}
|
| 327 |
+
batch_results.append(result)
|
| 328 |
+
|
| 329 |
+
# Update progress bar with current accuracy
|
| 330 |
+
current_accuracy = calculate_batch_accuracy(batch_results)
|
| 331 |
+
pbar.set_description(
|
| 332 |
+
f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% "
|
| 333 |
+
f"({len(batch_results)}/{index-start_idx+1} examples)"
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
except Exception as e:
|
| 337 |
+
error_result = {
|
| 338 |
+
"timestamp": datetime.now().isoformat(),
|
| 339 |
+
"example_index": index,
|
| 340 |
+
"error": str(e),
|
| 341 |
+
"input": {
|
| 342 |
+
"question": vqa_item["question"],
|
| 343 |
+
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
|
| 344 |
+
"image_path": vqa_item["image_path"],
|
| 345 |
+
},
|
| 346 |
+
}
|
| 347 |
+
batch_results.append(error_result)
|
| 348 |
+
if DEBUG_MODE:
|
| 349 |
+
pbar.write(f"Error processing example {index}: {str(e)}")
|
| 350 |
+
|
| 351 |
+
time.sleep(1) # Rate limiting
|
| 352 |
+
|
| 353 |
+
return batch_results
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def main() -> None:
|
| 357 |
+
"""
|
| 358 |
+
Main function to process the entire dataset.
|
| 359 |
+
|
| 360 |
+
Raises:
|
| 361 |
+
ValueError: If OPENAI_API_KEY is not set
|
| 362 |
+
Exception: For other processing errors
|
| 363 |
+
"""
|
| 364 |
+
logger.info("Starting full dataset processing...")
|
| 365 |
+
json_path = "../data/chexbench_updated.json"
|
| 366 |
+
|
| 367 |
+
try:
|
| 368 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 369 |
+
if not api_key:
|
| 370 |
+
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
| 371 |
+
client = openai.OpenAI(api_key=api_key)
|
| 372 |
+
|
| 373 |
+
with open(json_path, "r") as f:
|
| 374 |
+
data = json.load(f)
|
| 375 |
+
|
| 376 |
+
subset_data = data[SUBSET]
|
| 377 |
+
total_examples = len(subset_data)
|
| 378 |
+
logger.info(f"Found {total_examples} examples in {SUBSET} subset")
|
| 379 |
+
|
| 380 |
+
all_results = []
|
| 381 |
+
batch_size = 50 # Process in batches of 50 examples
|
| 382 |
+
|
| 383 |
+
# Process all examples in batches
|
| 384 |
+
for start_idx in range(0, total_examples, batch_size):
|
| 385 |
+
batch_results = process_batch(subset_data, client, start_idx, batch_size)
|
| 386 |
+
all_results.extend(batch_results)
|
| 387 |
+
|
| 388 |
+
# Save intermediate results after each batch
|
| 389 |
+
output_file = save_results_to_json(all_results, OUTPUT_DIR)
|
| 390 |
+
|
| 391 |
+
# Calculate and log overall progress
|
| 392 |
+
overall_accuracy, correct, total = calculate_accuracy(all_results)
|
| 393 |
+
logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed")
|
| 394 |
+
logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)")
|
| 395 |
+
|
| 396 |
+
logger.info("Processing completed!")
|
| 397 |
+
logger.info(f"Final results saved to: {output_file}")
|
| 398 |
+
|
| 399 |
+
except Exception as e:
|
| 400 |
+
logger.error(f"Fatal error: {str(e)}")
|
| 401 |
+
raise
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
if __name__ == "__main__":
|
| 405 |
+
main()
|
experiments/compare_runs.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
from typing import List, Dict, Any, Tuple
|
| 5 |
+
import re
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
# Define category order
|
| 9 |
+
CATEGORY_ORDER = [
|
| 10 |
+
"detection",
|
| 11 |
+
"classification",
|
| 12 |
+
"localization",
|
| 13 |
+
"comparison",
|
| 14 |
+
"relationship",
|
| 15 |
+
"diagnosis",
|
| 16 |
+
"characterization",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def extract_letter_answer(answer: str) -> str:
|
| 21 |
+
"""Extract just the letter answer from various answer formats.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
answer: The answer string to extract a letter from
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
str: The extracted letter in uppercase, or empty string if no letter found
|
| 28 |
+
"""
|
| 29 |
+
if not answer:
|
| 30 |
+
return ""
|
| 31 |
+
|
| 32 |
+
# Convert to string and clean
|
| 33 |
+
answer = str(answer).strip()
|
| 34 |
+
|
| 35 |
+
# If it's just a single letter A-F, return it
|
| 36 |
+
if len(answer) == 1 and answer.upper() in "ABCDEF":
|
| 37 |
+
return answer.upper()
|
| 38 |
+
|
| 39 |
+
# Try to match patterns like "A)", "A.", "A ", etc.
|
| 40 |
+
match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE)
|
| 41 |
+
if match:
|
| 42 |
+
return match.group(1).upper()
|
| 43 |
+
|
| 44 |
+
# Try to find any standalone A-F letters preceded by space or start of string
|
| 45 |
+
# and followed by space, period, parenthesis or end of string
|
| 46 |
+
matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE)
|
| 47 |
+
if matches:
|
| 48 |
+
return matches[0].upper()
|
| 49 |
+
|
| 50 |
+
# Last resort: just find any A-F letter
|
| 51 |
+
letters = re.findall(r"[A-F]", answer, re.IGNORECASE)
|
| 52 |
+
if letters:
|
| 53 |
+
return letters[0].upper()
|
| 54 |
+
|
| 55 |
+
# If no letter found, return original (cleaned)
|
| 56 |
+
return answer.strip().upper()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]:
|
| 60 |
+
"""Parse JSON Lines file and extract valid predictions.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
file_path: Path to the JSON Lines file to parse
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple containing:
|
| 67 |
+
- str: Model name or file path if model name not found
|
| 68 |
+
- List[Dict[str, Any]]: List of valid prediction entries
|
| 69 |
+
"""
|
| 70 |
+
valid_predictions = []
|
| 71 |
+
model_name = None
|
| 72 |
+
|
| 73 |
+
# First try to parse as LLaVA format
|
| 74 |
+
try:
|
| 75 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 76 |
+
data = json.load(f)
|
| 77 |
+
if data.get("model") == "llava-med-v1.5-mistral-7b":
|
| 78 |
+
model_name = data["model"]
|
| 79 |
+
for result in data.get("results", []):
|
| 80 |
+
if all(k in result for k in ["case_id", "question_id", "correct_answer"]):
|
| 81 |
+
# Extract answer with priority: model_answer > validated_answer > raw_output
|
| 82 |
+
model_answer = (
|
| 83 |
+
result.get("model_answer")
|
| 84 |
+
or result.get("validated_answer")
|
| 85 |
+
or result.get("raw_output", "")
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Add default categories for LLaVA results
|
| 89 |
+
prediction = {
|
| 90 |
+
"case_id": result["case_id"],
|
| 91 |
+
"question_id": result["question_id"],
|
| 92 |
+
"model_answer": model_answer,
|
| 93 |
+
"correct_answer": result["correct_answer"],
|
| 94 |
+
"input": {
|
| 95 |
+
"question_data": {
|
| 96 |
+
"metadata": {
|
| 97 |
+
"categories": [
|
| 98 |
+
"detection",
|
| 99 |
+
"classification",
|
| 100 |
+
"localization",
|
| 101 |
+
"comparison",
|
| 102 |
+
"relationship",
|
| 103 |
+
"diagnosis",
|
| 104 |
+
"characterization",
|
| 105 |
+
]
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
},
|
| 109 |
+
}
|
| 110 |
+
valid_predictions.append(prediction)
|
| 111 |
+
return model_name, valid_predictions
|
| 112 |
+
except (json.JSONDecodeError, KeyError):
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
# If not LLaVA format, process as original format
|
| 116 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 117 |
+
for line in f:
|
| 118 |
+
if line.startswith("HTTP Request:"):
|
| 119 |
+
continue
|
| 120 |
+
try:
|
| 121 |
+
data = json.loads(line.strip())
|
| 122 |
+
if "model" in data:
|
| 123 |
+
model_name = data["model"]
|
| 124 |
+
if all(
|
| 125 |
+
k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"]
|
| 126 |
+
):
|
| 127 |
+
valid_predictions.append(data)
|
| 128 |
+
except json.JSONDecodeError:
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
return model_name if model_name else file_path, valid_predictions
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def filter_common_questions(
|
| 135 |
+
predictions_list: List[List[Dict[str, Any]]]
|
| 136 |
+
) -> List[List[Dict[str, Any]]]:
|
| 137 |
+
"""Ensure only questions that exist across all models are evaluated.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
predictions_list: List of prediction lists from different models
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
List[List[Dict[str, Any]]]: Filtered predictions containing only common questions
|
| 144 |
+
"""
|
| 145 |
+
question_sets = [
|
| 146 |
+
set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list
|
| 147 |
+
]
|
| 148 |
+
common_questions = set.intersection(*question_sets)
|
| 149 |
+
|
| 150 |
+
return [
|
| 151 |
+
[p for p in preds if (p["case_id"], p["question_id"]) in common_questions]
|
| 152 |
+
for preds in predictions_list
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def calculate_accuracy(
|
| 157 |
+
predictions: List[Dict[str, Any]]
|
| 158 |
+
) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]:
|
| 159 |
+
"""Compute overall and category-level accuracy.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
predictions: List of prediction entries to analyze
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Tuple containing:
|
| 166 |
+
- float: Overall accuracy percentage
|
| 167 |
+
- int: Number of correct predictions
|
| 168 |
+
- int: Total number of predictions
|
| 169 |
+
- Dict[str, Dict[str, float]]: Category-level accuracy statistics
|
| 170 |
+
"""
|
| 171 |
+
if not predictions:
|
| 172 |
+
return 0.0, 0, 0, {}
|
| 173 |
+
|
| 174 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
| 175 |
+
correct = 0
|
| 176 |
+
total = 0
|
| 177 |
+
sample_size = min(5, len(predictions))
|
| 178 |
+
sampled_indices = random.sample(range(len(predictions)), sample_size)
|
| 179 |
+
|
| 180 |
+
print("\nSample extracted answers:")
|
| 181 |
+
for i in sampled_indices:
|
| 182 |
+
pred = predictions[i]
|
| 183 |
+
model_ans = extract_letter_answer(pred["model_answer"])
|
| 184 |
+
correct_ans = extract_letter_answer(pred["correct_answer"])
|
| 185 |
+
print(f"QID: {pred['question_id']}")
|
| 186 |
+
print(f" Raw Model Answer: {pred['model_answer']}")
|
| 187 |
+
print(f" Extracted Model Answer: {model_ans}")
|
| 188 |
+
print(f" Raw Correct Answer: {pred['correct_answer']}")
|
| 189 |
+
print(f" Extracted Correct Answer: {correct_ans}")
|
| 190 |
+
print("-" * 80)
|
| 191 |
+
|
| 192 |
+
for pred in predictions:
|
| 193 |
+
try:
|
| 194 |
+
model_ans = extract_letter_answer(pred["model_answer"])
|
| 195 |
+
correct_ans = extract_letter_answer(pred["correct_answer"])
|
| 196 |
+
categories = (
|
| 197 |
+
pred.get("input", {})
|
| 198 |
+
.get("question_data", {})
|
| 199 |
+
.get("metadata", {})
|
| 200 |
+
.get("categories", [])
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if model_ans and correct_ans:
|
| 204 |
+
total += 1
|
| 205 |
+
is_correct = model_ans == correct_ans
|
| 206 |
+
if is_correct:
|
| 207 |
+
correct += 1
|
| 208 |
+
|
| 209 |
+
for category in categories:
|
| 210 |
+
category_performance[category]["total"] += 1
|
| 211 |
+
if is_correct:
|
| 212 |
+
category_performance[category]["correct"] += 1
|
| 213 |
+
|
| 214 |
+
except KeyError:
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
category_accuracies = {
|
| 218 |
+
category: {
|
| 219 |
+
"accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0,
|
| 220 |
+
"total": stats["total"],
|
| 221 |
+
"correct": stats["correct"],
|
| 222 |
+
}
|
| 223 |
+
for category, stats in category_performance.items()
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def compare_models(file_paths: List[str]) -> None:
|
| 230 |
+
"""Compare accuracy between multiple model prediction files.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
file_paths: List of paths to model prediction files to compare
|
| 234 |
+
"""
|
| 235 |
+
# Parse all files
|
| 236 |
+
parsed_results = [parse_json_lines(file_path) for file_path in file_paths]
|
| 237 |
+
model_names, predictions_list = zip(*parsed_results)
|
| 238 |
+
|
| 239 |
+
# Get initial stats
|
| 240 |
+
print(f"\n📊 **Initial Accuracy**:")
|
| 241 |
+
results = []
|
| 242 |
+
category_results = []
|
| 243 |
+
|
| 244 |
+
for preds, name in zip(predictions_list, model_names):
|
| 245 |
+
acc, correct, total, category_acc = calculate_accuracy(preds)
|
| 246 |
+
results.append((acc, correct, total, name))
|
| 247 |
+
category_results.append(category_acc)
|
| 248 |
+
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
|
| 249 |
+
|
| 250 |
+
# Get common questions across all models
|
| 251 |
+
filtered_predictions = filter_common_questions(predictions_list)
|
| 252 |
+
print(
|
| 253 |
+
f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}"
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Compute accuracy on common questions
|
| 257 |
+
print(f"\n📊 **Accuracy on Common Questions**:")
|
| 258 |
+
filtered_results = []
|
| 259 |
+
filtered_category_results = []
|
| 260 |
+
|
| 261 |
+
for preds, name in zip(filtered_predictions, model_names):
|
| 262 |
+
acc, correct, total, category_acc = calculate_accuracy(preds)
|
| 263 |
+
filtered_results.append((acc, correct, total, name))
|
| 264 |
+
filtered_category_results.append(category_acc)
|
| 265 |
+
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
|
| 266 |
+
|
| 267 |
+
# Print category-wise accuracy
|
| 268 |
+
print("\nCategory Performance (Common Questions):")
|
| 269 |
+
for category in CATEGORY_ORDER:
|
| 270 |
+
print(f"\n{category.capitalize()}:")
|
| 271 |
+
for model_name, category_acc in zip(model_names, filtered_category_results):
|
| 272 |
+
stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0})
|
| 273 |
+
print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def main():
|
| 277 |
+
parser = argparse.ArgumentParser(
|
| 278 |
+
description="Compare accuracy across multiple model prediction files"
|
| 279 |
+
)
|
| 280 |
+
parser.add_argument("files", nargs="+", help="Paths to model prediction files")
|
| 281 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
|
| 282 |
+
|
| 283 |
+
args = parser.parse_args()
|
| 284 |
+
random.seed(args.seed)
|
| 285 |
+
|
| 286 |
+
compare_models(args.files)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|
experiments/inspect_logs.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import glob
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_latest_log() -> str:
|
| 10 |
+
"""Find the most recently modified log file in the current directory.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
str: Path to the most recently modified log file
|
| 14 |
+
|
| 15 |
+
Raises:
|
| 16 |
+
FileNotFoundError: If no log files are found in the current directory
|
| 17 |
+
"""
|
| 18 |
+
logs = list(Path(".").glob("api_usage_*.json"))
|
| 19 |
+
if not logs:
|
| 20 |
+
raise FileNotFoundError("No log files found in the current directory.")
|
| 21 |
+
return str(max(logs, key=lambda p: p.stat().st_mtime))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def format_cost(entry: dict) -> str:
|
| 25 |
+
"""Format cost if available, otherwise return 'N/A'
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
entry: Log entry dictionary containing cost information
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
str: Formatted cost string with $ and 4 decimal places, or 'N/A' if cost not found
|
| 32 |
+
"""
|
| 33 |
+
return f"${entry.get('cost', 'N/A'):.4f}" if "cost" in entry else "N/A"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def print_gpt4_entry(entry: dict) -> None:
|
| 37 |
+
"""Print entry for GPT-4 format
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
entry: Log entry dictionary in GPT-4 format containing model info, inputs and outputs
|
| 41 |
+
"""
|
| 42 |
+
print("\n=== Log Entry ===")
|
| 43 |
+
print(f"Model: {entry['model']}")
|
| 44 |
+
print(f"Case ID: {entry['case_id']}")
|
| 45 |
+
print(f"Question ID: {entry['question_id']}")
|
| 46 |
+
|
| 47 |
+
print("\n=== Model Input ===")
|
| 48 |
+
messages = entry["input"]["messages"]
|
| 49 |
+
print("System message:", messages[0]["content"])
|
| 50 |
+
user_content = messages[1]["content"]
|
| 51 |
+
print("\nUser prompt:", user_content[0]["text"])
|
| 52 |
+
print("\nImages provided:")
|
| 53 |
+
for content in user_content[1:]:
|
| 54 |
+
print(f" - {content['image_url']['url']}")
|
| 55 |
+
|
| 56 |
+
print("\n=== Model Output ===")
|
| 57 |
+
print(f"Answer: {entry['model_answer']}")
|
| 58 |
+
print(f"Correct: {entry['correct_answer']}")
|
| 59 |
+
|
| 60 |
+
print("\n=== Usage Stats ===")
|
| 61 |
+
print(f"Duration: {entry['duration']}s")
|
| 62 |
+
print(f"Cost: {format_cost(entry)}")
|
| 63 |
+
print(
|
| 64 |
+
f"Tokens: {entry['usage']['total_tokens']}",
|
| 65 |
+
f"(prompt: {entry['usage']['prompt_tokens']},",
|
| 66 |
+
f"completion: {entry['usage']['completion_tokens']})",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def print_llama_entry(entry: dict) -> None:
|
| 71 |
+
"""Print entry for Llama-3.2 format
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
entry: Log entry dictionary in Llama format containing model info, inputs and outputs
|
| 75 |
+
"""
|
| 76 |
+
print("\n=== Log Entry ===")
|
| 77 |
+
print(f"Model: {entry['model']}")
|
| 78 |
+
print(f"Case ID: {entry['case_id']}")
|
| 79 |
+
print(f"Question ID: {entry['question_id']}")
|
| 80 |
+
|
| 81 |
+
print("\n=== Model Input ===")
|
| 82 |
+
print(f"Question: {entry['input']['question_data']['question']}")
|
| 83 |
+
print("\nImages provided:")
|
| 84 |
+
for url in entry["input"]["image_urls"]:
|
| 85 |
+
print(f" - {url}")
|
| 86 |
+
if entry["input"]["image_captions"]:
|
| 87 |
+
print("\nImage captions:")
|
| 88 |
+
for caption in entry["input"]["image_captions"]:
|
| 89 |
+
if caption:
|
| 90 |
+
print(f" - {caption}")
|
| 91 |
+
|
| 92 |
+
print("\n=== Model Output ===")
|
| 93 |
+
print(f"Answer: {entry['model_answer']}")
|
| 94 |
+
print(f"Correct: {entry['correct_answer']}")
|
| 95 |
+
|
| 96 |
+
print("\n=== Usage Stats ===")
|
| 97 |
+
print(f"Duration: {entry['duration']}s")
|
| 98 |
+
if "usage" in entry:
|
| 99 |
+
print(
|
| 100 |
+
f"Tokens: {entry['usage']['total_tokens']}",
|
| 101 |
+
f"(prompt: {entry['usage']['prompt_tokens']},",
|
| 102 |
+
f"completion: {entry['usage']['completion_tokens']})",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def determine_model_type(entry: dict) -> str:
|
| 107 |
+
"""Determine the model type from the entry
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
entry: Log entry dictionary containing model information
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
str: Model type - 'gpt4', 'llama', or 'unknown'
|
| 114 |
+
"""
|
| 115 |
+
model = entry.get("model", "").lower()
|
| 116 |
+
if "gpt-4" in model:
|
| 117 |
+
return "gpt4"
|
| 118 |
+
elif "llama" in model:
|
| 119 |
+
return "llama"
|
| 120 |
+
elif "chexagent" in model:
|
| 121 |
+
return "chexagent"
|
| 122 |
+
elif "medrax" in model:
|
| 123 |
+
return "medrax"
|
| 124 |
+
else:
|
| 125 |
+
return "unknown"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def print_log_entry(
|
| 129 |
+
log_file: Optional[str] = None,
|
| 130 |
+
num_entries: Optional[int] = None,
|
| 131 |
+
model_filter: Optional[str] = None,
|
| 132 |
+
) -> None:
|
| 133 |
+
"""Print log entries from the specified log file or the latest log file.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
log_file: Path to the log file. If None, uses the latest log file.
|
| 137 |
+
num_entries: Number of entries to print. If None, prints all entries.
|
| 138 |
+
model_filter: Filter entries by model type ('gpt4' or 'llama'). If None, prints all.
|
| 139 |
+
"""
|
| 140 |
+
if log_file is None:
|
| 141 |
+
log_file = get_latest_log()
|
| 142 |
+
print(f"Using latest log file: {log_file}")
|
| 143 |
+
|
| 144 |
+
entries_printed = 0
|
| 145 |
+
total_entries = 0
|
| 146 |
+
filtered_entries = 0
|
| 147 |
+
|
| 148 |
+
with open(log_file, "r") as f:
|
| 149 |
+
for line in f:
|
| 150 |
+
if line.startswith("HTTP"):
|
| 151 |
+
continue
|
| 152 |
+
try:
|
| 153 |
+
total_entries += 1
|
| 154 |
+
entry = json.loads(line)
|
| 155 |
+
|
| 156 |
+
# Apply model filter if specified
|
| 157 |
+
model_type = determine_model_type(entry)
|
| 158 |
+
if model_filter and model_type != model_filter:
|
| 159 |
+
filtered_entries += 1
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
if model_type == "gpt4":
|
| 163 |
+
print_gpt4_entry(entry)
|
| 164 |
+
elif model_type == "llama":
|
| 165 |
+
print_llama_entry(entry)
|
| 166 |
+
else:
|
| 167 |
+
print(f"Unknown model type in entry: {entry['model']}")
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
print("=" * 50)
|
| 171 |
+
entries_printed += 1
|
| 172 |
+
if num_entries and entries_printed >= num_entries:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 176 |
+
print(f"Error processing entry: {e}")
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
print(f"\nSummary:")
|
| 180 |
+
print(f"Total entries: {total_entries}")
|
| 181 |
+
print(f"Entries printed: {entries_printed}")
|
| 182 |
+
if model_filter:
|
| 183 |
+
print(f"Entries filtered: {filtered_entries}")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def main() -> None:
|
| 187 |
+
"""Main entry point for the script"""
|
| 188 |
+
parser = argparse.ArgumentParser(
|
| 189 |
+
description="Parse and display log entries from API usage logs."
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument("-l", "--log_file", nargs="?", help="Path to the log file (optional)")
|
| 192 |
+
parser.add_argument("-n", "--num_entries", type=int, help="Number of entries to display")
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"-m",
|
| 195 |
+
"--model",
|
| 196 |
+
choices=["gpt4", "llama"],
|
| 197 |
+
default="gpt4",
|
| 198 |
+
help="Model type to display (default: gpt4)",
|
| 199 |
+
)
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
print_log_entry(args.log_file, args.num_entries, args.model)
|
| 204 |
+
except FileNotFoundError as e:
|
| 205 |
+
print(f"Error: {e}")
|
| 206 |
+
exit(1)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
main()
|
experiments/validate_logs.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Tuple, Optional
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
import glob
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_latest_log() -> str:
|
| 10 |
+
"""Find the most recently modified log file in the current directory.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
str: Path to the most recently modified log file
|
| 14 |
+
|
| 15 |
+
Raises:
|
| 16 |
+
SystemExit: If no log files are found in current directory
|
| 17 |
+
"""
|
| 18 |
+
log_pattern = "api_usage_*.json"
|
| 19 |
+
logs = list(Path(".").glob(log_pattern))
|
| 20 |
+
if not logs:
|
| 21 |
+
print(f"No files matching pattern '{log_pattern}' found in current directory")
|
| 22 |
+
sys.exit(1)
|
| 23 |
+
return str(max(logs, key=lambda p: p.stat().st_mtime))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def analyze_log_file(filename: str) -> Tuple[List[Dict], List[Dict], Dict[str, List[str]]]:
|
| 27 |
+
"""Analyze a log file for entries missing images and errors.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
filename: Path to the log file to analyze
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Tuple containing:
|
| 34 |
+
- List of entries with no images
|
| 35 |
+
- List of skipped/error entries
|
| 36 |
+
- Dict of processing errors by type
|
| 37 |
+
|
| 38 |
+
Raises:
|
| 39 |
+
SystemExit: If file cannot be found or read
|
| 40 |
+
"""
|
| 41 |
+
no_images = []
|
| 42 |
+
errors = defaultdict(list)
|
| 43 |
+
skipped = []
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
with open(filename, "r") as f:
|
| 47 |
+
for line_num, line in enumerate(f, 1):
|
| 48 |
+
# Skip HTTP request logs
|
| 49 |
+
if line.startswith("HTTP Request:") or line.strip() == "":
|
| 50 |
+
continue
|
| 51 |
+
try:
|
| 52 |
+
# Try to parse the JSON line
|
| 53 |
+
if not line.strip().startswith("{"):
|
| 54 |
+
continue
|
| 55 |
+
entry = json.loads(line.strip())
|
| 56 |
+
case_id = entry.get("case_id")
|
| 57 |
+
question_id = entry.get("question_id")
|
| 58 |
+
|
| 59 |
+
# Skip if we can't identify the question
|
| 60 |
+
if not case_id or not question_id:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
# Check for explicit skip/error status
|
| 64 |
+
if entry.get("status") in ["skipped", "error"]:
|
| 65 |
+
skipped.append(
|
| 66 |
+
{
|
| 67 |
+
"case_id": case_id,
|
| 68 |
+
"question_id": question_id,
|
| 69 |
+
"reason": entry.get("reason"),
|
| 70 |
+
"status": entry.get("status"),
|
| 71 |
+
}
|
| 72 |
+
)
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
# Check user content for images
|
| 76 |
+
messages = entry.get("input", {}).get("messages", [])
|
| 77 |
+
has_image = False
|
| 78 |
+
for msg in messages:
|
| 79 |
+
content = msg.get("content", [])
|
| 80 |
+
if isinstance(content, list):
|
| 81 |
+
for item in content:
|
| 82 |
+
if isinstance(item, dict) and item.get("type") == "image_url":
|
| 83 |
+
has_image = True
|
| 84 |
+
break
|
| 85 |
+
if not has_image:
|
| 86 |
+
no_images.append(
|
| 87 |
+
{
|
| 88 |
+
"case_id": case_id,
|
| 89 |
+
"question_id": question_id,
|
| 90 |
+
"question": entry.get("input", {})
|
| 91 |
+
.get("question_data", {})
|
| 92 |
+
.get("question", "")[:100]
|
| 93 |
+
+ "...", # First 100 chars of question
|
| 94 |
+
}
|
| 95 |
+
)
|
| 96 |
+
except json.JSONDecodeError:
|
| 97 |
+
errors["json_decode"].append(f"Line {line_num}: Invalid JSON")
|
| 98 |
+
continue
|
| 99 |
+
except Exception as e:
|
| 100 |
+
errors["other"].append(f"Line {line_num}: Error processing entry: {str(e)}")
|
| 101 |
+
except FileNotFoundError:
|
| 102 |
+
print(f"Error: Could not find log file: {filename}")
|
| 103 |
+
sys.exit(1)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error reading file {filename}: {str(e)}")
|
| 106 |
+
sys.exit(1)
|
| 107 |
+
|
| 108 |
+
return no_images, skipped, errors
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def print_results(
|
| 112 |
+
filename: str, no_images: List[Dict], skipped: List[Dict], errors: Dict[str, List[str]]
|
| 113 |
+
) -> None:
|
| 114 |
+
"""Print analysis results.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
filename: Name of the analyzed log file
|
| 118 |
+
no_images: List of entries with no images
|
| 119 |
+
skipped: List of skipped/error entries
|
| 120 |
+
errors: Dict of processing errors by type
|
| 121 |
+
"""
|
| 122 |
+
print(f"\nAnalyzing log file: {filename}")
|
| 123 |
+
print("\n=== Questions with No Images ===")
|
| 124 |
+
if no_images:
|
| 125 |
+
for entry in no_images:
|
| 126 |
+
print(f"\nCase ID: {entry['case_id']}")
|
| 127 |
+
print(f"Question ID: {entry['question_id']}")
|
| 128 |
+
print(f"Question Preview: {entry['question']}")
|
| 129 |
+
print(f"\nTotal questions without images: {len(no_images)}")
|
| 130 |
+
|
| 131 |
+
print("\n=== Skipped/Error Questions ===")
|
| 132 |
+
if skipped:
|
| 133 |
+
for entry in skipped:
|
| 134 |
+
print(f"\nCase ID: {entry['case_id']}")
|
| 135 |
+
print(f"Question ID: {entry['question_id']}")
|
| 136 |
+
print(f"Status: {entry['status']}")
|
| 137 |
+
print(f"Reason: {entry.get('reason', 'unknown')}")
|
| 138 |
+
print(f"\nTotal skipped/error questions: {len(skipped)}")
|
| 139 |
+
|
| 140 |
+
if errors:
|
| 141 |
+
print("\n=== Processing Errors ===")
|
| 142 |
+
for error_type, messages in errors.items():
|
| 143 |
+
if messages:
|
| 144 |
+
print(f"\n{error_type}:")
|
| 145 |
+
for msg in messages:
|
| 146 |
+
print(f" {msg}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def main() -> None:
|
| 150 |
+
"""Main entry point for log validation script."""
|
| 151 |
+
# If a file is specified as an argument, use it; otherwise find the latest log
|
| 152 |
+
if len(sys.argv) > 1:
|
| 153 |
+
log_file = sys.argv[1]
|
| 154 |
+
else:
|
| 155 |
+
log_file = get_latest_log()
|
| 156 |
+
|
| 157 |
+
no_images, skipped, errors = analyze_log_file(log_file)
|
| 158 |
+
print_results(log_file, no_images, skipped, errors)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
interface.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import time
|
| 5 |
+
import shutil
|
| 6 |
+
from typing import AsyncGenerator, List, Optional, Tuple
|
| 7 |
+
from gradio import ChatMessage
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ChatInterface:
|
| 11 |
+
"""
|
| 12 |
+
A chat interface for interacting with a medical AI agent through Gradio.
|
| 13 |
+
|
| 14 |
+
Handles file uploads, message processing, and chat history management.
|
| 15 |
+
Supports both regular image files and DICOM medical imaging files.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, agent, tools_dict):
|
| 19 |
+
"""
|
| 20 |
+
Initialize the chat interface.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
agent: The medical AI agent to handle requests
|
| 24 |
+
tools_dict (dict): Dictionary of available tools for image processing
|
| 25 |
+
"""
|
| 26 |
+
self.agent = agent
|
| 27 |
+
self.tools_dict = tools_dict
|
| 28 |
+
self.upload_dir = Path("temp")
|
| 29 |
+
self.upload_dir.mkdir(exist_ok=True)
|
| 30 |
+
self.current_thread_id = None
|
| 31 |
+
# Separate storage for original and display paths
|
| 32 |
+
self.original_file_path = None # For LLM (.dcm or other)
|
| 33 |
+
self.display_file_path = None # For UI (always viewable format)
|
| 34 |
+
|
| 35 |
+
def handle_upload(self, file_path: str) -> str:
|
| 36 |
+
"""
|
| 37 |
+
Handle new file upload and set appropriate paths.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
file_path (str): Path to the uploaded file
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
str: Display path for UI, or None if no file uploaded
|
| 44 |
+
"""
|
| 45 |
+
if not file_path:
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
source = Path(file_path)
|
| 49 |
+
timestamp = int(time.time())
|
| 50 |
+
|
| 51 |
+
# Save original file with proper suffix
|
| 52 |
+
suffix = source.suffix.lower()
|
| 53 |
+
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
|
| 54 |
+
shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
|
| 55 |
+
self.original_file_path = str(saved_path)
|
| 56 |
+
|
| 57 |
+
# Handle DICOM conversion for display only
|
| 58 |
+
if suffix == ".dcm":
|
| 59 |
+
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
|
| 60 |
+
self.display_file_path = output["image_path"]
|
| 61 |
+
else:
|
| 62 |
+
self.display_file_path = str(saved_path)
|
| 63 |
+
|
| 64 |
+
return self.display_file_path
|
| 65 |
+
|
| 66 |
+
def add_message(
|
| 67 |
+
self, message: str, display_image: str, history: List[dict]
|
| 68 |
+
) -> Tuple[List[dict], gr.Textbox]:
|
| 69 |
+
"""
|
| 70 |
+
Add a new message to the chat history.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
message (str): Text message to add
|
| 74 |
+
display_image (str): Path to image being displayed
|
| 75 |
+
history (List[dict]): Current chat history
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Tuple[List[dict], gr.Textbox]: Updated history and textbox component
|
| 79 |
+
"""
|
| 80 |
+
image_path = self.original_file_path or display_image
|
| 81 |
+
if image_path is not None:
|
| 82 |
+
history.append({"role": "user", "content": {"path": image_path}})
|
| 83 |
+
if message is not None:
|
| 84 |
+
history.append({"role": "user", "content": message})
|
| 85 |
+
return history, gr.Textbox(value=message, interactive=False)
|
| 86 |
+
|
| 87 |
+
async def process_message(
|
| 88 |
+
self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
|
| 89 |
+
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
|
| 90 |
+
"""
|
| 91 |
+
Process a message and generate responses.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
message (str): User message to process
|
| 95 |
+
display_image (Optional[str]): Path to currently displayed image
|
| 96 |
+
chat_history (List[ChatMessage]): Current chat history
|
| 97 |
+
|
| 98 |
+
Yields:
|
| 99 |
+
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
|
| 100 |
+
"""
|
| 101 |
+
chat_history = chat_history or []
|
| 102 |
+
|
| 103 |
+
# Initialize thread if needed
|
| 104 |
+
if not self.current_thread_id:
|
| 105 |
+
self.current_thread_id = str(time.time())
|
| 106 |
+
|
| 107 |
+
messages = []
|
| 108 |
+
image_path = self.original_file_path or display_image
|
| 109 |
+
if image_path is not None:
|
| 110 |
+
messages.append({"role": "user", "content": f"path: {image_path}"})
|
| 111 |
+
if message is not None:
|
| 112 |
+
messages.append({"role": "user", "content": message})
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
for event in self.agent.workflow.stream(
|
| 116 |
+
{"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
|
| 117 |
+
):
|
| 118 |
+
if isinstance(event, dict):
|
| 119 |
+
if "process" in event:
|
| 120 |
+
content = event["process"]["messages"][-1].content
|
| 121 |
+
if content:
|
| 122 |
+
content = re.sub(r"temp/[^\s]*", "", content)
|
| 123 |
+
chat_history.append(ChatMessage(role="assistant", content=content))
|
| 124 |
+
yield chat_history, self.display_file_path, ""
|
| 125 |
+
|
| 126 |
+
elif "execute" in event:
|
| 127 |
+
for message in event["execute"]["messages"]:
|
| 128 |
+
tool_name = message.name
|
| 129 |
+
tool_result = eval(message.content)[0]
|
| 130 |
+
|
| 131 |
+
if tool_result:
|
| 132 |
+
metadata = {"title": f"🖼️ Image from tool: {tool_name}"}
|
| 133 |
+
formatted_result = " ".join(
|
| 134 |
+
line.strip() for line in str(tool_result).splitlines()
|
| 135 |
+
).strip()
|
| 136 |
+
metadata["description"] = formatted_result
|
| 137 |
+
chat_history.append(
|
| 138 |
+
ChatMessage(
|
| 139 |
+
role="assistant",
|
| 140 |
+
content=formatted_result,
|
| 141 |
+
metadata=metadata,
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# For image_visualizer, use display path
|
| 146 |
+
if tool_name == "image_visualizer":
|
| 147 |
+
self.display_file_path = tool_result["image_path"]
|
| 148 |
+
chat_history.append(
|
| 149 |
+
ChatMessage(
|
| 150 |
+
role="assistant",
|
| 151 |
+
# content=gr.Image(value=self.display_file_path),
|
| 152 |
+
content={"path": self.display_file_path},
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
yield chat_history, self.display_file_path, ""
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
chat_history.append(
|
| 160 |
+
ChatMessage(
|
| 161 |
+
role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
yield chat_history, self.display_file_path
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def create_demo(agent, tools_dict):
|
| 168 |
+
"""
|
| 169 |
+
Create a Gradio demo interface for the medical AI agent.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
agent: The medical AI agent to handle requests
|
| 173 |
+
tools_dict (dict): Dictionary of available tools for image processing
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
gr.Blocks: Gradio Blocks interface
|
| 177 |
+
"""
|
| 178 |
+
interface = ChatInterface(agent, tools_dict)
|
| 179 |
+
|
| 180 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 181 |
+
with gr.Column():
|
| 182 |
+
gr.Markdown(
|
| 183 |
+
"""
|
| 184 |
+
# 🏥 MedRAX
|
| 185 |
+
Medical Reasoning Agent for Chest X-ray
|
| 186 |
+
"""
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
with gr.Row():
|
| 190 |
+
with gr.Column(scale=3):
|
| 191 |
+
chatbot = gr.Chatbot(
|
| 192 |
+
[],
|
| 193 |
+
height=800,
|
| 194 |
+
container=True,
|
| 195 |
+
show_label=True,
|
| 196 |
+
elem_classes="chat-box",
|
| 197 |
+
type="messages",
|
| 198 |
+
label="Agent",
|
| 199 |
+
avatar_images=(
|
| 200 |
+
None,
|
| 201 |
+
"assets/medrax_logo.jpg",
|
| 202 |
+
),
|
| 203 |
+
)
|
| 204 |
+
with gr.Row():
|
| 205 |
+
with gr.Column(scale=3):
|
| 206 |
+
txt = gr.Textbox(
|
| 207 |
+
show_label=False,
|
| 208 |
+
placeholder="Ask about the X-ray...",
|
| 209 |
+
container=False,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
with gr.Column(scale=3):
|
| 213 |
+
image_display = gr.Image(
|
| 214 |
+
label="Image", type="filepath", height=700, container=True
|
| 215 |
+
)
|
| 216 |
+
with gr.Row():
|
| 217 |
+
upload_button = gr.UploadButton(
|
| 218 |
+
"📎 Upload X-Ray",
|
| 219 |
+
file_types=["image"],
|
| 220 |
+
)
|
| 221 |
+
dicom_upload = gr.UploadButton(
|
| 222 |
+
"📄 Upload DICOM",
|
| 223 |
+
file_types=["file"],
|
| 224 |
+
)
|
| 225 |
+
with gr.Row():
|
| 226 |
+
clear_btn = gr.Button("Clear Chat")
|
| 227 |
+
new_thread_btn = gr.Button("New Thread")
|
| 228 |
+
|
| 229 |
+
# Event handlers
|
| 230 |
+
def clear_chat():
|
| 231 |
+
interface.original_file_path = None
|
| 232 |
+
interface.display_file_path = None
|
| 233 |
+
return [], None
|
| 234 |
+
|
| 235 |
+
def new_thread():
|
| 236 |
+
interface.current_thread_id = str(time.time())
|
| 237 |
+
return [], interface.display_file_path
|
| 238 |
+
|
| 239 |
+
def handle_file_upload(file):
|
| 240 |
+
return interface.handle_upload(file.name)
|
| 241 |
+
|
| 242 |
+
chat_msg = txt.submit(
|
| 243 |
+
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
| 244 |
+
)
|
| 245 |
+
bot_msg = chat_msg.then(
|
| 246 |
+
interface.process_message,
|
| 247 |
+
inputs=[txt, image_display, chatbot],
|
| 248 |
+
outputs=[chatbot, image_display, txt],
|
| 249 |
+
)
|
| 250 |
+
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
|
| 251 |
+
|
| 252 |
+
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display)
|
| 253 |
+
|
| 254 |
+
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
|
| 255 |
+
|
| 256 |
+
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
| 257 |
+
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
| 258 |
+
|
| 259 |
+
return demo
|
main.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import *
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from transformers import logging
|
| 5 |
+
|
| 6 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 7 |
+
from langchain_openai import ChatOpenAI
|
| 8 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 9 |
+
from langchain_openai import ChatOpenAI
|
| 10 |
+
|
| 11 |
+
from interface import create_demo
|
| 12 |
+
from medrax.agent import *
|
| 13 |
+
from medrax.tools import *
|
| 14 |
+
from medrax.utils import *
|
| 15 |
+
|
| 16 |
+
warnings.filterwarnings("ignore")
|
| 17 |
+
logging.set_verbosity_error()
|
| 18 |
+
_ = load_dotenv()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def initialize_agent(prompt_file, model_dir="/model-weights", temp_dir="temp", device="cuda"):
|
| 22 |
+
prompts = load_prompts_from_file(prompt_file)
|
| 23 |
+
prompt = prompts["MEDICAL_ASSISTANT"]
|
| 24 |
+
|
| 25 |
+
tools_dict = {
|
| 26 |
+
"ChestXRayClassifierTool": ChestXRayClassifierTool(device=device),
|
| 27 |
+
"ChestXRayReportGeneratorTool": ChestXRayReportGeneratorTool(
|
| 28 |
+
cache_dir=model_dir, device=device
|
| 29 |
+
),
|
| 30 |
+
"ChestXRaySegmentationTool": ChestXRaySegmentationTool(device=device),
|
| 31 |
+
"LlavaMedTool": LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
|
| 32 |
+
"XRayVQATool": XRayVQATool(cache_dir=model_dir, device=device),
|
| 33 |
+
"ImageVisualizerTool": ImageVisualizerTool(),
|
| 34 |
+
"XRayPhraseGroundingTool": XRayPhraseGroundingTool(
|
| 35 |
+
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
|
| 36 |
+
),
|
| 37 |
+
"ChestXRayGeneratorTool": ChestXRayGeneratorTool(
|
| 38 |
+
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
|
| 39 |
+
),
|
| 40 |
+
"DicomProcessorTool": DicomProcessorTool(temp_dir=temp_dir),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
checkpointer = MemorySaver()
|
| 44 |
+
model = ChatOpenAI(model="gpt-4o", temperature=0.7, top_p=0.95)
|
| 45 |
+
agent = Agent(
|
| 46 |
+
model,
|
| 47 |
+
tools=list(tools_dict.values()),
|
| 48 |
+
log_tools=True,
|
| 49 |
+
log_dir="logs",
|
| 50 |
+
system_prompt=prompt,
|
| 51 |
+
checkpointer=checkpointer,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print("Agent initialized")
|
| 55 |
+
return agent, tools_dict
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
print("Starting server...")
|
| 60 |
+
agent, tools_dict = initialize_agent("medrax/docs/system_prompts.txt")
|
| 61 |
+
demo = create_demo(agent, tools_dict)
|
| 62 |
+
|
| 63 |
+
demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
|
medrax/__init__.py
ADDED
|
File without changes
|
medrax/agent/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .agent import AgentState, Agent
|
medrax/agent/agent.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import operator
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import List, Dict, Any, TypedDict, Annotated, Optional
|
| 7 |
+
|
| 8 |
+
from langgraph.graph import StateGraph, END
|
| 9 |
+
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
|
| 10 |
+
from langchain_core.language_models import BaseLanguageModel
|
| 11 |
+
from langchain_core.tools import BaseTool
|
| 12 |
+
|
| 13 |
+
_ = load_dotenv()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ToolCallLog(TypedDict):
|
| 17 |
+
"""
|
| 18 |
+
A TypedDict representing a log entry for a tool call.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
timestamp (str): The timestamp of when the tool call was made.
|
| 22 |
+
tool_call_id (str): The unique identifier for the tool call.
|
| 23 |
+
name (str): The name of the tool that was called.
|
| 24 |
+
args (Any): The arguments passed to the tool.
|
| 25 |
+
content (str): The content or result of the tool call.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
timestamp: str
|
| 29 |
+
tool_call_id: str
|
| 30 |
+
name: str
|
| 31 |
+
args: Any
|
| 32 |
+
content: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AgentState(TypedDict):
|
| 36 |
+
"""
|
| 37 |
+
A TypedDict representing the state of an agent.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
messages (Annotated[List[AnyMessage], operator.add]): A list of messages
|
| 41 |
+
representing the conversation history. The operator.add annotation
|
| 42 |
+
indicates that new messages should be appended to this list.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
messages: Annotated[List[AnyMessage], operator.add]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Agent:
|
| 49 |
+
"""
|
| 50 |
+
A class representing an agent that processes requests and executes tools based on
|
| 51 |
+
language model responses.
|
| 52 |
+
|
| 53 |
+
Attributes:
|
| 54 |
+
model (BaseLanguageModel): The language model used for processing.
|
| 55 |
+
tools (Dict[str, BaseTool]): A dictionary of available tools.
|
| 56 |
+
checkpointer (Any): Manages and persists the agent's state.
|
| 57 |
+
system_prompt (str): The system instructions for the agent.
|
| 58 |
+
workflow (StateGraph): The compiled workflow for the agent's processing.
|
| 59 |
+
log_tools (bool): Whether to log tool calls.
|
| 60 |
+
log_path (Path): Path to save tool call logs.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
model: BaseLanguageModel,
|
| 66 |
+
tools: List[BaseTool],
|
| 67 |
+
checkpointer: Any = None,
|
| 68 |
+
system_prompt: str = "",
|
| 69 |
+
log_tools: bool = True,
|
| 70 |
+
log_dir: Optional[str] = "logs",
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Initialize the Agent.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
model (BaseLanguageModel): The language model to use.
|
| 77 |
+
tools (List[BaseTool]): A list of available tools.
|
| 78 |
+
checkpointer (Any, optional): State persistence manager. Defaults to None.
|
| 79 |
+
system_prompt (str, optional): System instructions. Defaults to "".
|
| 80 |
+
log_tools (bool, optional): Whether to log tool calls. Defaults to True.
|
| 81 |
+
log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
|
| 82 |
+
"""
|
| 83 |
+
self.system_prompt = system_prompt
|
| 84 |
+
self.log_tools = log_tools
|
| 85 |
+
|
| 86 |
+
if self.log_tools:
|
| 87 |
+
self.log_path = Path(log_dir or "logs")
|
| 88 |
+
self.log_path.mkdir(exist_ok=True)
|
| 89 |
+
|
| 90 |
+
# Define the agent workflow
|
| 91 |
+
workflow = StateGraph(AgentState)
|
| 92 |
+
workflow.add_node("process", self.process_request)
|
| 93 |
+
workflow.add_node("execute", self.execute_tools)
|
| 94 |
+
workflow.add_conditional_edges(
|
| 95 |
+
"process", self.has_tool_calls, {True: "execute", False: END}
|
| 96 |
+
)
|
| 97 |
+
workflow.add_edge("execute", "process")
|
| 98 |
+
workflow.set_entry_point("process")
|
| 99 |
+
|
| 100 |
+
self.workflow = workflow.compile(checkpointer=checkpointer)
|
| 101 |
+
self.tools = {t.name: t for t in tools}
|
| 102 |
+
self.model = model.bind_tools(tools)
|
| 103 |
+
|
| 104 |
+
def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
|
| 105 |
+
"""
|
| 106 |
+
Process the request using the language model.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
state (AgentState): The current state of the agent.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
|
| 113 |
+
"""
|
| 114 |
+
messages = state["messages"]
|
| 115 |
+
if self.system_prompt:
|
| 116 |
+
messages = [SystemMessage(content=self.system_prompt)] + messages
|
| 117 |
+
response = self.model.invoke(messages)
|
| 118 |
+
return {"messages": [response]}
|
| 119 |
+
|
| 120 |
+
def has_tool_calls(self, state: AgentState) -> bool:
|
| 121 |
+
"""
|
| 122 |
+
Check if the response contains any tool calls.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
state (AgentState): The current state of the agent.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
bool: True if tool calls exist, False otherwise.
|
| 129 |
+
"""
|
| 130 |
+
response = state["messages"][-1]
|
| 131 |
+
return len(response.tool_calls) > 0
|
| 132 |
+
|
| 133 |
+
def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
|
| 134 |
+
"""
|
| 135 |
+
Execute tool calls from the model's response.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
state (AgentState): The current state of the agent.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
|
| 142 |
+
"""
|
| 143 |
+
tool_calls = state["messages"][-1].tool_calls
|
| 144 |
+
results = []
|
| 145 |
+
|
| 146 |
+
for call in tool_calls:
|
| 147 |
+
print(f"Executing tool: {call}")
|
| 148 |
+
if call["name"] not in self.tools:
|
| 149 |
+
print("\n....invalid tool....")
|
| 150 |
+
result = "invalid tool, please retry"
|
| 151 |
+
else:
|
| 152 |
+
result = self.tools[call["name"]].invoke(call["args"])
|
| 153 |
+
|
| 154 |
+
results.append(
|
| 155 |
+
ToolMessage(
|
| 156 |
+
tool_call_id=call["id"],
|
| 157 |
+
name=call["name"],
|
| 158 |
+
args=call["args"],
|
| 159 |
+
content=str(result),
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
self._save_tool_calls(results)
|
| 164 |
+
print("Returning to model processing!")
|
| 165 |
+
|
| 166 |
+
return {"messages": results}
|
| 167 |
+
|
| 168 |
+
def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
|
| 169 |
+
"""
|
| 170 |
+
Save tool calls to a JSON file with timestamp-based naming.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
tool_calls (List[ToolMessage]): List of tool calls to save.
|
| 174 |
+
"""
|
| 175 |
+
if not self.log_tools:
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 179 |
+
filename = self.log_path / f"tool_calls_{timestamp}.json"
|
| 180 |
+
|
| 181 |
+
logs: List[ToolCallLog] = []
|
| 182 |
+
for call in tool_calls:
|
| 183 |
+
log_entry = {
|
| 184 |
+
"tool_call_id": call.tool_call_id,
|
| 185 |
+
"name": call.name,
|
| 186 |
+
"args": call.args,
|
| 187 |
+
"content": call.content,
|
| 188 |
+
"timestamp": datetime.now().isoformat(),
|
| 189 |
+
}
|
| 190 |
+
logs.append(log_entry)
|
| 191 |
+
|
| 192 |
+
with open(filename, "w") as f:
|
| 193 |
+
json.dump(logs, f, indent=4)
|