Adibvafa commited on
Commit
7393de8
·
1 Parent(s): 9a2c640

Improve style

Browse files
Files changed (38) hide show
  1. medrax/agent/agent.py +1 -4
  2. medrax/llava/conversation.py +1 -3
  3. medrax/llava/eval/eval_multimodal_chat_gpt_score.py +3 -6
  4. medrax/llava/eval/llm.py +8 -23
  5. medrax/llava/eval/model_vqa.py +2 -8
  6. medrax/llava/eval/summarize_gpt_review.py +3 -7
  7. medrax/llava/mm_utils.py +4 -14
  8. medrax/llava/model/builder.py +4 -12
  9. medrax/llava/model/language_model/llava_mistral.py +1 -3
  10. medrax/llava/model/llava_arch.py +13 -39
  11. medrax/llava/model/multimodal_encoder/builder.py +2 -8
  12. medrax/llava/model/multimodal_projector/builder.py +1 -3
  13. medrax/llava/serve/cli.py +1 -3
  14. medrax/llava/serve/controller.py +3 -6
  15. medrax/llava/serve/gradio_web_server.py +4 -12
  16. medrax/llava/serve/model_worker.py +6 -14
  17. medrax/llava/serve/test_message.py +2 -6
  18. medrax/llava/utils.py +1 -3
  19. medrax/models/model_factory.py +6 -15
  20. medrax/rag/rag.py +3 -9
  21. medrax/tools/browsing/__init__.py +3 -3
  22. medrax/tools/browsing/duckduckgo.py +12 -33
  23. medrax/tools/browsing/web_browser.py +3 -9
  24. medrax/tools/classification/__init__.py +1 -6
  25. medrax/tools/classification/arcplus.py +5 -17
  26. medrax/tools/classification/torchxrayvision.py +1 -3
  27. medrax/tools/dicom.py +1 -3
  28. medrax/tools/grounding.py +4 -13
  29. medrax/tools/rag.py +1 -1
  30. medrax/tools/report_generation.py +4 -14
  31. medrax/tools/segmentation/__init__.py +1 -7
  32. medrax/tools/segmentation/medsam2.py +70 -78
  33. medrax/tools/segmentation/segmentation.py +10 -30
  34. medrax/tools/utils.py +5 -15
  35. medrax/tools/vqa/__init__.py +4 -4
  36. medrax/tools/vqa/llava_med.py +4 -12
  37. medrax/tools/vqa/xray_vqa.py +6 -12
  38. medrax/tools/xray_generation.py +12 -23
medrax/agent/agent.py CHANGED
@@ -62,9 +62,7 @@ class Agent:
62
  workflow = StateGraph(AgentState)
63
  workflow.add_node("agent", self.process_request)
64
  workflow.add_node("tools", self.tool_node)
65
- workflow.add_conditional_edges(
66
- "agent", self.has_tool_calls, {True: "tools", False: END}
67
- )
68
  workflow.add_edge("tools", "agent")
69
  workflow.set_entry_point("agent")
70
 
@@ -99,4 +97,3 @@ class Agent:
99
  """
100
  response = state["messages"][-1]
101
  return len(response.tool_calls) > 0
102
-
 
62
  workflow = StateGraph(AgentState)
63
  workflow.add_node("agent", self.process_request)
64
  workflow.add_node("tools", self.tool_node)
65
+ workflow.add_conditional_edges("agent", self.has_tool_calls, {True: "tools", False: END})
 
 
66
  workflow.add_edge("tools", "agent")
67
  workflow.set_entry_point("agent")
68
 
 
97
  """
98
  response = state["messages"][-1]
99
  return len(response.tool_calls) > 0
 
medrax/llava/conversation.py CHANGED
@@ -230,9 +230,7 @@ class Conversation:
230
  buffered = BytesIO()
231
  image.save(buffered, format="JPEG")
232
  img_b64_str = base64.b64encode(buffered.getvalue()).decode()
233
- img_str = (
234
- f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
235
- )
236
  msg = img_str + msg.replace("<image>", "").strip()
237
  ret.append([msg, None])
238
  else:
 
230
  buffered = BytesIO()
231
  image.save(buffered, format="JPEG")
232
  img_b64_str = base64.b64encode(buffered.getvalue()).decode()
233
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
 
 
234
  msg = img_str + msg.replace("<image>", "").strip()
235
  ret.append([msg, None])
236
  else:
medrax/llava/eval/eval_multimodal_chat_gpt_score.py CHANGED
@@ -14,6 +14,7 @@ INSTRUCT_PROMPT = """We would like to request your feedback on the performance o
14
  Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
15
  ROLE = "Assistant"
16
 
 
17
  # Generate instruction for GPT-4 to score the two answers.
18
  def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
19
  return (
@@ -127,17 +128,13 @@ def main(args):
127
 
128
  if __name__ == "__main__":
129
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
130
- parser.add_argument(
131
- "--answers-file", default="", metavar="FILE", help="path to model answer file"
132
- )
133
  parser.add_argument(
134
  "--question-file",
135
  default="data/questions/llava_med_eval_qa50_qa.jsonl",
136
  metavar="FILE",
137
  help="path to multichat questions file",
138
  )
139
- parser.add_argument(
140
- "--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file"
141
- )
142
  args = parser.parse_args()
143
  main(args)
 
14
  Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
15
  ROLE = "Assistant"
16
 
17
+
18
  # Generate instruction for GPT-4 to score the two answers.
19
  def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
20
  return (
 
128
 
129
  if __name__ == "__main__":
130
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
131
+ parser.add_argument("--answers-file", default="", metavar="FILE", help="path to model answer file")
 
 
132
  parser.add_argument(
133
  "--question-file",
134
  default="data/questions/llava_med_eval_qa50_qa.jsonl",
135
  metavar="FILE",
136
  help="path to multichat questions file",
137
  )
138
+ parser.add_argument("--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file")
 
 
139
  args = parser.parse_args()
140
  main(args)
medrax/llava/eval/llm.py CHANGED
@@ -21,9 +21,7 @@ class LLM(abc.ABC):
21
  raise NotImplementedError("Subclasses should implement this!")
22
 
23
  @abstractmethod
24
- def split_input(
25
- self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
26
- ):
27
  raise NotImplementedError("Subclasses should implement this!")
28
 
29
 
@@ -49,9 +47,7 @@ class GPT(LLM):
49
  def __init__(self, model_id):
50
  self.temperature = 0.0
51
  self.top_k = 1
52
- self.encoding = tiktoken.encoding_for_model(
53
- "-".join(model_id.split("-", 2)[:2]).replace("5", ".5")
54
- )
55
  self.openai_api = "default"
56
  self.model_id = model_id
57
  self.max_length = self.deployment_max_length_dict[model_id]
@@ -61,9 +57,7 @@ class GPT(LLM):
61
  azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
62
  )
63
 
64
- def gen_messages(
65
- self, fixed_instruction, few_shot_examples, input, input_header, output_header
66
- ):
67
  messages = [
68
  {
69
  "role": "system",
@@ -120,18 +114,13 @@ class GPT(LLM):
120
  ):
121
  return asyncio.run(self.dispatch_openai_requests(messages_list))
122
 
123
- def split_input(
124
- self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
125
- ):
126
  # Tokenize fixed_prompt
127
  fixed_token_ids = self.encoding.encode(
128
- fixed_instruction
129
- + " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
130
  )
131
  # Calculate remaining token length
132
- remaining_token_len = math.ceil(
133
- (self.prompt_percent * self.max_length) - len(fixed_token_ids)
134
- )
135
 
136
  # Tokenize splittable_input
137
  split_token_ids = self.encoding.encode(splittable_input)
@@ -141,14 +130,10 @@ class GPT(LLM):
141
  split_token_ids[i : i + remaining_token_len + 10]
142
  for i in range(0, len(split_token_ids), remaining_token_len)
143
  ]
144
- split_input_list = [
145
- self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list
146
- ]
147
 
148
  # Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
149
  return [
150
- self.gen_messages(
151
- fixed_instruction, few_shot_examples, split_input, input_header, output_header
152
- )
153
  for split_input in split_input_list
154
  ]
 
21
  raise NotImplementedError("Subclasses should implement this!")
22
 
23
  @abstractmethod
24
+ def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
 
 
25
  raise NotImplementedError("Subclasses should implement this!")
26
 
27
 
 
47
  def __init__(self, model_id):
48
  self.temperature = 0.0
49
  self.top_k = 1
50
+ self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace("5", ".5"))
 
 
51
  self.openai_api = "default"
52
  self.model_id = model_id
53
  self.max_length = self.deployment_max_length_dict[model_id]
 
57
  azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
58
  )
59
 
60
+ def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header):
 
 
61
  messages = [
62
  {
63
  "role": "system",
 
114
  ):
115
  return asyncio.run(self.dispatch_openai_requests(messages_list))
116
 
117
+ def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
 
 
118
  # Tokenize fixed_prompt
119
  fixed_token_ids = self.encoding.encode(
120
+ fixed_instruction + " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
 
121
  )
122
  # Calculate remaining token length
123
+ remaining_token_len = math.ceil((self.prompt_percent * self.max_length) - len(fixed_token_ids))
 
 
124
 
125
  # Tokenize splittable_input
126
  split_token_ids = self.encoding.encode(splittable_input)
 
130
  split_token_ids[i : i + remaining_token_len + 10]
131
  for i in range(0, len(split_token_ids), remaining_token_len)
132
  ]
133
+ split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list]
 
 
134
 
135
  # Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
136
  return [
137
+ self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header)
 
 
138
  for split_input in split_input_list
139
  ]
medrax/llava/eval/model_vqa.py CHANGED
@@ -45,9 +45,7 @@ def eval_model(args):
45
  disable_torch_init()
46
  model_path = os.path.expanduser(args.model_path)
47
  model_name = get_model_name_from_path(model_path)
48
- tokenizer, model, image_processor, context_len = load_pretrained_model(
49
- model_path, args.model_base, model_name
50
- )
51
 
52
  questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
53
  questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
@@ -69,11 +67,7 @@ def eval_model(args):
69
  conv.append_message(conv.roles[1], None)
70
  prompt = conv.get_prompt()
71
 
72
- input_ids = (
73
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
74
- .unsqueeze(0)
75
- .cuda()
76
- )
77
 
78
  image = Image.open(os.path.join(args.image_folder, image_file))
79
  image_tensor = process_images([image], image_processor, model.config)[0]
 
45
  disable_torch_init()
46
  model_path = os.path.expanduser(args.model_path)
47
  model_name = get_model_name_from_path(model_path)
48
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
 
 
49
 
50
  questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
51
  questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
 
67
  conv.append_message(conv.roles[1], None)
68
  prompt = conv.get_prompt()
69
 
70
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
 
 
 
 
71
 
72
  image = Image.open(os.path.join(args.image_folder, image_file))
73
  image_tensor = process_images([image], image_processor, model.config)[0]
medrax/llava/eval/summarize_gpt_review.py CHANGED
@@ -14,8 +14,7 @@ def get_domain(x):
14
  def main(args):
15
  scores_data = util.load_file_jsonl(args.scores_file)
16
  predictions = [
17
- (x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" "))
18
- for x in scores_data
19
  ]
20
 
21
  score_type_dict = defaultdict(lambda: defaultdict(list))
@@ -33,8 +32,7 @@ def main(args):
33
  result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
34
  result[q_type]["pred_score"] = util.get_avg(score_dict[2])
35
  result[q_type]["pred_relative_score"] = (
36
- util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])])
37
- * 100
38
  )
39
  result[q_type]["data_size"] = len(score_dict[1])
40
 
@@ -55,8 +53,6 @@ def main(args):
55
 
56
  if __name__ == "__main__":
57
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
58
- parser.add_argument(
59
- "--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file"
60
- )
61
  args = parser.parse_args()
62
  main(args)
 
14
  def main(args):
15
  scores_data = util.load_file_jsonl(args.scores_file)
16
  predictions = [
17
+ (x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" ")) for x in scores_data
 
18
  ]
19
 
20
  score_type_dict = defaultdict(lambda: defaultdict(list))
 
32
  result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
33
  result[q_type]["pred_score"] = util.get_avg(score_dict[2])
34
  result[q_type]["pred_relative_score"] = (
35
+ util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])]) * 100
 
36
  )
37
  result[q_type]["data_size"] = len(score_dict[1])
38
 
 
53
 
54
  if __name__ == "__main__":
55
  parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
56
+ parser.add_argument("--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file")
 
 
57
  args = parser.parse_args()
58
  main(args)
medrax/llava/mm_utils.py CHANGED
@@ -35,9 +35,7 @@ def process_images(images, image_processor, model_cfg):
35
  for image in images:
36
  if image_aspect_ratio == "pad":
37
  if image.mode == "L":
38
- background_color = int(
39
- 255 * sum(image_processor.image_mean) / len(image_processor.image_mean)
40
- )
41
  else:
42
  background_color = tuple(int(x * 255) for x in image_processor.image_mean)
43
  image = expand2square(image, background_color)
@@ -48,9 +46,7 @@ def process_images(images, image_processor, model_cfg):
48
  return new_images
49
 
50
 
51
- def tokenizer_image_token(
52
- prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
53
- ):
54
  prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
55
 
56
  def insert_separator(X, sep):
@@ -58,11 +54,7 @@ def tokenizer_image_token(
58
 
59
  input_ids = []
60
  offset = 0
61
- if (
62
- len(prompt_chunks) > 0
63
- and len(prompt_chunks[0]) > 0
64
- and prompt_chunks[0][0] == tokenizer.bos_token_id
65
- ):
66
  offset = 1
67
  input_ids.append(prompt_chunks[0][0])
68
 
@@ -100,9 +92,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
100
  self.tokenizer = tokenizer
101
  self.start_len = input_ids.shape[1]
102
 
103
- def call_for_batch(
104
- self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
105
- ) -> bool:
106
  offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
107
  self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
108
  for keyword_id in self.keyword_ids:
 
35
  for image in images:
36
  if image_aspect_ratio == "pad":
37
  if image.mode == "L":
38
+ background_color = int(255 * sum(image_processor.image_mean) / len(image_processor.image_mean))
 
 
39
  else:
40
  background_color = tuple(int(x * 255) for x in image_processor.image_mean)
41
  image = expand2square(image, background_color)
 
46
  return new_images
47
 
48
 
49
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
 
 
50
  prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
51
 
52
  def insert_separator(X, sep):
 
54
 
55
  input_ids = []
56
  offset = 0
57
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
 
 
 
 
58
  offset = 1
59
  input_ids.append(prompt_chunks[0][0])
60
 
 
92
  self.tokenizer = tokenizer
93
  self.start_len = input_ids.shape[1]
94
 
95
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
96
  offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
97
  self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
98
  for keyword_id in self.keyword_ids:
medrax/llava/model/builder.py CHANGED
@@ -59,9 +59,7 @@ def load_pretrained_model(
59
  # PEFT model
60
  from peft import PeftModel
61
 
62
- tokenizer = AutoTokenizer.from_pretrained(
63
- model_base, use_fast=False, cache_dir=cache_dir
64
- )
65
  model = AutoModelForCausalLM.from_pretrained(
66
  model_base,
67
  low_cpu_mem_usage=True,
@@ -78,9 +76,7 @@ def load_pretrained_model(
78
  else:
79
  use_fast = False
80
  if "mpt" in model_name.lower():
81
- tokenizer = AutoTokenizer.from_pretrained(
82
- model_path, use_fast=True, cache_dir=cache_dir
83
- )
84
  model = AutoModelForCausalLM.from_pretrained(
85
  model_path,
86
  low_cpu_mem_usage=True,
@@ -90,9 +86,7 @@ def load_pretrained_model(
90
  **kwargs,
91
  )
92
  else:
93
- tokenizer = AutoTokenizer.from_pretrained(
94
- model_path, use_fast=False, cache_dir=cache_dir
95
- )
96
  model = AutoModelForCausalLM.from_pretrained(
97
  model_path,
98
  low_cpu_mem_usage=True,
@@ -109,9 +103,7 @@ def load_pretrained_model(
109
  if mm_use_im_patch_token:
110
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
111
  if mm_use_im_start_end:
112
- tokenizer.add_tokens(
113
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
114
- )
115
  model.resize_token_embeddings(len(tokenizer))
116
 
117
  vision_tower = model.get_vision_tower()
 
59
  # PEFT model
60
  from peft import PeftModel
61
 
62
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, cache_dir=cache_dir)
 
 
63
  model = AutoModelForCausalLM.from_pretrained(
64
  model_base,
65
  low_cpu_mem_usage=True,
 
76
  else:
77
  use_fast = False
78
  if "mpt" in model_name.lower():
79
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, cache_dir=cache_dir)
 
 
80
  model = AutoModelForCausalLM.from_pretrained(
81
  model_path,
82
  low_cpu_mem_usage=True,
 
86
  **kwargs,
87
  )
88
  else:
89
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, cache_dir=cache_dir)
 
 
90
  model = AutoModelForCausalLM.from_pretrained(
91
  model_path,
92
  low_cpu_mem_usage=True,
 
103
  if mm_use_im_patch_token:
104
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
105
  if mm_use_im_start_end:
106
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
107
  model.resize_token_embeddings(len(tokenizer))
108
 
109
  vision_tower = model.get_vision_tower()
medrax/llava/model/language_model/llava_mistral.py CHANGED
@@ -125,9 +125,7 @@ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
125
  **kwargs,
126
  )
127
 
128
- def prepare_inputs_for_generation(
129
- self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
130
- ):
131
  images = kwargs.pop("images", None)
132
  image_sizes = kwargs.pop("image_sizes", None)
133
  inputs = super().prepare_inputs_for_generation(
 
125
  **kwargs,
126
  )
127
 
128
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
 
129
  images = kwargs.pop("images", None)
130
  image_sizes = kwargs.pop("image_sizes", None)
131
  inputs = super().prepare_inputs_for_generation(
medrax/llava/model/llava_arch.py CHANGED
@@ -104,9 +104,7 @@ class LlavaMetaModel:
104
  checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
105
  ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
106
  if len(ckpts) > 0:
107
- vision_module_weights = torch.load(
108
- f"{ckpts[-1]}/mm_projector.bin", map_location="cpu"
109
- )
110
  model_dict = get_w(vision_module_weights, "vision_tower")
111
  print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
112
  # print keys in model_dict
@@ -170,9 +168,7 @@ class LlavaMetaForCausalLM(ABC):
170
  image_features = self.encode_images(images).to(self.device)
171
 
172
  # TODO: image start / end is not implemented here to support pretraining.
173
- if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
174
- self.config, "mm_use_im_start_end", False
175
- ):
176
  raise NotImplementedError
177
 
178
  # Let's just add dummy tensors if they do not exist,
@@ -188,21 +184,15 @@ class LlavaMetaForCausalLM(ABC):
188
  else:
189
  attention_mask = attention_mask.bool()
190
  if position_ids is None:
191
- position_ids = torch.arange(
192
- 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
193
- )
194
 
195
  if labels is None:
196
  labels = torch.full_like(input_ids, IGNORE_INDEX)
197
 
198
  input_ids = [
199
- cur_input_ids[cur_attention_mask]
200
- for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
201
- ]
202
- labels = [
203
- cur_labels[cur_attention_mask]
204
- for cur_labels, cur_attention_mask in zip(labels, attention_mask)
205
  ]
 
206
 
207
  new_input_embeds = []
208
  new_labels = []
@@ -219,20 +209,14 @@ class LlavaMetaForCausalLM(ABC):
219
  continue
220
 
221
  image_token_indices = (
222
- [-1]
223
- + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
224
- + [cur_input_ids.shape[0]]
225
  )
226
  cur_input_ids_noim = []
227
  cur_labels = labels[batch_idx]
228
  cur_labels_noim = []
229
  for i in range(len(image_token_indices) - 1):
230
- cur_input_ids_noim.append(
231
- cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]
232
- )
233
- cur_labels_noim.append(
234
- cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
235
- )
236
 
237
  split_sizes = [x.shape[0] for x in cur_labels_noim]
238
  cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
@@ -279,12 +263,8 @@ class LlavaMetaForCausalLM(ABC):
279
  dtype=new_labels[0].dtype,
280
  device=new_labels[0].device,
281
  )
282
- attention_mask = torch.zeros(
283
- (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
284
- )
285
- position_ids = torch.zeros(
286
- (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
287
- )
288
 
289
  for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
290
  cur_len = cur_new_embed.shape[0]
@@ -351,9 +331,7 @@ class LlavaMetaForCausalLM(ABC):
351
  self.resize_token_embeddings(len(tokenizer))
352
 
353
  if model_args.mm_use_im_start_end:
354
- num_new_tokens = tokenizer.add_tokens(
355
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
356
- )
357
  self.resize_token_embeddings(len(tokenizer))
358
 
359
  if num_new_tokens > 0:
@@ -361,9 +339,7 @@ class LlavaMetaForCausalLM(ABC):
361
  output_embeddings = self.get_output_embeddings().weight.data
362
 
363
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
364
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
365
- dim=0, keepdim=True
366
- )
367
 
368
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
369
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
@@ -375,9 +351,7 @@ class LlavaMetaForCausalLM(ABC):
375
  p.requires_grad = False
376
 
377
  if model_args.pretrain_mm_mlp_adapter:
378
- mm_projector_weights = torch.load(
379
- model_args.pretrain_mm_mlp_adapter, map_location="cpu"
380
- )
381
  embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
382
  assert num_new_tokens == 2
383
  if input_embeddings.shape == embed_tokens_weight.shape:
 
104
  checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
105
  ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
106
  if len(ckpts) > 0:
107
+ vision_module_weights = torch.load(f"{ckpts[-1]}/mm_projector.bin", map_location="cpu")
 
 
108
  model_dict = get_w(vision_module_weights, "vision_tower")
109
  print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
110
  # print keys in model_dict
 
168
  image_features = self.encode_images(images).to(self.device)
169
 
170
  # TODO: image start / end is not implemented here to support pretraining.
171
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
 
 
172
  raise NotImplementedError
173
 
174
  # Let's just add dummy tensors if they do not exist,
 
184
  else:
185
  attention_mask = attention_mask.bool()
186
  if position_ids is None:
187
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
 
 
188
 
189
  if labels is None:
190
  labels = torch.full_like(input_ids, IGNORE_INDEX)
191
 
192
  input_ids = [
193
+ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
 
 
 
 
 
194
  ]
195
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
196
 
197
  new_input_embeds = []
198
  new_labels = []
 
209
  continue
210
 
211
  image_token_indices = (
212
+ [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
 
 
213
  )
214
  cur_input_ids_noim = []
215
  cur_labels = labels[batch_idx]
216
  cur_labels_noim = []
217
  for i in range(len(image_token_indices) - 1):
218
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
219
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
 
 
 
 
220
 
221
  split_sizes = [x.shape[0] for x in cur_labels_noim]
222
  cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
 
263
  dtype=new_labels[0].dtype,
264
  device=new_labels[0].device,
265
  )
266
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
267
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
 
 
 
 
268
 
269
  for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
270
  cur_len = cur_new_embed.shape[0]
 
331
  self.resize_token_embeddings(len(tokenizer))
332
 
333
  if model_args.mm_use_im_start_end:
334
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
335
  self.resize_token_embeddings(len(tokenizer))
336
 
337
  if num_new_tokens > 0:
 
339
  output_embeddings = self.get_output_embeddings().weight.data
340
 
341
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
342
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
 
 
343
 
344
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
345
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
 
351
  p.requires_grad = False
352
 
353
  if model_args.pretrain_mm_mlp_adapter:
354
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
 
 
355
  embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
356
  assert num_new_tokens == 2
357
  if input_embeddings.shape == embed_tokens_weight.shape:
medrax/llava/model/multimodal_encoder/builder.py CHANGED
@@ -3,13 +3,7 @@ from .clip_encoder import CLIPVisionTower
3
 
4
 
5
  def build_vision_tower(vision_tower_cfg, **kwargs):
6
- vision_tower = getattr(
7
- vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
8
- )
9
  is_absolute_path_exists = os.path.exists(vision_tower)
10
- if (
11
- is_absolute_path_exists
12
- or vision_tower.startswith("openai")
13
- or vision_tower.startswith("laion")
14
- ):
15
  return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
 
3
 
4
 
5
  def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
 
 
7
  is_absolute_path_exists = os.path.exists(vision_tower)
8
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
 
 
 
 
9
  return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
medrax/llava/model/multimodal_projector/builder.py CHANGED
@@ -19,9 +19,7 @@ class SimpleResBlock(nn.Module):
19
  super().__init__()
20
  self.pre_norm = nn.LayerNorm(channels)
21
 
22
- self.proj = nn.Sequential(
23
- nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
24
- )
25
 
26
  def forward(self, x):
27
  x = self.pre_norm(x)
 
19
  super().__init__()
20
  self.pre_norm = nn.LayerNorm(channels)
21
 
22
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
 
 
23
 
24
  def forward(self, x):
25
  x = self.pre_norm(x)
medrax/llava/serve/cli.py CHANGED
@@ -94,9 +94,7 @@ def main(args):
94
  if image is not None:
95
  # first message
96
  if model.config.mm_use_im_start_end:
97
- inp = (
98
- DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
99
- )
100
  else:
101
  inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
102
  conv.append_message(conv.roles[0], inp)
 
94
  if image is not None:
95
  # first message
96
  if model.config.mm_use_im_start_end:
97
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
 
 
98
  else:
99
  inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
100
  conv.append_message(conv.roles[0], inp)
medrax/llava/serve/controller.py CHANGED
@@ -2,6 +2,7 @@
2
  A controller manages distributed workers.
3
  It sends worker addresses to clients.
4
  """
 
5
  import argparse
6
  import dataclasses
7
  from enum import Enum, auto
@@ -199,9 +200,7 @@ class Controller:
199
  yield json.dumps(ret).encode() + b"\0"
200
 
201
  try:
202
- response = requests.post(
203
- worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5
204
- )
205
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
206
  if chunk:
207
  yield chunk + b"\0"
@@ -240,9 +239,7 @@ app = FastAPI()
240
  @app.post("/register_worker")
241
  async def register_worker(request: Request):
242
  data = await request.json()
243
- controller.register_worker(
244
- data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
245
- )
246
 
247
 
248
  @app.post("/refresh_all_workers")
 
2
  A controller manages distributed workers.
3
  It sends worker addresses to clients.
4
  """
5
+
6
  import argparse
7
  import dataclasses
8
  from enum import Enum, auto
 
200
  yield json.dumps(ret).encode() + b"\0"
201
 
202
  try:
203
+ response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5)
 
 
204
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
205
  if chunk:
206
  yield chunk + b"\0"
 
239
  @app.post("/register_worker")
240
  async def register_worker(request: Request):
241
  data = await request.json()
242
+ controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None))
 
 
243
 
244
 
245
  @app.post("/refresh_all_workers")
medrax/llava/serve/gradio_web_server.py CHANGED
@@ -216,9 +216,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
216
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
217
  for image, hash in zip(all_images, all_image_hash):
218
  t = datetime.datetime.now()
219
- filename = os.path.join(
220
- LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
221
- )
222
  if not os.path.isfile(filename):
223
  os.makedirs(os.path.dirname(filename), exist_ok=True)
224
  image.save(filename)
@@ -230,9 +228,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
230
  "temperature": float(temperature),
231
  "top_p": float(top_p),
232
  "max_new_tokens": min(int(max_new_tokens), 1536),
233
- "stop": state.sep
234
- if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
235
- else state.sep2,
236
  "images": f"List of {len(state.get_images())} images: {all_image_hash}",
237
  }
238
  logger.info(f"==== request ====\n{pload}")
@@ -330,9 +326,7 @@ block_css = """
330
 
331
 
332
  def build_demo(embed_mode):
333
- textbox = gr.Textbox(
334
- show_label=False, placeholder="Enter text and press ENTER", container=False
335
- )
336
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
337
  state = gr.State()
338
 
@@ -468,9 +462,7 @@ def build_demo(embed_mode):
468
  [state, chatbot] + btn_list,
469
  )
470
 
471
- clear_btn.click(
472
- clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False
473
- )
474
 
475
  textbox.submit(
476
  add_text,
 
216
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
217
  for image, hash in zip(all_images, all_image_hash):
218
  t = datetime.datetime.now()
219
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
 
 
220
  if not os.path.isfile(filename):
221
  os.makedirs(os.path.dirname(filename), exist_ok=True)
222
  image.save(filename)
 
228
  "temperature": float(temperature),
229
  "top_p": float(top_p),
230
  "max_new_tokens": min(int(max_new_tokens), 1536),
231
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
 
 
232
  "images": f"List of {len(state.get_images())} images: {all_image_hash}",
233
  }
234
  logger.info(f"==== request ====\n{pload}")
 
326
 
327
 
328
  def build_demo(embed_mode):
329
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
 
 
330
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
331
  state = gr.State()
332
 
 
462
  [state, chatbot] + btn_list,
463
  )
464
 
465
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
 
 
466
 
467
  textbox.submit(
468
  add_text,
medrax/llava/serve/model_worker.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  A model worker executes the model.
3
  """
 
4
  import argparse
5
  import asyncio
6
  import json
@@ -155,9 +156,7 @@ class ModelWorker:
155
  if images is not None and len(images) > 0 and self.is_multimodal:
156
  if len(images) > 0:
157
  if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
158
- raise ValueError(
159
- "Number of images does not match number of <image> tokens in prompt"
160
- )
161
 
162
  images = [load_image_from_base64(image) for image in images]
163
  images = process_images(images, image_processor, model.config)
@@ -172,9 +171,7 @@ class ModelWorker:
172
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
173
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
174
 
175
- num_image_tokens = (
176
- prompt.count(replace_token) * model.get_vision_tower().num_patches
177
- )
178
  else:
179
  images = None
180
  image_args = {"images": images}
@@ -196,19 +193,14 @@ class ModelWorker:
196
  )
197
  keywords = [stop_str]
198
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
199
- streamer = TextIteratorStreamer(
200
- tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
201
- )
202
 
203
- max_new_tokens = min(
204
- max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
205
- )
206
 
207
  if max_new_tokens < 1:
208
  yield json.dumps(
209
  {
210
- "text": ori_prompt
211
- + "Exceeds max token length. Please start a new conversation, thanks.",
212
  "error_code": 0,
213
  }
214
  ).encode() + b"\0"
 
1
  """
2
  A model worker executes the model.
3
  """
4
+
5
  import argparse
6
  import asyncio
7
  import json
 
156
  if images is not None and len(images) > 0 and self.is_multimodal:
157
  if len(images) > 0:
158
  if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
159
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
 
 
160
 
161
  images = [load_image_from_base64(image) for image in images]
162
  images = process_images(images, image_processor, model.config)
 
171
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
172
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
173
 
174
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
 
 
175
  else:
176
  images = None
177
  image_args = {"images": images}
 
193
  )
194
  keywords = [stop_str]
195
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
196
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
 
 
197
 
198
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
 
 
199
 
200
  if max_new_tokens < 1:
201
  yield json.dumps(
202
  {
203
+ "text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.",
 
204
  "error_code": 0,
205
  }
206
  ).encode() + b"\0"
medrax/llava/serve/test_message.py CHANGED
@@ -17,9 +17,7 @@ def main():
17
  models.sort()
18
  print(f"Models: {models}")
19
 
20
- ret = requests.post(
21
- controller_addr + "/get_worker_address", json={"model": args.model_name}
22
- )
23
  worker_addr = ret.json()["address"]
24
  print(f"worker_addr: {worker_addr}")
25
 
@@ -38,9 +36,7 @@ def main():
38
  "temperature": 0.7,
39
  "stop": conv.sep2,
40
  }
41
- response = requests.post(
42
- worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True
43
- )
44
 
45
  print(prompt, end="")
46
  for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
 
17
  models.sort()
18
  print(f"Models: {models}")
19
 
20
+ ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name})
 
 
21
  worker_addr = ret.json()["address"]
22
  print(f"worker_addr: {worker_addr}")
23
 
 
36
  "temperature": 0.7,
37
  "stop": conv.sep2,
38
  }
39
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True)
 
 
40
 
41
  print(prompt, end="")
42
  for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
medrax/llava/utils.py CHANGED
@@ -45,9 +45,7 @@ def build_logger(logger_name, logger_filename):
45
  if handler is None:
46
  os.makedirs(LOGDIR, exist_ok=True)
47
  filename = os.path.join(LOGDIR, logger_filename)
48
- handler = logging.handlers.TimedRotatingFileHandler(
49
- filename, when="D", utc=True, encoding="UTF-8"
50
- )
51
  handler.setFormatter(formatter)
52
 
53
  for name, item in logging.root.manager.loggerDict.items():
 
45
  if handler is None:
46
  os.makedirs(LOGDIR, exist_ok=True)
47
  filename = os.path.join(LOGDIR, logger_filename)
48
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True, encoding="UTF-8")
 
 
49
  handler.setFormatter(formatter)
50
 
51
  for name, item in logging.root.manager.loggerDict.items():
medrax/models/model_factory.py CHANGED
@@ -29,7 +29,7 @@ class ModelFactory:
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
  "gemini": {
32
- "class": ChatGoogleGenerativeAI,
33
  "env_key": "GOOGLE_API_KEY",
34
  "base_url_key": "GOOGLE_BASE_URL",
35
  },
@@ -42,14 +42,12 @@ class ModelFactory:
42
  "grok": {
43
  "class": ChatXAI,
44
  "env_key": "XAI_API_KEY",
45
- }
46
  # Add more providers with default configurations here
47
  }
48
 
49
  @classmethod
50
- def register_provider(
51
- cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs
52
- ) -> None:
53
  """Register a new model provider.
54
 
55
  Args:
@@ -61,9 +59,7 @@ class ModelFactory:
61
  cls._model_providers[prefix] = {"class": model_class, "env_key": env_key, **kwargs}
62
 
63
  @classmethod
64
- def create_model(
65
- cls, model_name: str, temperature: float = 0.7, **kwargs
66
- ) -> BaseLanguageModel:
67
  """Create and return an instance of the appropriate language model.
68
 
69
  Args:
@@ -79,9 +75,7 @@ class ModelFactory:
79
  ValueError: If the required API key is missing
80
  """
81
  # Find the matching provider based on model name prefix
82
- provider_prefix = next(
83
- (prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None
84
- )
85
 
86
  if not provider_prefix:
87
  raise ValueError(
@@ -138,7 +132,4 @@ class ModelFactory:
138
  Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
139
  """
140
  # Return a copy to prevent accidental modification
141
- return {
142
- k: {kk: vv for kk, vv in v.items() if kk != "class"}
143
- for k, v in cls._model_providers.items()
144
- }
 
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
  "gemini": {
32
+ "class": ChatGoogleGenerativeAI,
33
  "env_key": "GOOGLE_API_KEY",
34
  "base_url_key": "GOOGLE_BASE_URL",
35
  },
 
42
  "grok": {
43
  "class": ChatXAI,
44
  "env_key": "XAI_API_KEY",
45
+ },
46
  # Add more providers with default configurations here
47
  }
48
 
49
  @classmethod
50
+ def register_provider(cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs) -> None:
 
 
51
  """Register a new model provider.
52
 
53
  Args:
 
59
  cls._model_providers[prefix] = {"class": model_class, "env_key": env_key, **kwargs}
60
 
61
  @classmethod
62
+ def create_model(cls, model_name: str, temperature: float = 0.7, **kwargs) -> BaseLanguageModel:
 
 
63
  """Create and return an instance of the appropriate language model.
64
 
65
  Args:
 
75
  ValueError: If the required API key is missing
76
  """
77
  # Find the matching provider based on model name prefix
78
+ provider_prefix = next((prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None)
 
 
79
 
80
  if not provider_prefix:
81
  raise ValueError(
 
132
  Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
133
  """
134
  # Return a copy to prevent accidental modification
135
+ return {k: {kk: vv for kk, vv in v.items() if kk != "class"} for k, v in cls._model_providers.items()}
 
 
 
medrax/rag/rag.py CHANGED
@@ -107,9 +107,7 @@ class CohereRAG:
107
  # Initialize Pinecone
108
  self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
109
  if not self.pinecone_api_key:
110
- raise ValueError(
111
- "PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io"
112
- )
113
  self.pinecone = Pinecone(api_key=self.pinecone_api_key)
114
  self.index_name = self.config.pinecone_index_name
115
 
@@ -161,9 +159,7 @@ class CohereRAG:
161
  )
162
 
163
  print(f"Connecting to existing Pinecone index: {self.index_name}")
164
- vectorstore = PineconeVectorStore.from_existing_index(
165
- index_name=self.index_name, embedding=self.embeddings
166
- )
167
 
168
  # Check if the index is empty and needs to be populated
169
  try:
@@ -329,9 +325,7 @@ class CohereRAG:
329
  )
330
  documents.append(doc)
331
 
332
- print(
333
- f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}"
334
- )
335
  return documents
336
 
337
  except Exception as e:
 
107
  # Initialize Pinecone
108
  self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
109
  if not self.pinecone_api_key:
110
+ raise ValueError("PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io")
 
 
111
  self.pinecone = Pinecone(api_key=self.pinecone_api_key)
112
  self.index_name = self.config.pinecone_index_name
113
 
 
159
  )
160
 
161
  print(f"Connecting to existing Pinecone index: {self.index_name}")
162
+ vectorstore = PineconeVectorStore.from_existing_index(index_name=self.index_name, embedding=self.embeddings)
 
 
163
 
164
  # Check if the index is empty and needs to be populated
165
  try:
 
325
  )
326
  documents.append(doc)
327
 
328
+ print(f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}")
 
 
329
  return documents
330
 
331
  except Exception as e:
medrax/tools/browsing/__init__.py CHANGED
@@ -6,8 +6,8 @@ from .web_browser import WebBrowserTool, WebBrowserSchema, SearchQuerySchema, Vi
6
  __all__ = [
7
  "DuckDuckGoSearchTool",
8
  "WebSearchInput",
9
- "WebBrowserTool",
10
  "WebBrowserSchema",
11
  "SearchQuerySchema",
12
- "VisitUrlSchema"
13
- ]
 
6
  __all__ = [
7
  "DuckDuckGoSearchTool",
8
  "WebSearchInput",
9
+ "WebBrowserTool",
10
  "WebBrowserSchema",
11
  "SearchQuerySchema",
12
+ "VisitUrlSchema",
13
+ ]
medrax/tools/browsing/duckduckgo.py CHANGED
@@ -95,18 +95,12 @@ class DuckDuckGoSearchTool(BaseTool):
95
  super().__init__(**kwargs)
96
 
97
  if DDGS is None:
98
- logger.error(
99
- "duckduckgo-search package not installed. Install with: pip install duckduckgo-search"
100
- )
101
- raise ImportError(
102
- "duckduckgo-search package is required for web search functionality"
103
- )
104
 
105
  logger.info("DuckDuckGo search tool initialized successfully")
106
 
107
- def _perform_search_sync(
108
- self, query: str, max_results: int = 5, region: str = "us-en"
109
- ) -> Dict[str, Any]:
110
  """
111
  Perform the actual web search using DuckDuckGo synchronously.
112
 
@@ -118,9 +112,7 @@ class DuckDuckGoSearchTool(BaseTool):
118
  Returns:
119
  Dict[str, Any]: Structured search results.
120
  """
121
- logger.info(
122
- f"Performing web search: '{query}' (max_results={max_results}, region={region})"
123
- )
124
 
125
  try:
126
  # Initialize DDGS with error handling
@@ -158,9 +150,7 @@ class DuckDuckGoSearchTool(BaseTool):
158
  summary = f"No results found for '{query}'"
159
 
160
  # Log successful completion
161
- logger.info(
162
- f"Web search completed successfully: {len(formatted_results)} results"
163
- )
164
 
165
  return {
166
  "query": query,
@@ -217,7 +207,7 @@ class DuckDuckGoSearchTool(BaseTool):
217
 
218
  try:
219
  result = self._perform_search_sync(query, max_results, region)
220
-
221
  # Check if search was successful
222
  if "error" in result:
223
  metadata["analysis_status"] = "failed"
@@ -239,7 +229,7 @@ class DuckDuckGoSearchTool(BaseTool):
239
  }
240
  metadata["analysis_status"] = "failed"
241
  metadata["error_details"] = str(e)
242
-
243
  return error_result, metadata
244
 
245
  async def _arun(
@@ -296,9 +286,7 @@ class DuckDuckGoSearchTool(BaseTool):
296
 
297
  # Use asyncio to run sync search in executor
298
  loop = asyncio.get_event_loop()
299
- result, metadata = await loop.run_in_executor(
300
- None, self._run, query, max_results, region
301
- )
302
 
303
  if writer:
304
  # Parse result to get count for progress update
@@ -333,7 +321,7 @@ class DuckDuckGoSearchTool(BaseTool):
333
  "search_engine": "DuckDuckGo",
334
  "timestamp": datetime.now().isoformat(),
335
  }
336
-
337
  metadata = {
338
  "query": query,
339
  "max_results": max_results,
@@ -344,12 +332,10 @@ class DuckDuckGoSearchTool(BaseTool):
344
  "analysis_status": "failed",
345
  "error_details": str(e),
346
  }
347
-
348
  return error_result, metadata
349
 
350
- def get_search_summary(
351
- self, query: str, max_results: int = 3
352
- ) -> dict[str, str | list[str]]:
353
  """
354
  Get a quick summary of search results for a given query.
355
 
@@ -375,14 +361,7 @@ class DuckDuckGoSearchTool(BaseTool):
375
  results = result.get("results", [])
376
  titles = [r["title"] for r in results]
377
  urls = [r["url"] for r in results]
378
- snippets = [
379
- (
380
- r["snippet"][:100] + "..."
381
- if len(r["snippet"]) > 100
382
- else r["snippet"]
383
- )
384
- for r in results
385
- ]
386
 
387
  return {
388
  "query": query,
 
95
  super().__init__(**kwargs)
96
 
97
  if DDGS is None:
98
+ logger.error("duckduckgo-search package not installed. Install with: pip install duckduckgo-search")
99
+ raise ImportError("duckduckgo-search package is required for web search functionality")
 
 
 
 
100
 
101
  logger.info("DuckDuckGo search tool initialized successfully")
102
 
103
+ def _perform_search_sync(self, query: str, max_results: int = 5, region: str = "us-en") -> Dict[str, Any]:
 
 
104
  """
105
  Perform the actual web search using DuckDuckGo synchronously.
106
 
 
112
  Returns:
113
  Dict[str, Any]: Structured search results.
114
  """
115
+ logger.info(f"Performing web search: '{query}' (max_results={max_results}, region={region})")
 
 
116
 
117
  try:
118
  # Initialize DDGS with error handling
 
150
  summary = f"No results found for '{query}'"
151
 
152
  # Log successful completion
153
+ logger.info(f"Web search completed successfully: {len(formatted_results)} results")
 
 
154
 
155
  return {
156
  "query": query,
 
207
 
208
  try:
209
  result = self._perform_search_sync(query, max_results, region)
210
+
211
  # Check if search was successful
212
  if "error" in result:
213
  metadata["analysis_status"] = "failed"
 
229
  }
230
  metadata["analysis_status"] = "failed"
231
  metadata["error_details"] = str(e)
232
+
233
  return error_result, metadata
234
 
235
  async def _arun(
 
286
 
287
  # Use asyncio to run sync search in executor
288
  loop = asyncio.get_event_loop()
289
+ result, metadata = await loop.run_in_executor(None, self._run, query, max_results, region)
 
 
290
 
291
  if writer:
292
  # Parse result to get count for progress update
 
321
  "search_engine": "DuckDuckGo",
322
  "timestamp": datetime.now().isoformat(),
323
  }
324
+
325
  metadata = {
326
  "query": query,
327
  "max_results": max_results,
 
332
  "analysis_status": "failed",
333
  "error_details": str(e),
334
  }
335
+
336
  return error_result, metadata
337
 
338
+ def get_search_summary(self, query: str, max_results: int = 3) -> dict[str, str | list[str]]:
 
 
339
  """
340
  Get a quick summary of search results for a given query.
341
 
 
361
  results = result.get("results", [])
362
  titles = [r["title"] for r in results]
363
  urls = [r["url"] for r in results]
364
+ snippets = [(r["snippet"][:100] + "..." if len(r["snippet"]) > 100 else r["snippet"]) for r in results]
 
 
 
 
 
 
 
365
 
366
  return {
367
  "query": query,
medrax/tools/browsing/web_browser.py CHANGED
@@ -78,9 +78,7 @@ class WebBrowserTool(BaseTool):
78
  max_results: int = 5
79
  args_schema: Type[BaseModel] = WebBrowserSchema
80
 
81
- def __init__(
82
- self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs
83
- ):
84
  """Initialize the web browser tool with optional search API credentials.
85
 
86
  Args:
@@ -145,9 +143,7 @@ class WebBrowserTool(BaseTool):
145
  except Exception as e:
146
  return {"error": f"Search failed: {str(e)}"}
147
 
148
- def visit_url(
149
- self, url: str, max_content_length: int = 5000, max_links: int = 5
150
- ) -> Dict[str, Any]:
151
  """Visit a URL and extract its content with comprehensive parsing.
152
 
153
  Args:
@@ -218,9 +214,7 @@ class WebBrowserTool(BaseTool):
218
  return {
219
  "title": title,
220
  "content": (
221
- text_content[:max_content_length]
222
- if len(text_content) > max_content_length
223
- else text_content
224
  ),
225
  "url": url,
226
  "links": links[:max_links], # Limit to max_links
 
78
  max_results: int = 5
79
  args_schema: Type[BaseModel] = WebBrowserSchema
80
 
81
+ def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
 
 
82
  """Initialize the web browser tool with optional search API credentials.
83
 
84
  Args:
 
143
  except Exception as e:
144
  return {"error": f"Search failed: {str(e)}"}
145
 
146
+ def visit_url(self, url: str, max_content_length: int = 5000, max_links: int = 5) -> Dict[str, Any]:
 
 
147
  """Visit a URL and extract its content with comprehensive parsing.
148
 
149
  Args:
 
214
  return {
215
  "title": title,
216
  "content": (
217
+ text_content[:max_content_length] if len(text_content) > max_content_length else text_content
 
 
218
  ),
219
  "url": url,
220
  "links": links[:max_links], # Limit to max_links
medrax/tools/classification/__init__.py CHANGED
@@ -3,9 +3,4 @@
3
  from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
4
  from .arcplus import ArcPlusClassifierTool, ArcPlusInput
5
 
6
- __all__ = [
7
- "TorchXRayVisionClassifierTool",
8
- "TorchXRayVisionInput",
9
- "ArcPlusClassifierTool",
10
- "ArcPlusInput"
11
- ]
 
3
  from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
4
  from .arcplus import ArcPlusClassifierTool, ArcPlusInput
5
 
6
+ __all__ = ["TorchXRayVisionClassifierTool", "TorchXRayVisionInput", "ArcPlusClassifierTool", "ArcPlusInput"]
 
 
 
 
 
medrax/tools/classification/arcplus.py CHANGED
@@ -38,9 +38,7 @@ class OmniSwinTransformer(SwinTransformer):
38
 
39
  self.omni_heads = []
40
  for num_classes in num_classes_list:
41
- self.omni_heads.append(
42
- nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
43
- )
44
  self.omni_heads = nn.ModuleList(self.omni_heads)
45
 
46
  def forward(self, x, head_n=None):
@@ -62,9 +60,7 @@ class OmniSwinTransformer(SwinTransformer):
62
  class ArcPlusInput(BaseModel):
63
  """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
64
 
65
- image_path: str = Field(
66
- ..., description="Path to the radiology image file, only supports JPG or PNG images"
67
- )
68
 
69
 
70
  class ArcPlusClassifierTool(BaseTool):
@@ -249,11 +245,7 @@ class ArcPlusClassifierTool(BaseTool):
249
 
250
  # Remove "module." prefix if present (improved logic from example)
251
  if any([True if "module." in k else False for k in state_dict.keys()]):
252
- state_dict = {
253
- k.replace("module.", ""): v
254
- for k, v in state_dict.items()
255
- if k.startswith("module.")
256
- }
257
 
258
  # Load the model weights
259
  msg = self.model.load_state_dict(state_dict, strict=False)
@@ -333,14 +325,10 @@ class ArcPlusClassifierTool(BaseTool):
333
 
334
  # Map predictions to disease names
335
  if len(predictions) != len(self.disease_list):
336
- print(
337
- f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}"
338
- )
339
  # Pad or truncate as needed
340
  if len(predictions) < len(self.disease_list):
341
- predictions = np.pad(
342
- predictions, (0, len(self.disease_list) - len(predictions))
343
- )
344
  else:
345
  predictions = predictions[: len(self.disease_list)]
346
 
 
38
 
39
  self.omni_heads = []
40
  for num_classes in num_classes_list:
41
+ self.omni_heads.append(nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
 
 
42
  self.omni_heads = nn.ModuleList(self.omni_heads)
43
 
44
  def forward(self, x, head_n=None):
 
60
  class ArcPlusInput(BaseModel):
61
  """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
62
 
63
+ image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
 
 
64
 
65
 
66
  class ArcPlusClassifierTool(BaseTool):
 
245
 
246
  # Remove "module." prefix if present (improved logic from example)
247
  if any([True if "module." in k else False for k in state_dict.keys()]):
248
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if k.startswith("module.")}
 
 
 
 
249
 
250
  # Load the model weights
251
  msg = self.model.load_state_dict(state_dict, strict=False)
 
325
 
326
  # Map predictions to disease names
327
  if len(predictions) != len(self.disease_list):
328
+ print(f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}")
 
 
329
  # Pad or truncate as needed
330
  if len(predictions) < len(self.disease_list):
331
+ predictions = np.pad(predictions, (0, len(self.disease_list) - len(predictions)))
 
 
332
  else:
333
  predictions = predictions[: len(self.disease_list)]
334
 
medrax/tools/classification/torchxrayvision.py CHANGED
@@ -16,9 +16,7 @@ from langchain_core.tools import BaseTool
16
  class TorchXRayVisionInput(BaseModel):
17
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
18
 
19
- image_path: str = Field(
20
- ..., description="Path to the radiology image file, only supports JPG or PNG images"
21
- )
22
 
23
 
24
  class TorchXRayVisionClassifierTool(BaseTool):
 
16
  class TorchXRayVisionInput(BaseModel):
17
  """Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
18
 
19
+ image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
 
 
20
 
21
 
22
  class TorchXRayVisionClassifierTool(BaseTool):
medrax/tools/dicom.py CHANGED
@@ -14,9 +14,7 @@ class DicomProcessorInput(BaseModel):
14
  """Input schema for the DICOM Processor Tool."""
15
 
16
  dicom_path: str = Field(..., description="Path to the DICOM file")
17
- window_center: Optional[float] = Field(
18
- None, description="Window center for contrast adjustment"
19
- )
20
  window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
21
 
22
 
 
14
  """Input schema for the DICOM Processor Tool."""
15
 
16
  dicom_path: str = Field(..., description="Path to the DICOM file")
17
+ window_center: Optional[float] = Field(None, description="Window center for contrast adjustment")
 
 
18
  window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
19
 
20
 
medrax/tools/grounding.py CHANGED
@@ -89,11 +89,8 @@ class XRayPhraseGroundingTool(BaseTool):
89
  trust_remote_code=True,
90
  quantization_config=quantization_config,
91
  )
92
- self.processor = AutoProcessor.from_pretrained(
93
- model_path, cache_dir=cache_dir, trust_remote_code=True
94
- )
95
 
96
-
97
  self.model = self.model.eval()
98
 
99
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
@@ -167,12 +164,8 @@ class XRayPhraseGroundingTool(BaseTool):
167
  )
168
 
169
  prompt_length = inputs["input_ids"].shape[-1]
170
- decoded_text = self.processor.decode(
171
- output[0][prompt_length:], skip_special_tokens=True
172
- )
173
- predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(
174
- decoded_text
175
- )
176
 
177
  metadata = {
178
  "image_path": image_path,
@@ -199,9 +192,7 @@ class XRayPhraseGroundingTool(BaseTool):
199
  # Convert model bboxes to list format and get original image bboxes
200
  model_bboxes = [list(bbox) for bbox in pred_bboxes]
201
  original_bboxes = [
202
- self.processor.adjust_box_for_original_image_size(
203
- bbox, width=image.size[0], height=image.size[1]
204
- )
205
  for bbox in model_bboxes
206
  ]
207
 
 
89
  trust_remote_code=True,
90
  quantization_config=quantization_config,
91
  )
92
+ self.processor = AutoProcessor.from_pretrained(model_path, cache_dir=cache_dir, trust_remote_code=True)
 
 
93
 
 
94
  self.model = self.model.eval()
95
 
96
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
 
164
  )
165
 
166
  prompt_length = inputs["input_ids"].shape[-1]
167
+ decoded_text = self.processor.decode(output[0][prompt_length:], skip_special_tokens=True)
168
+ predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
 
 
 
 
169
 
170
  metadata = {
171
  "image_path": image_path,
 
192
  # Convert model bboxes to list format and get original image bboxes
193
  model_bboxes = [list(bbox) for bbox in pred_bboxes]
194
  original_bboxes = [
195
+ self.processor.adjust_box_for_original_image_size(bbox, width=image.size[0], height=image.size[1])
 
 
196
  for bbox in model_bboxes
197
  ]
198
 
medrax/tools/rag.py CHANGED
@@ -14,7 +14,7 @@ class RAGTool(BaseTool):
14
 
15
  The knowledge base includes:
16
  - Medical textbooks and reference materials
17
- - Research papers and clinical studies
18
  - Medical manuals and guidelines
19
  - Specialized medical literature
20
 
 
14
 
15
  The knowledge base includes:
16
  - Medical textbooks and reference materials
17
+ - Research papers and clinical studies
18
  - Medical manuals and guidelines
19
  - Specialized medical literature
20
 
medrax/tools/report_generation.py CHANGED
@@ -22,9 +22,7 @@ from transformers import (
22
  class ChestXRayInput(BaseModel):
23
  """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
24
 
25
- image_path: str = Field(
26
- ..., description="Path to the radiology image file, only supports JPG or PNG images"
27
- )
28
 
29
 
30
  class ChestXRayReportGeneratorTool(BaseTool):
@@ -170,12 +168,8 @@ class ChestXRayReportGeneratorTool(BaseTool):
170
  """
171
  try:
172
  # Process image for both models
173
- findings_pixels = self._process_image(
174
- image_path, self.findings_processor, self.findings_model
175
- )
176
- impression_pixels = self._process_image(
177
- image_path, self.impression_processor, self.impression_model
178
- )
179
 
180
  # Generate both sections
181
  with torch.inference_mode():
@@ -187,11 +181,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
187
  )
188
 
189
  # Combine into formatted report
190
- report = (
191
- "CHEST X-RAY REPORT\n\n"
192
- f"FINDINGS:\n{findings_text}\n\n"
193
- f"IMPRESSION:\n{impression_text}"
194
- )
195
 
196
  output = {
197
  "report": report,
 
22
  class ChestXRayInput(BaseModel):
23
  """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
24
 
25
+ image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
 
 
26
 
27
 
28
  class ChestXRayReportGeneratorTool(BaseTool):
 
168
  """
169
  try:
170
  # Process image for both models
171
+ findings_pixels = self._process_image(image_path, self.findings_processor, self.findings_model)
172
+ impression_pixels = self._process_image(image_path, self.impression_processor, self.impression_model)
 
 
 
 
173
 
174
  # Generate both sections
175
  with torch.inference_mode():
 
181
  )
182
 
183
  # Combine into formatted report
184
+ report = "CHEST X-RAY REPORT\n\n" f"FINDINGS:\n{findings_text}\n\n" f"IMPRESSION:\n{impression_text}"
 
 
 
 
185
 
186
  output = {
187
  "report": report,
medrax/tools/segmentation/__init__.py CHANGED
@@ -3,10 +3,4 @@
3
  from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
4
  from .medsam2 import MedSAM2Tool, MedSAM2Input
5
 
6
- __all__ = [
7
- "ChestXRaySegmentationTool",
8
- "ChestXRaySegmentationInput",
9
- "OrganMetrics",
10
- "MedSAM2Tool",
11
- "MedSAM2Input"
12
- ]
 
3
  from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
4
  from .medsam2 import MedSAM2Tool, MedSAM2Input
5
 
6
+ __all__ = ["ChestXRaySegmentationTool", "ChestXRaySegmentationInput", "OrganMetrics", "MedSAM2Tool", "MedSAM2Input"]
 
 
 
 
 
 
medrax/tools/segmentation/medsam2.py CHANGED
@@ -26,7 +26,6 @@ from hydra import initialize_config_dir
26
  from hydra.core.global_hydra import GlobalHydra
27
 
28
 
29
-
30
  class MedSAM2Input(BaseModel):
31
  """Input schema for the MedSAM2 Tool."""
32
 
@@ -47,7 +46,7 @@ class MedSAM2Input(BaseModel):
47
 
48
  class MedSAM2Tool(BaseTool):
49
  """Advanced medical image segmentation tool using MedSAM2.
50
-
51
  This tool provides state-of-the-art medical image segmentation capabilities using
52
  the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
53
  Supports interactive prompting with boxes, points, or automatic segmentation.
@@ -92,18 +91,15 @@ class MedSAM2Tool(BaseTool):
92
  # This works around the issue with initialize_config_module in sam2
93
  if GlobalHydra.instance().is_initialized():
94
  GlobalHydra.instance().clear()
95
-
96
  config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
97
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
98
-
99
  hf_hub_download(
100
- repo_id=model_path,
101
- filename=model_file,
102
- local_dir=self.cache_dir,
103
- local_dir_use_symlinks=False
104
  )
105
 
106
- config_path = model_cfg.replace('.yaml', '')
107
  sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
108
  self.predictor = SAM2ImagePredictor(sam2_model)
109
 
@@ -114,37 +110,37 @@ class MedSAM2Tool(BaseTool):
114
  """Load and preprocess image for medical analysis."""
115
  try:
116
  # Handle different image formats
117
- if image_path.lower().endswith('.dcm'):
118
  # DICOM files - would need DICOM processor
119
  raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
120
-
121
  # Load standard image formats
122
  image = Image.open(image_path)
123
-
124
  # For medical images, convert to grayscale first if needed, then to RGB
125
- if image.mode == 'L': # Grayscale
126
  # Convert grayscale to RGB for SAM2
127
- image = image.convert('RGB')
128
- elif image.mode != 'RGB':
129
- if image.mode == 'RGBA':
130
  # Create white background for RGBA
131
- background = Image.new('RGB', image.size, (255, 255, 255))
132
  background.paste(image, mask=image.split()[-1])
133
  image = background
134
  else:
135
- image = image.convert('RGB')
136
-
137
  # Convert to numpy array
138
  image_np = np.array(image)
139
-
140
  # Ensure image is in proper range [0, 255]
141
  if image_np.max() <= 1.0:
142
  image_np = (image_np * 255).astype(np.uint8)
143
  else:
144
  image_np = image_np.astype(np.uint8)
145
-
146
  return image_np
147
-
148
  except Exception as e:
149
  raise ValueError(f"Failed to load image {image_path}: {str(e)}")
150
 
@@ -152,55 +148,53 @@ class MedSAM2Tool(BaseTool):
152
  """Process and validate prompts."""
153
  if prompt_type == "auto":
154
  return None, None, None
155
-
156
  if prompt_coords is None:
157
  if prompt_type != "auto":
158
  raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
159
  return None, None, None
160
-
161
  if prompt_type == "box":
162
  if len(prompt_coords) != 4:
163
  raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
164
-
165
  x1, y1, x2, y2 = prompt_coords
166
  # Validate coordinates
167
  if x1 >= x2 or y1 >= y2:
168
  raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
169
-
170
  input_box = np.array([[x1, y1, x2, y2]])
171
  return input_box, None, None
172
-
173
  elif prompt_type == "point":
174
  if len(prompt_coords) != 2:
175
  raise ValueError("Point prompt requires 2 coordinates: [x,y]")
176
-
177
  x, y = prompt_coords
178
  input_point = np.array([[x, y]])
179
  input_label = np.array([1]) # Positive point
180
  return None, input_point, input_label
181
-
182
  else:
183
  raise ValueError(f"Unknown prompt type: {prompt_type}")
184
 
185
  def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
186
  """Create visualization of segmentation results."""
187
  plt.figure(figsize=(10, 10))
188
-
189
  # Convert RGB image to grayscale for background display
190
  if len(image.shape) == 3:
191
  # Convert RGB to grayscale using standard luminance formula
192
- gray_image = 0.299 * image[:,:,0] + 0.587 * image[:,:,1] + 0.114 * image[:,:,2]
193
  else:
194
  gray_image = image
195
-
196
  # Display grayscale background
197
- plt.imshow(
198
- gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0]
199
- )
200
-
201
  # Generate color palette for multiple masks
202
  colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
203
-
204
  # Process and overlay each mask
205
  for idx, (mask, color) in enumerate(zip(masks, colors)):
206
  if mask.sum() > 0:
@@ -208,33 +202,31 @@ class MedSAM2Tool(BaseTool):
208
  mask_bool = mask.astype(bool)
209
  colored_mask = np.zeros((*mask_bool.shape, 4))
210
  colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
211
- plt.imshow(
212
- colored_mask, extent=[0, image.shape[1], image.shape[0], 0]
213
- )
214
-
215
  # Add legend entry for each mask
216
  mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
217
  plt.plot([], [], color=color, label=mask_label, linewidth=3)
218
-
219
  # Add prompt visualization with consistent styling
220
- if prompt_info.get('box') is not None:
221
- box = prompt_info['box'][0]
222
  x1, y1, x2, y2 = box
223
- plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'g-', linewidth=2, label='Box Prompt')
224
-
225
- if prompt_info.get('point') is not None:
226
- point = prompt_info['point'][0]
227
- plt.plot(point[0], point[1], 'go', markersize=10, label='Point Prompt')
228
-
229
  plt.title("Segmentation Overlay")
230
  plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
231
  plt.axis("off")
232
-
233
  # Save visualization with higher DPI like segmentation tool
234
  viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
235
- plt.savefig(viz_path, bbox_inches='tight', dpi=300)
236
  plt.close()
237
-
238
  return str(viz_path)
239
 
240
  def _run(
@@ -249,28 +241,28 @@ class MedSAM2Tool(BaseTool):
249
  try:
250
  # Load image
251
  image = self._load_image(image_path)
252
-
253
  # Set image for predictor
254
  self.predictor.set_image(image)
255
-
256
  # Process prompts
257
- input_box, input_point, input_label = self._process_prompts(
258
- prompt_type, prompt_coords, image.shape[:2]
259
- )
260
-
261
  # Run inference
262
  if prompt_type == "auto":
263
  # For auto segmentation, try multiple approaches and select best result
264
  h, w = image.shape[:2]
265
-
266
  # Try multiple points in key areas for medical images
267
- sample_points = np.array([
268
- [w//3, h//3], # Upper left lung area
269
- [2*w//3, h//3], # Upper right lung area
270
- [w//2, 2*h//3], # Lower center area
271
- ])
 
 
272
  sample_labels = np.array([1, 1, 1]) # All positive points
273
-
274
  masks, scores, logits = self.predictor.predict(
275
  point_coords=sample_points,
276
  point_labels=sample_labels,
@@ -283,29 +275,29 @@ class MedSAM2Tool(BaseTool):
283
  box=input_box,
284
  multimask_output=True,
285
  )
286
-
287
  # Create visualization
288
  prompt_info = {
289
- 'box': input_box,
290
- 'point': input_point,
291
- 'type': prompt_type,
292
- 'scores': scores # Add scores for legend display
293
  }
294
  viz_path = self._create_visualization(image, masks, prompt_info)
295
-
296
  # Create output dictionary (main results)
297
  output = {
298
  "segmentation_image_path": viz_path,
299
- "confidence_scores": scores.tolist() if hasattr(scores, 'tolist') else list(scores),
300
  "num_masks": len(masks),
301
  "best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
302
  "mask_summary": {
303
  "total_masks": len(masks),
304
  "mask_shapes": [list(mask.shape) for mask in masks],
305
- "segmented_area_pixels": [int(mask.sum()) for mask in masks]
306
  },
307
  }
308
-
309
  # Create metadata dictionary
310
  metadata = {
311
  "image_path": image_path,
@@ -317,9 +309,9 @@ class MedSAM2Tool(BaseTool):
317
  "num_masks_generated": len(masks),
318
  "analysis_status": "completed",
319
  }
320
-
321
  return output, metadata
322
-
323
  except Exception as e:
324
  error_output = {"error": str(e)}
325
  error_metadata = {
@@ -338,4 +330,4 @@ class MedSAM2Tool(BaseTool):
338
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
339
  ) -> Tuple[Dict[str, Any], Dict]:
340
  """Async version of _run."""
341
- return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
 
26
  from hydra.core.global_hydra import GlobalHydra
27
 
28
 
 
29
  class MedSAM2Input(BaseModel):
30
  """Input schema for the MedSAM2 Tool."""
31
 
 
46
 
47
  class MedSAM2Tool(BaseTool):
48
  """Advanced medical image segmentation tool using MedSAM2.
49
+
50
  This tool provides state-of-the-art medical image segmentation capabilities using
51
  the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
52
  Supports interactive prompting with boxes, points, or automatic segmentation.
 
91
  # This works around the issue with initialize_config_module in sam2
92
  if GlobalHydra.instance().is_initialized():
93
  GlobalHydra.instance().clear()
94
+
95
  config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
96
  initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
97
+
98
  hf_hub_download(
99
+ repo_id=model_path, filename=model_file, local_dir=self.cache_dir, local_dir_use_symlinks=False
 
 
 
100
  )
101
 
102
+ config_path = model_cfg.replace(".yaml", "")
103
  sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
104
  self.predictor = SAM2ImagePredictor(sam2_model)
105
 
 
110
  """Load and preprocess image for medical analysis."""
111
  try:
112
  # Handle different image formats
113
+ if image_path.lower().endswith(".dcm"):
114
  # DICOM files - would need DICOM processor
115
  raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
116
+
117
  # Load standard image formats
118
  image = Image.open(image_path)
119
+
120
  # For medical images, convert to grayscale first if needed, then to RGB
121
+ if image.mode == "L": # Grayscale
122
  # Convert grayscale to RGB for SAM2
123
+ image = image.convert("RGB")
124
+ elif image.mode != "RGB":
125
+ if image.mode == "RGBA":
126
  # Create white background for RGBA
127
+ background = Image.new("RGB", image.size, (255, 255, 255))
128
  background.paste(image, mask=image.split()[-1])
129
  image = background
130
  else:
131
+ image = image.convert("RGB")
132
+
133
  # Convert to numpy array
134
  image_np = np.array(image)
135
+
136
  # Ensure image is in proper range [0, 255]
137
  if image_np.max() <= 1.0:
138
  image_np = (image_np * 255).astype(np.uint8)
139
  else:
140
  image_np = image_np.astype(np.uint8)
141
+
142
  return image_np
143
+
144
  except Exception as e:
145
  raise ValueError(f"Failed to load image {image_path}: {str(e)}")
146
 
 
148
  """Process and validate prompts."""
149
  if prompt_type == "auto":
150
  return None, None, None
151
+
152
  if prompt_coords is None:
153
  if prompt_type != "auto":
154
  raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
155
  return None, None, None
156
+
157
  if prompt_type == "box":
158
  if len(prompt_coords) != 4:
159
  raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
160
+
161
  x1, y1, x2, y2 = prompt_coords
162
  # Validate coordinates
163
  if x1 >= x2 or y1 >= y2:
164
  raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
165
+
166
  input_box = np.array([[x1, y1, x2, y2]])
167
  return input_box, None, None
168
+
169
  elif prompt_type == "point":
170
  if len(prompt_coords) != 2:
171
  raise ValueError("Point prompt requires 2 coordinates: [x,y]")
172
+
173
  x, y = prompt_coords
174
  input_point = np.array([[x, y]])
175
  input_label = np.array([1]) # Positive point
176
  return None, input_point, input_label
177
+
178
  else:
179
  raise ValueError(f"Unknown prompt type: {prompt_type}")
180
 
181
  def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
182
  """Create visualization of segmentation results."""
183
  plt.figure(figsize=(10, 10))
184
+
185
  # Convert RGB image to grayscale for background display
186
  if len(image.shape) == 3:
187
  # Convert RGB to grayscale using standard luminance formula
188
+ gray_image = 0.299 * image[:, :, 0] + 0.587 * image[:, :, 1] + 0.114 * image[:, :, 2]
189
  else:
190
  gray_image = image
191
+
192
  # Display grayscale background
193
+ plt.imshow(gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0])
194
+
 
 
195
  # Generate color palette for multiple masks
196
  colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
197
+
198
  # Process and overlay each mask
199
  for idx, (mask, color) in enumerate(zip(masks, colors)):
200
  if mask.sum() > 0:
 
202
  mask_bool = mask.astype(bool)
203
  colored_mask = np.zeros((*mask_bool.shape, 4))
204
  colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
205
+ plt.imshow(colored_mask, extent=[0, image.shape[1], image.shape[0], 0])
206
+
 
 
207
  # Add legend entry for each mask
208
  mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
209
  plt.plot([], [], color=color, label=mask_label, linewidth=3)
210
+
211
  # Add prompt visualization with consistent styling
212
+ if prompt_info.get("box") is not None:
213
+ box = prompt_info["box"][0]
214
  x1, y1, x2, y2 = box
215
+ plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], "g-", linewidth=2, label="Box Prompt")
216
+
217
+ if prompt_info.get("point") is not None:
218
+ point = prompt_info["point"][0]
219
+ plt.plot(point[0], point[1], "go", markersize=10, label="Point Prompt")
220
+
221
  plt.title("Segmentation Overlay")
222
  plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
223
  plt.axis("off")
224
+
225
  # Save visualization with higher DPI like segmentation tool
226
  viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
227
+ plt.savefig(viz_path, bbox_inches="tight", dpi=300)
228
  plt.close()
229
+
230
  return str(viz_path)
231
 
232
  def _run(
 
241
  try:
242
  # Load image
243
  image = self._load_image(image_path)
244
+
245
  # Set image for predictor
246
  self.predictor.set_image(image)
247
+
248
  # Process prompts
249
+ input_box, input_point, input_label = self._process_prompts(prompt_type, prompt_coords, image.shape[:2])
250
+
 
 
251
  # Run inference
252
  if prompt_type == "auto":
253
  # For auto segmentation, try multiple approaches and select best result
254
  h, w = image.shape[:2]
255
+
256
  # Try multiple points in key areas for medical images
257
+ sample_points = np.array(
258
+ [
259
+ [w // 3, h // 3], # Upper left lung area
260
+ [2 * w // 3, h // 3], # Upper right lung area
261
+ [w // 2, 2 * h // 3], # Lower center area
262
+ ]
263
+ )
264
  sample_labels = np.array([1, 1, 1]) # All positive points
265
+
266
  masks, scores, logits = self.predictor.predict(
267
  point_coords=sample_points,
268
  point_labels=sample_labels,
 
275
  box=input_box,
276
  multimask_output=True,
277
  )
278
+
279
  # Create visualization
280
  prompt_info = {
281
+ "box": input_box,
282
+ "point": input_point,
283
+ "type": prompt_type,
284
+ "scores": scores, # Add scores for legend display
285
  }
286
  viz_path = self._create_visualization(image, masks, prompt_info)
287
+
288
  # Create output dictionary (main results)
289
  output = {
290
  "segmentation_image_path": viz_path,
291
+ "confidence_scores": scores.tolist() if hasattr(scores, "tolist") else list(scores),
292
  "num_masks": len(masks),
293
  "best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
294
  "mask_summary": {
295
  "total_masks": len(masks),
296
  "mask_shapes": [list(mask.shape) for mask in masks],
297
+ "segmented_area_pixels": [int(mask.sum()) for mask in masks],
298
  },
299
  }
300
+
301
  # Create metadata dictionary
302
  metadata = {
303
  "image_path": image_path,
 
309
  "num_masks_generated": len(masks),
310
  "analysis_status": "completed",
311
  }
312
+
313
  return output, metadata
314
+
315
  except Exception as e:
316
  error_output = {"error": str(e)}
317
  error_metadata = {
 
330
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
331
  ) -> Tuple[Dict[str, Any], Dict]:
332
  """Async version of _run."""
333
+ return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
medrax/tools/segmentation/segmentation.py CHANGED
@@ -41,9 +41,7 @@ class OrganMetrics(BaseModel):
41
  area_pixels: int = Field(..., description="Area in pixels")
42
  area_cm2: float = Field(..., description="Approximate area in cm²")
43
  centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
44
- bbox: Tuple[int, int, int, int] = Field(
45
- ..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)"
46
- )
47
 
48
  # Size metrics
49
  width: int = Field(..., description="Width of the organ in pixels")
@@ -51,9 +49,7 @@ class OrganMetrics(BaseModel):
51
  aspect_ratio: float = Field(..., description="Height/width ratio")
52
 
53
  # Position metrics
54
- relative_position: Dict[str, float] = Field(
55
- ..., description="Position relative to image boundaries (0-1 scale)"
56
- )
57
 
58
  # Analysis metrics
59
  mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
@@ -90,9 +86,7 @@ class ChestXRaySegmentationTool(BaseTool):
90
  self.model = self.model.to(self.device)
91
  self.model.eval()
92
 
93
- self.transform = torchvision.transforms.Compose(
94
- [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)]
95
- )
96
 
97
  self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
98
  self.temp_dir.mkdir(exist_ok=True)
@@ -115,9 +109,7 @@ class ChestXRaySegmentationTool(BaseTool):
115
  "Spine": 13,
116
  }
117
 
118
- def _align_mask_to_original(
119
- self, mask: np.ndarray, original_shape: Tuple[int, int]
120
- ) -> np.ndarray:
121
  """
122
  Align a mask from the transformed (cropped/resized) space back to the full original image.
123
  Assumes that the transform does a center crop to a square of side = min(original height, width)
@@ -170,23 +162,17 @@ class ChestXRaySegmentationTool(BaseTool):
170
  bbox=tuple(map(int, props.bbox)),
171
  width=int(props.bbox[3] - props.bbox[1]),
172
  height=int(props.bbox[2] - props.bbox[0]),
173
- aspect_ratio=float(
174
- (props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])
175
- ),
176
  relative_position=relative_pos,
177
  mean_intensity=float(mean_intensity),
178
  std_intensity=float(std_intensity),
179
  confidence_score=float(confidence),
180
  )
181
 
182
- def _save_visualization(
183
- self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]
184
- ) -> str:
185
  """Save visualization of original image with segmentation masks overlaid."""
186
  plt.figure(figsize=(10, 10))
187
- plt.imshow(
188
- original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0]
189
- )
190
 
191
  # Generate color palette for organs
192
  colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
@@ -202,14 +188,10 @@ class ChestXRaySegmentationTool(BaseTool):
202
  # Create a colored overlay with transparency
203
  colored_mask = np.zeros((*original_img.shape, 4))
204
  colored_mask[mask > 0] = (*color[:3], 0.3)
205
- plt.imshow(
206
- colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0]
207
- )
208
 
209
  # Add legend entry for the organ
210
- organ_name = list(self.organ_map.keys())[
211
- list(self.organ_map.values()).index(organ_idx)
212
- ]
213
  plt.plot([], [], color=color, label=organ_name, linewidth=3)
214
 
215
  plt.title("Segmentation Overlay")
@@ -266,9 +248,7 @@ class ChestXRaySegmentationTool(BaseTool):
266
  for idx, organ_name in zip(organ_indices, organs):
267
  mask = pred_masks[0, idx].cpu().numpy()
268
  if mask.sum() > 0:
269
- metrics = self._compute_organ_metrics(
270
- mask, original_img, float(pred_probs[0, idx].mean().cpu())
271
- )
272
  if metrics:
273
  results[organ_name] = metrics
274
 
 
41
  area_pixels: int = Field(..., description="Area in pixels")
42
  area_cm2: float = Field(..., description="Approximate area in cm²")
43
  centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
44
+ bbox: Tuple[int, int, int, int] = Field(..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)")
 
 
45
 
46
  # Size metrics
47
  width: int = Field(..., description="Width of the organ in pixels")
 
49
  aspect_ratio: float = Field(..., description="Height/width ratio")
50
 
51
  # Position metrics
52
+ relative_position: Dict[str, float] = Field(..., description="Position relative to image boundaries (0-1 scale)")
 
 
53
 
54
  # Analysis metrics
55
  mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
 
86
  self.model = self.model.to(self.device)
87
  self.model.eval()
88
 
89
+ self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)])
 
 
90
 
91
  self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
92
  self.temp_dir.mkdir(exist_ok=True)
 
109
  "Spine": 13,
110
  }
111
 
112
+ def _align_mask_to_original(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray:
 
 
113
  """
114
  Align a mask from the transformed (cropped/resized) space back to the full original image.
115
  Assumes that the transform does a center crop to a square of side = min(original height, width)
 
162
  bbox=tuple(map(int, props.bbox)),
163
  width=int(props.bbox[3] - props.bbox[1]),
164
  height=int(props.bbox[2] - props.bbox[0]),
165
+ aspect_ratio=float((props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])),
 
 
166
  relative_position=relative_pos,
167
  mean_intensity=float(mean_intensity),
168
  std_intensity=float(std_intensity),
169
  confidence_score=float(confidence),
170
  )
171
 
172
+ def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
 
 
173
  """Save visualization of original image with segmentation masks overlaid."""
174
  plt.figure(figsize=(10, 10))
175
+ plt.imshow(original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0])
 
 
176
 
177
  # Generate color palette for organs
178
  colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
 
188
  # Create a colored overlay with transparency
189
  colored_mask = np.zeros((*original_img.shape, 4))
190
  colored_mask[mask > 0] = (*color[:3], 0.3)
191
+ plt.imshow(colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0])
 
 
192
 
193
  # Add legend entry for the organ
194
+ organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
 
 
195
  plt.plot([], [], color=color, label=organ_name, linewidth=3)
196
 
197
  plt.title("Segmentation Overlay")
 
248
  for idx, organ_name in zip(organ_indices, organs):
249
  mask = pred_masks[0, idx].cpu().numpy()
250
  if mask.sum() > 0:
251
+ metrics = self._compute_organ_metrics(mask, original_img, float(pred_probs[0, idx].mean().cpu()))
 
 
252
  if metrics:
253
  results[organ_name] = metrics
254
 
medrax/tools/utils.py CHANGED
@@ -16,18 +16,10 @@ class ImageVisualizerInput(BaseModel):
16
 
17
  image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
18
  title: Optional[str] = Field(None, description="Optional title to display above the image")
19
- description: Optional[str] = Field(
20
- None, description="Optional description to display below the image"
21
- )
22
- width: Optional[int] = Field(
23
- 10, description="Optional figure width in inches"
24
- )
25
- height: Optional[int] = Field(
26
- 10, description="Optional figure height in inches"
27
- )
28
- cmap: Optional[str] = Field(
29
- "rgb", description="Optional colormap to use for displaying the image"
30
- )
31
 
32
 
33
  class ImageVisualizerTool(BaseTool):
@@ -65,9 +57,7 @@ class ImageVisualizerTool(BaseTool):
65
 
66
  # Add description if provided
67
  if description:
68
- plt.figtext(
69
- 0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10
70
- )
71
 
72
  # Adjust margins to minimize whitespace while preventing overlap
73
  plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
 
16
 
17
  image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
18
  title: Optional[str] = Field(None, description="Optional title to display above the image")
19
+ description: Optional[str] = Field(None, description="Optional description to display below the image")
20
+ width: Optional[int] = Field(10, description="Optional figure width in inches")
21
+ height: Optional[int] = Field(10, description="Optional figure height in inches")
22
+ cmap: Optional[str] = Field("rgb", description="Optional colormap to use for displaying the image")
 
 
 
 
 
 
 
 
23
 
24
 
25
  class ImageVisualizerTool(BaseTool):
 
57
 
58
  # Add description if provided
59
  if description:
60
+ plt.figtext(0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10)
 
 
61
 
62
  # Adjust margins to minimize whitespace while preventing overlap
63
  plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
medrax/tools/vqa/__init__.py CHANGED
@@ -1,16 +1,16 @@
1
  """Visual Question Answering tools for medical images."""
2
 
3
  from .llava_med import LlavaMedTool, LlavaMedInput
4
- from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
5
  from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
6
  from .medgemma.medgemma_setup import setup_medgemma_env
7
 
8
  __all__ = [
9
  "LlavaMedTool",
10
  "LlavaMedInput",
11
- "CheXagentXRayVQATool",
12
  "XRayVQAToolInput",
13
  "MedGemmaAPIClientTool",
14
  "MedGemmaVQAInput",
15
- "setup_medgemma_env"
16
- ]
 
1
  """Visual Question Answering tools for medical images."""
2
 
3
  from .llava_med import LlavaMedTool, LlavaMedInput
4
+ from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
5
  from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
6
  from .medgemma.medgemma_setup import setup_medgemma_env
7
 
8
  __all__ = [
9
  "LlavaMedTool",
10
  "LlavaMedInput",
11
+ "CheXagentXRayVQATool",
12
  "XRayVQAToolInput",
13
  "MedGemmaAPIClientTool",
14
  "MedGemmaVQAInput",
15
+ "setup_medgemma_env",
16
+ ]
medrax/tools/vqa/llava_med.py CHANGED
@@ -83,13 +83,7 @@ class LlavaMedTool(BaseTool):
83
  self, question: str, image_path: Optional[str] = None
84
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
85
  if self.model.config.mm_use_im_start_end:
86
- question = (
87
- DEFAULT_IM_START_TOKEN
88
- + DEFAULT_IMAGE_TOKEN
89
- + DEFAULT_IM_END_TOKEN
90
- + "\n"
91
- + question
92
- )
93
  else:
94
  question = DEFAULT_IMAGE_TOKEN + "\n" + question
95
 
@@ -99,9 +93,7 @@ class LlavaMedTool(BaseTool):
99
  prompt = conv.get_prompt()
100
 
101
  input_ids = (
102
- tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
103
- .unsqueeze(0)
104
- .cuda()
105
  )
106
 
107
  image_tensor = None
@@ -147,11 +139,11 @@ class LlavaMedTool(BaseTool):
147
  )
148
 
149
  answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150
-
151
  output = {
152
  "answer": answer,
153
  }
154
-
155
  metadata = {
156
  "question": question,
157
  "image_path": image_path,
 
83
  self, question: str, image_path: Optional[str] = None
84
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
85
  if self.model.config.mm_use_im_start_end:
86
+ question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + question
 
 
 
 
 
 
87
  else:
88
  question = DEFAULT_IMAGE_TOKEN + "\n" + question
89
 
 
93
  prompt = conv.get_prompt()
94
 
95
  input_ids = (
96
+ tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
 
 
97
  )
98
 
99
  image_tensor = None
 
139
  )
140
 
141
  answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
142
+
143
  output = {
144
  "answer": answer,
145
  }
146
+
147
  metadata = {
148
  "question": question,
149
  "image_path": image_path,
medrax/tools/vqa/xray_vqa.py CHANGED
@@ -15,13 +15,9 @@ from langchain_core.tools import BaseTool
15
  class XRayVQAToolInput(BaseModel):
16
  """Input schema for the CheXagent Tool."""
17
 
18
- image_paths: List[str] = Field(
19
- ..., description="List of paths to chest X-ray images to analyze"
20
- )
21
  prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
22
- max_new_tokens: int = Field(
23
- 512, description="Maximum number of tokens to generate in the response"
24
- )
25
 
26
 
27
  class CheXagentXRayVQATool(BaseTool):
@@ -99,16 +95,14 @@ class CheXagentXRayVQATool(BaseTool):
99
  Returns:
100
  str: Model's response
101
  """
102
- query = self.tokenizer.from_list_format(
103
- [*[{"image": path} for path in image_paths], {"text": prompt}]
104
- )
105
  conv = [
106
  {"from": "system", "value": "You are a helpful assistant."},
107
  {"from": "human", "value": query},
108
  ]
109
- input_ids = self.tokenizer.apply_chat_template(
110
- conv, add_generation_prompt=True, return_tensors="pt"
111
- ).to(device=self.device)
112
 
113
  # Run inference
114
  with torch.inference_mode():
 
15
  class XRayVQAToolInput(BaseModel):
16
  """Input schema for the CheXagent Tool."""
17
 
18
+ image_paths: List[str] = Field(..., description="List of paths to chest X-ray images to analyze")
 
 
19
  prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
20
+ max_new_tokens: int = Field(512, description="Maximum number of tokens to generate in the response")
 
 
21
 
22
 
23
  class CheXagentXRayVQATool(BaseTool):
 
95
  Returns:
96
  str: Model's response
97
  """
98
+ query = self.tokenizer.from_list_format([*[{"image": path} for path in image_paths], {"text": prompt}])
 
 
99
  conv = [
100
  {"from": "system", "value": "You are a helpful assistant."},
101
  {"from": "human", "value": query},
102
  ]
103
+ input_ids = self.tokenizer.apply_chat_template(conv, add_generation_prompt=True, return_tensors="pt").to(
104
+ device=self.device
105
+ )
106
 
107
  # Run inference
108
  with torch.inference_mode():
medrax/tools/xray_generation.py CHANGED
@@ -11,26 +11,15 @@ from langchain_core.tools import BaseTool
11
 
12
  class ChestXRayGeneratorInput(BaseModel):
13
  """Input schema for the Chest X-Ray Generator Tool."""
14
-
15
  prompt: str = Field(
16
- ...,
17
- description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
18
- )
19
- height: int = Field(
20
- 512,
21
- description="Height of generated image in pixels"
22
- )
23
- width: int = Field(
24
- 512,
25
- description="Width of generated image in pixels"
26
- )
27
- num_inference_steps: int = Field(
28
- 75,
29
- description="Number of denoising steps (higher = better quality but slower)"
30
  )
 
 
 
31
  guidance_scale: float = Field(
32
- 4.0,
33
- description="How closely to follow the prompt (higher = more faithful but less diverse)"
34
  )
35
 
36
 
@@ -60,11 +49,11 @@ class ChestXRayGeneratorTool(BaseTool):
60
  ):
61
  """Initialize the chest X-ray generator tool."""
62
  super().__init__()
63
-
64
  self.device = torch.device(device) if device else "cuda"
65
  self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
66
  self.model = self.model.to(torch.float32).to(self.device)
67
-
68
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
69
  self.temp_dir.mkdir(exist_ok=True)
70
 
@@ -97,7 +86,7 @@ class ChestXRayGeneratorTool(BaseTool):
97
  num_inference_steps=num_inference_steps,
98
  height=height,
99
  width=width,
100
- guidance_scale=guidance_scale
101
  )
102
 
103
  # Save generated image
@@ -107,7 +96,7 @@ class ChestXRayGeneratorTool(BaseTool):
107
  output = {
108
  "image_path": str(image_path),
109
  }
110
-
111
  metadata = {
112
  "prompt": prompt,
113
  "num_inference_steps": num_inference_steps,
@@ -126,7 +115,7 @@ class ChestXRayGeneratorTool(BaseTool):
126
  "prompt": prompt,
127
  "analysis_status": "failed",
128
  "error_details": str(e),
129
- }
130
  )
131
 
132
  async def _arun(
@@ -139,4 +128,4 @@ class ChestXRayGeneratorTool(BaseTool):
139
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
140
  ) -> Tuple[Dict[str, str], Dict]:
141
  """Async version of _run."""
142
- return self._run(prompt, num_inference_steps, guidance_scale, height, width)
 
11
 
12
  class ChestXRayGeneratorInput(BaseModel):
13
  """Input schema for the Chest X-Ray Generator Tool."""
14
+
15
  prompt: str = Field(
16
+ ..., description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
+ height: int = Field(512, description="Height of generated image in pixels")
19
+ width: int = Field(512, description="Width of generated image in pixels")
20
+ num_inference_steps: int = Field(75, description="Number of denoising steps (higher = better quality but slower)")
21
  guidance_scale: float = Field(
22
+ 4.0, description="How closely to follow the prompt (higher = more faithful but less diverse)"
 
23
  )
24
 
25
 
 
49
  ):
50
  """Initialize the chest X-ray generator tool."""
51
  super().__init__()
52
+
53
  self.device = torch.device(device) if device else "cuda"
54
  self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
55
  self.model = self.model.to(torch.float32).to(self.device)
56
+
57
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
58
  self.temp_dir.mkdir(exist_ok=True)
59
 
 
86
  num_inference_steps=num_inference_steps,
87
  height=height,
88
  width=width,
89
+ guidance_scale=guidance_scale,
90
  )
91
 
92
  # Save generated image
 
96
  output = {
97
  "image_path": str(image_path),
98
  }
99
+
100
  metadata = {
101
  "prompt": prompt,
102
  "num_inference_steps": num_inference_steps,
 
115
  "prompt": prompt,
116
  "analysis_status": "failed",
117
  "error_details": str(e),
118
+ },
119
  )
120
 
121
  async def _arun(
 
128
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
129
  ) -> Tuple[Dict[str, str], Dict]:
130
  """Async version of _run."""
131
+ return self._run(prompt, num_inference_steps, guidance_scale, height, width)