Adibvafa
commited on
Commit
·
7393de8
1
Parent(s):
9a2c640
Improve style
Browse files- medrax/agent/agent.py +1 -4
- medrax/llava/conversation.py +1 -3
- medrax/llava/eval/eval_multimodal_chat_gpt_score.py +3 -6
- medrax/llava/eval/llm.py +8 -23
- medrax/llava/eval/model_vqa.py +2 -8
- medrax/llava/eval/summarize_gpt_review.py +3 -7
- medrax/llava/mm_utils.py +4 -14
- medrax/llava/model/builder.py +4 -12
- medrax/llava/model/language_model/llava_mistral.py +1 -3
- medrax/llava/model/llava_arch.py +13 -39
- medrax/llava/model/multimodal_encoder/builder.py +2 -8
- medrax/llava/model/multimodal_projector/builder.py +1 -3
- medrax/llava/serve/cli.py +1 -3
- medrax/llava/serve/controller.py +3 -6
- medrax/llava/serve/gradio_web_server.py +4 -12
- medrax/llava/serve/model_worker.py +6 -14
- medrax/llava/serve/test_message.py +2 -6
- medrax/llava/utils.py +1 -3
- medrax/models/model_factory.py +6 -15
- medrax/rag/rag.py +3 -9
- medrax/tools/browsing/__init__.py +3 -3
- medrax/tools/browsing/duckduckgo.py +12 -33
- medrax/tools/browsing/web_browser.py +3 -9
- medrax/tools/classification/__init__.py +1 -6
- medrax/tools/classification/arcplus.py +5 -17
- medrax/tools/classification/torchxrayvision.py +1 -3
- medrax/tools/dicom.py +1 -3
- medrax/tools/grounding.py +4 -13
- medrax/tools/rag.py +1 -1
- medrax/tools/report_generation.py +4 -14
- medrax/tools/segmentation/__init__.py +1 -7
- medrax/tools/segmentation/medsam2.py +70 -78
- medrax/tools/segmentation/segmentation.py +10 -30
- medrax/tools/utils.py +5 -15
- medrax/tools/vqa/__init__.py +4 -4
- medrax/tools/vqa/llava_med.py +4 -12
- medrax/tools/vqa/xray_vqa.py +6 -12
- medrax/tools/xray_generation.py +12 -23
medrax/agent/agent.py
CHANGED
|
@@ -62,9 +62,7 @@ class Agent:
|
|
| 62 |
workflow = StateGraph(AgentState)
|
| 63 |
workflow.add_node("agent", self.process_request)
|
| 64 |
workflow.add_node("tools", self.tool_node)
|
| 65 |
-
workflow.add_conditional_edges(
|
| 66 |
-
"agent", self.has_tool_calls, {True: "tools", False: END}
|
| 67 |
-
)
|
| 68 |
workflow.add_edge("tools", "agent")
|
| 69 |
workflow.set_entry_point("agent")
|
| 70 |
|
|
@@ -99,4 +97,3 @@ class Agent:
|
|
| 99 |
"""
|
| 100 |
response = state["messages"][-1]
|
| 101 |
return len(response.tool_calls) > 0
|
| 102 |
-
|
|
|
|
| 62 |
workflow = StateGraph(AgentState)
|
| 63 |
workflow.add_node("agent", self.process_request)
|
| 64 |
workflow.add_node("tools", self.tool_node)
|
| 65 |
+
workflow.add_conditional_edges("agent", self.has_tool_calls, {True: "tools", False: END})
|
|
|
|
|
|
|
| 66 |
workflow.add_edge("tools", "agent")
|
| 67 |
workflow.set_entry_point("agent")
|
| 68 |
|
|
|
|
| 97 |
"""
|
| 98 |
response = state["messages"][-1]
|
| 99 |
return len(response.tool_calls) > 0
|
|
|
medrax/llava/conversation.py
CHANGED
|
@@ -230,9 +230,7 @@ class Conversation:
|
|
| 230 |
buffered = BytesIO()
|
| 231 |
image.save(buffered, format="JPEG")
|
| 232 |
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 233 |
-
img_str =
|
| 234 |
-
f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
| 235 |
-
)
|
| 236 |
msg = img_str + msg.replace("<image>", "").strip()
|
| 237 |
ret.append([msg, None])
|
| 238 |
else:
|
|
|
|
| 230 |
buffered = BytesIO()
|
| 231 |
image.save(buffered, format="JPEG")
|
| 232 |
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 233 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
|
|
|
|
|
|
| 234 |
msg = img_str + msg.replace("<image>", "").strip()
|
| 235 |
ret.append([msg, None])
|
| 236 |
else:
|
medrax/llava/eval/eval_multimodal_chat_gpt_score.py
CHANGED
|
@@ -14,6 +14,7 @@ INSTRUCT_PROMPT = """We would like to request your feedback on the performance o
|
|
| 14 |
Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
|
| 15 |
ROLE = "Assistant"
|
| 16 |
|
|
|
|
| 17 |
# Generate instruction for GPT-4 to score the two answers.
|
| 18 |
def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
| 19 |
return (
|
|
@@ -127,17 +128,13 @@ def main(args):
|
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
| 129 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
|
| 130 |
-
parser.add_argument(
|
| 131 |
-
"--answers-file", default="", metavar="FILE", help="path to model answer file"
|
| 132 |
-
)
|
| 133 |
parser.add_argument(
|
| 134 |
"--question-file",
|
| 135 |
default="data/questions/llava_med_eval_qa50_qa.jsonl",
|
| 136 |
metavar="FILE",
|
| 137 |
help="path to multichat questions file",
|
| 138 |
)
|
| 139 |
-
parser.add_argument(
|
| 140 |
-
"--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file"
|
| 141 |
-
)
|
| 142 |
args = parser.parse_args()
|
| 143 |
main(args)
|
|
|
|
| 14 |
Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
|
| 15 |
ROLE = "Assistant"
|
| 16 |
|
| 17 |
+
|
| 18 |
# Generate instruction for GPT-4 to score the two answers.
|
| 19 |
def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
| 20 |
return (
|
|
|
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
| 130 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
|
| 131 |
+
parser.add_argument("--answers-file", default="", metavar="FILE", help="path to model answer file")
|
|
|
|
|
|
|
| 132 |
parser.add_argument(
|
| 133 |
"--question-file",
|
| 134 |
default="data/questions/llava_med_eval_qa50_qa.jsonl",
|
| 135 |
metavar="FILE",
|
| 136 |
help="path to multichat questions file",
|
| 137 |
)
|
| 138 |
+
parser.add_argument("--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file")
|
|
|
|
|
|
|
| 139 |
args = parser.parse_args()
|
| 140 |
main(args)
|
medrax/llava/eval/llm.py
CHANGED
|
@@ -21,9 +21,7 @@ class LLM(abc.ABC):
|
|
| 21 |
raise NotImplementedError("Subclasses should implement this!")
|
| 22 |
|
| 23 |
@abstractmethod
|
| 24 |
-
def split_input(
|
| 25 |
-
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
|
| 26 |
-
):
|
| 27 |
raise NotImplementedError("Subclasses should implement this!")
|
| 28 |
|
| 29 |
|
|
@@ -49,9 +47,7 @@ class GPT(LLM):
|
|
| 49 |
def __init__(self, model_id):
|
| 50 |
self.temperature = 0.0
|
| 51 |
self.top_k = 1
|
| 52 |
-
self.encoding = tiktoken.encoding_for_model(
|
| 53 |
-
"-".join(model_id.split("-", 2)[:2]).replace("5", ".5")
|
| 54 |
-
)
|
| 55 |
self.openai_api = "default"
|
| 56 |
self.model_id = model_id
|
| 57 |
self.max_length = self.deployment_max_length_dict[model_id]
|
|
@@ -61,9 +57,7 @@ class GPT(LLM):
|
|
| 61 |
azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
|
| 62 |
)
|
| 63 |
|
| 64 |
-
def gen_messages(
|
| 65 |
-
self, fixed_instruction, few_shot_examples, input, input_header, output_header
|
| 66 |
-
):
|
| 67 |
messages = [
|
| 68 |
{
|
| 69 |
"role": "system",
|
|
@@ -120,18 +114,13 @@ class GPT(LLM):
|
|
| 120 |
):
|
| 121 |
return asyncio.run(self.dispatch_openai_requests(messages_list))
|
| 122 |
|
| 123 |
-
def split_input(
|
| 124 |
-
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
|
| 125 |
-
):
|
| 126 |
# Tokenize fixed_prompt
|
| 127 |
fixed_token_ids = self.encoding.encode(
|
| 128 |
-
fixed_instruction
|
| 129 |
-
+ " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
|
| 130 |
)
|
| 131 |
# Calculate remaining token length
|
| 132 |
-
remaining_token_len = math.ceil(
|
| 133 |
-
(self.prompt_percent * self.max_length) - len(fixed_token_ids)
|
| 134 |
-
)
|
| 135 |
|
| 136 |
# Tokenize splittable_input
|
| 137 |
split_token_ids = self.encoding.encode(splittable_input)
|
|
@@ -141,14 +130,10 @@ class GPT(LLM):
|
|
| 141 |
split_token_ids[i : i + remaining_token_len + 10]
|
| 142 |
for i in range(0, len(split_token_ids), remaining_token_len)
|
| 143 |
]
|
| 144 |
-
split_input_list = [
|
| 145 |
-
self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list
|
| 146 |
-
]
|
| 147 |
|
| 148 |
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
|
| 149 |
return [
|
| 150 |
-
self.gen_messages(
|
| 151 |
-
fixed_instruction, few_shot_examples, split_input, input_header, output_header
|
| 152 |
-
)
|
| 153 |
for split_input in split_input_list
|
| 154 |
]
|
|
|
|
| 21 |
raise NotImplementedError("Subclasses should implement this!")
|
| 22 |
|
| 23 |
@abstractmethod
|
| 24 |
+
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
|
|
|
|
|
|
|
| 25 |
raise NotImplementedError("Subclasses should implement this!")
|
| 26 |
|
| 27 |
|
|
|
|
| 47 |
def __init__(self, model_id):
|
| 48 |
self.temperature = 0.0
|
| 49 |
self.top_k = 1
|
| 50 |
+
self.encoding = tiktoken.encoding_for_model("-".join(model_id.split("-", 2)[:2]).replace("5", ".5"))
|
|
|
|
|
|
|
| 51 |
self.openai_api = "default"
|
| 52 |
self.model_id = model_id
|
| 53 |
self.max_length = self.deployment_max_length_dict[model_id]
|
|
|
|
| 57 |
azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
|
| 58 |
)
|
| 59 |
|
| 60 |
+
def gen_messages(self, fixed_instruction, few_shot_examples, input, input_header, output_header):
|
|
|
|
|
|
|
| 61 |
messages = [
|
| 62 |
{
|
| 63 |
"role": "system",
|
|
|
|
| 114 |
):
|
| 115 |
return asyncio.run(self.dispatch_openai_requests(messages_list))
|
| 116 |
|
| 117 |
+
def split_input(self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header):
|
|
|
|
|
|
|
| 118 |
# Tokenize fixed_prompt
|
| 119 |
fixed_token_ids = self.encoding.encode(
|
| 120 |
+
fixed_instruction + " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
|
|
|
|
| 121 |
)
|
| 122 |
# Calculate remaining token length
|
| 123 |
+
remaining_token_len = math.ceil((self.prompt_percent * self.max_length) - len(fixed_token_ids))
|
|
|
|
|
|
|
| 124 |
|
| 125 |
# Tokenize splittable_input
|
| 126 |
split_token_ids = self.encoding.encode(splittable_input)
|
|
|
|
| 130 |
split_token_ids[i : i + remaining_token_len + 10]
|
| 131 |
for i in range(0, len(split_token_ids), remaining_token_len)
|
| 132 |
]
|
| 133 |
+
split_input_list = [self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list]
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
|
| 136 |
return [
|
| 137 |
+
self.gen_messages(fixed_instruction, few_shot_examples, split_input, input_header, output_header)
|
|
|
|
|
|
|
| 138 |
for split_input in split_input_list
|
| 139 |
]
|
medrax/llava/eval/model_vqa.py
CHANGED
|
@@ -45,9 +45,7 @@ def eval_model(args):
|
|
| 45 |
disable_torch_init()
|
| 46 |
model_path = os.path.expanduser(args.model_path)
|
| 47 |
model_name = get_model_name_from_path(model_path)
|
| 48 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
| 49 |
-
model_path, args.model_base, model_name
|
| 50 |
-
)
|
| 51 |
|
| 52 |
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
| 53 |
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
|
@@ -69,11 +67,7 @@ def eval_model(args):
|
|
| 69 |
conv.append_message(conv.roles[1], None)
|
| 70 |
prompt = conv.get_prompt()
|
| 71 |
|
| 72 |
-
input_ids = (
|
| 73 |
-
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
| 74 |
-
.unsqueeze(0)
|
| 75 |
-
.cuda()
|
| 76 |
-
)
|
| 77 |
|
| 78 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
| 79 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
|
|
|
| 45 |
disable_torch_init()
|
| 46 |
model_path = os.path.expanduser(args.model_path)
|
| 47 |
model_name = get_model_name_from_path(model_path)
|
| 48 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
|
|
|
|
|
|
|
| 49 |
|
| 50 |
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
| 51 |
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
|
|
|
| 67 |
conv.append_message(conv.roles[1], None)
|
| 68 |
prompt = conv.get_prompt()
|
| 69 |
|
| 70 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
| 73 |
image_tensor = process_images([image], image_processor, model.config)[0]
|
medrax/llava/eval/summarize_gpt_review.py
CHANGED
|
@@ -14,8 +14,7 @@ def get_domain(x):
|
|
| 14 |
def main(args):
|
| 15 |
scores_data = util.load_file_jsonl(args.scores_file)
|
| 16 |
predictions = [
|
| 17 |
-
(x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" "))
|
| 18 |
-
for x in scores_data
|
| 19 |
]
|
| 20 |
|
| 21 |
score_type_dict = defaultdict(lambda: defaultdict(list))
|
|
@@ -33,8 +32,7 @@ def main(args):
|
|
| 33 |
result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
|
| 34 |
result[q_type]["pred_score"] = util.get_avg(score_dict[2])
|
| 35 |
result[q_type]["pred_relative_score"] = (
|
| 36 |
-
util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])])
|
| 37 |
-
* 100
|
| 38 |
)
|
| 39 |
result[q_type]["data_size"] = len(score_dict[1])
|
| 40 |
|
|
@@ -55,8 +53,6 @@ def main(args):
|
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
| 57 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
|
| 58 |
-
parser.add_argument(
|
| 59 |
-
"--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file"
|
| 60 |
-
)
|
| 61 |
args = parser.parse_args()
|
| 62 |
main(args)
|
|
|
|
| 14 |
def main(args):
|
| 15 |
scores_data = util.load_file_jsonl(args.scores_file)
|
| 16 |
predictions = [
|
| 17 |
+
(x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" ")) for x in scores_data
|
|
|
|
| 18 |
]
|
| 19 |
|
| 20 |
score_type_dict = defaultdict(lambda: defaultdict(list))
|
|
|
|
| 32 |
result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
|
| 33 |
result[q_type]["pred_score"] = util.get_avg(score_dict[2])
|
| 34 |
result[q_type]["pred_relative_score"] = (
|
| 35 |
+
util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])]) * 100
|
|
|
|
| 36 |
)
|
| 37 |
result[q_type]["data_size"] = len(score_dict[1])
|
| 38 |
|
|
|
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
| 55 |
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
|
| 56 |
+
parser.add_argument("--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file")
|
|
|
|
|
|
|
| 57 |
args = parser.parse_args()
|
| 58 |
main(args)
|
medrax/llava/mm_utils.py
CHANGED
|
@@ -35,9 +35,7 @@ def process_images(images, image_processor, model_cfg):
|
|
| 35 |
for image in images:
|
| 36 |
if image_aspect_ratio == "pad":
|
| 37 |
if image.mode == "L":
|
| 38 |
-
background_color = int(
|
| 39 |
-
255 * sum(image_processor.image_mean) / len(image_processor.image_mean)
|
| 40 |
-
)
|
| 41 |
else:
|
| 42 |
background_color = tuple(int(x * 255) for x in image_processor.image_mean)
|
| 43 |
image = expand2square(image, background_color)
|
|
@@ -48,9 +46,7 @@ def process_images(images, image_processor, model_cfg):
|
|
| 48 |
return new_images
|
| 49 |
|
| 50 |
|
| 51 |
-
def tokenizer_image_token(
|
| 52 |
-
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
|
| 53 |
-
):
|
| 54 |
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 55 |
|
| 56 |
def insert_separator(X, sep):
|
|
@@ -58,11 +54,7 @@ def tokenizer_image_token(
|
|
| 58 |
|
| 59 |
input_ids = []
|
| 60 |
offset = 0
|
| 61 |
-
if (
|
| 62 |
-
len(prompt_chunks) > 0
|
| 63 |
-
and len(prompt_chunks[0]) > 0
|
| 64 |
-
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
| 65 |
-
):
|
| 66 |
offset = 1
|
| 67 |
input_ids.append(prompt_chunks[0][0])
|
| 68 |
|
|
@@ -100,9 +92,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|
| 100 |
self.tokenizer = tokenizer
|
| 101 |
self.start_len = input_ids.shape[1]
|
| 102 |
|
| 103 |
-
def call_for_batch(
|
| 104 |
-
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 105 |
-
) -> bool:
|
| 106 |
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 107 |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 108 |
for keyword_id in self.keyword_ids:
|
|
|
|
| 35 |
for image in images:
|
| 36 |
if image_aspect_ratio == "pad":
|
| 37 |
if image.mode == "L":
|
| 38 |
+
background_color = int(255 * sum(image_processor.image_mean) / len(image_processor.image_mean))
|
|
|
|
|
|
|
| 39 |
else:
|
| 40 |
background_color = tuple(int(x * 255) for x in image_processor.image_mean)
|
| 41 |
image = expand2square(image, background_color)
|
|
|
|
| 46 |
return new_images
|
| 47 |
|
| 48 |
|
| 49 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
|
|
|
|
|
|
| 50 |
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 51 |
|
| 52 |
def insert_separator(X, sep):
|
|
|
|
| 54 |
|
| 55 |
input_ids = []
|
| 56 |
offset = 0
|
| 57 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
offset = 1
|
| 59 |
input_ids.append(prompt_chunks[0][0])
|
| 60 |
|
|
|
|
| 92 |
self.tokenizer = tokenizer
|
| 93 |
self.start_len = input_ids.shape[1]
|
| 94 |
|
| 95 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
|
|
|
| 96 |
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 97 |
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 98 |
for keyword_id in self.keyword_ids:
|
medrax/llava/model/builder.py
CHANGED
|
@@ -59,9 +59,7 @@ def load_pretrained_model(
|
|
| 59 |
# PEFT model
|
| 60 |
from peft import PeftModel
|
| 61 |
|
| 62 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 63 |
-
model_base, use_fast=False, cache_dir=cache_dir
|
| 64 |
-
)
|
| 65 |
model = AutoModelForCausalLM.from_pretrained(
|
| 66 |
model_base,
|
| 67 |
low_cpu_mem_usage=True,
|
|
@@ -78,9 +76,7 @@ def load_pretrained_model(
|
|
| 78 |
else:
|
| 79 |
use_fast = False
|
| 80 |
if "mpt" in model_name.lower():
|
| 81 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 82 |
-
model_path, use_fast=True, cache_dir=cache_dir
|
| 83 |
-
)
|
| 84 |
model = AutoModelForCausalLM.from_pretrained(
|
| 85 |
model_path,
|
| 86 |
low_cpu_mem_usage=True,
|
|
@@ -90,9 +86,7 @@ def load_pretrained_model(
|
|
| 90 |
**kwargs,
|
| 91 |
)
|
| 92 |
else:
|
| 93 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 94 |
-
model_path, use_fast=False, cache_dir=cache_dir
|
| 95 |
-
)
|
| 96 |
model = AutoModelForCausalLM.from_pretrained(
|
| 97 |
model_path,
|
| 98 |
low_cpu_mem_usage=True,
|
|
@@ -109,9 +103,7 @@ def load_pretrained_model(
|
|
| 109 |
if mm_use_im_patch_token:
|
| 110 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 111 |
if mm_use_im_start_end:
|
| 112 |
-
tokenizer.add_tokens(
|
| 113 |
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
| 114 |
-
)
|
| 115 |
model.resize_token_embeddings(len(tokenizer))
|
| 116 |
|
| 117 |
vision_tower = model.get_vision_tower()
|
|
|
|
| 59 |
# PEFT model
|
| 60 |
from peft import PeftModel
|
| 61 |
|
| 62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, cache_dir=cache_dir)
|
|
|
|
|
|
|
| 63 |
model = AutoModelForCausalLM.from_pretrained(
|
| 64 |
model_base,
|
| 65 |
low_cpu_mem_usage=True,
|
|
|
|
| 76 |
else:
|
| 77 |
use_fast = False
|
| 78 |
if "mpt" in model_name.lower():
|
| 79 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, cache_dir=cache_dir)
|
|
|
|
|
|
|
| 80 |
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
model_path,
|
| 82 |
low_cpu_mem_usage=True,
|
|
|
|
| 86 |
**kwargs,
|
| 87 |
)
|
| 88 |
else:
|
| 89 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, cache_dir=cache_dir)
|
|
|
|
|
|
|
| 90 |
model = AutoModelForCausalLM.from_pretrained(
|
| 91 |
model_path,
|
| 92 |
low_cpu_mem_usage=True,
|
|
|
|
| 103 |
if mm_use_im_patch_token:
|
| 104 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
| 105 |
if mm_use_im_start_end:
|
| 106 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
|
|
|
|
|
|
| 107 |
model.resize_token_embeddings(len(tokenizer))
|
| 108 |
|
| 109 |
vision_tower = model.get_vision_tower()
|
medrax/llava/model/language_model/llava_mistral.py
CHANGED
|
@@ -125,9 +125,7 @@ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
|
| 125 |
**kwargs,
|
| 126 |
)
|
| 127 |
|
| 128 |
-
def prepare_inputs_for_generation(
|
| 129 |
-
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
| 130 |
-
):
|
| 131 |
images = kwargs.pop("images", None)
|
| 132 |
image_sizes = kwargs.pop("image_sizes", None)
|
| 133 |
inputs = super().prepare_inputs_for_generation(
|
|
|
|
| 125 |
**kwargs,
|
| 126 |
)
|
| 127 |
|
| 128 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
|
|
|
|
|
|
| 129 |
images = kwargs.pop("images", None)
|
| 130 |
image_sizes = kwargs.pop("image_sizes", None)
|
| 131 |
inputs = super().prepare_inputs_for_generation(
|
medrax/llava/model/llava_arch.py
CHANGED
|
@@ -104,9 +104,7 @@ class LlavaMetaModel:
|
|
| 104 |
checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
|
| 105 |
ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
|
| 106 |
if len(ckpts) > 0:
|
| 107 |
-
vision_module_weights = torch.load(
|
| 108 |
-
f"{ckpts[-1]}/mm_projector.bin", map_location="cpu"
|
| 109 |
-
)
|
| 110 |
model_dict = get_w(vision_module_weights, "vision_tower")
|
| 111 |
print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
|
| 112 |
# print keys in model_dict
|
|
@@ -170,9 +168,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 170 |
image_features = self.encode_images(images).to(self.device)
|
| 171 |
|
| 172 |
# TODO: image start / end is not implemented here to support pretraining.
|
| 173 |
-
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
| 174 |
-
self.config, "mm_use_im_start_end", False
|
| 175 |
-
):
|
| 176 |
raise NotImplementedError
|
| 177 |
|
| 178 |
# Let's just add dummy tensors if they do not exist,
|
|
@@ -188,21 +184,15 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 188 |
else:
|
| 189 |
attention_mask = attention_mask.bool()
|
| 190 |
if position_ids is None:
|
| 191 |
-
position_ids = torch.arange(
|
| 192 |
-
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
| 193 |
-
)
|
| 194 |
|
| 195 |
if labels is None:
|
| 196 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 197 |
|
| 198 |
input_ids = [
|
| 199 |
-
cur_input_ids[cur_attention_mask]
|
| 200 |
-
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
| 201 |
-
]
|
| 202 |
-
labels = [
|
| 203 |
-
cur_labels[cur_attention_mask]
|
| 204 |
-
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
|
| 205 |
]
|
|
|
|
| 206 |
|
| 207 |
new_input_embeds = []
|
| 208 |
new_labels = []
|
|
@@ -219,20 +209,14 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 219 |
continue
|
| 220 |
|
| 221 |
image_token_indices = (
|
| 222 |
-
[-1]
|
| 223 |
-
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
|
| 224 |
-
+ [cur_input_ids.shape[0]]
|
| 225 |
)
|
| 226 |
cur_input_ids_noim = []
|
| 227 |
cur_labels = labels[batch_idx]
|
| 228 |
cur_labels_noim = []
|
| 229 |
for i in range(len(image_token_indices) - 1):
|
| 230 |
-
cur_input_ids_noim.append(
|
| 231 |
-
|
| 232 |
-
)
|
| 233 |
-
cur_labels_noim.append(
|
| 234 |
-
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
|
| 235 |
-
)
|
| 236 |
|
| 237 |
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
| 238 |
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
|
@@ -279,12 +263,8 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 279 |
dtype=new_labels[0].dtype,
|
| 280 |
device=new_labels[0].device,
|
| 281 |
)
|
| 282 |
-
attention_mask = torch.zeros(
|
| 283 |
-
|
| 284 |
-
)
|
| 285 |
-
position_ids = torch.zeros(
|
| 286 |
-
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
|
| 287 |
-
)
|
| 288 |
|
| 289 |
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
| 290 |
cur_len = cur_new_embed.shape[0]
|
|
@@ -351,9 +331,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 351 |
self.resize_token_embeddings(len(tokenizer))
|
| 352 |
|
| 353 |
if model_args.mm_use_im_start_end:
|
| 354 |
-
num_new_tokens = tokenizer.add_tokens(
|
| 355 |
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
| 356 |
-
)
|
| 357 |
self.resize_token_embeddings(len(tokenizer))
|
| 358 |
|
| 359 |
if num_new_tokens > 0:
|
|
@@ -361,9 +339,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 361 |
output_embeddings = self.get_output_embeddings().weight.data
|
| 362 |
|
| 363 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 364 |
-
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
| 365 |
-
dim=0, keepdim=True
|
| 366 |
-
)
|
| 367 |
|
| 368 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 369 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
@@ -375,9 +351,7 @@ class LlavaMetaForCausalLM(ABC):
|
|
| 375 |
p.requires_grad = False
|
| 376 |
|
| 377 |
if model_args.pretrain_mm_mlp_adapter:
|
| 378 |
-
mm_projector_weights = torch.load(
|
| 379 |
-
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
|
| 380 |
-
)
|
| 381 |
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
| 382 |
assert num_new_tokens == 2
|
| 383 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
|
|
|
| 104 |
checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
|
| 105 |
ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
|
| 106 |
if len(ckpts) > 0:
|
| 107 |
+
vision_module_weights = torch.load(f"{ckpts[-1]}/mm_projector.bin", map_location="cpu")
|
|
|
|
|
|
|
| 108 |
model_dict = get_w(vision_module_weights, "vision_tower")
|
| 109 |
print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
|
| 110 |
# print keys in model_dict
|
|
|
|
| 168 |
image_features = self.encode_images(images).to(self.device)
|
| 169 |
|
| 170 |
# TODO: image start / end is not implemented here to support pretraining.
|
| 171 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
|
|
|
|
|
|
|
| 172 |
raise NotImplementedError
|
| 173 |
|
| 174 |
# Let's just add dummy tensors if they do not exist,
|
|
|
|
| 184 |
else:
|
| 185 |
attention_mask = attention_mask.bool()
|
| 186 |
if position_ids is None:
|
| 187 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
|
|
|
|
|
|
| 188 |
|
| 189 |
if labels is None:
|
| 190 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
| 191 |
|
| 192 |
input_ids = [
|
| 193 |
+
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
]
|
| 195 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
| 196 |
|
| 197 |
new_input_embeds = []
|
| 198 |
new_labels = []
|
|
|
|
| 209 |
continue
|
| 210 |
|
| 211 |
image_token_indices = (
|
| 212 |
+
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
|
|
|
|
|
|
| 213 |
)
|
| 214 |
cur_input_ids_noim = []
|
| 215 |
cur_labels = labels[batch_idx]
|
| 216 |
cur_labels_noim = []
|
| 217 |
for i in range(len(image_token_indices) - 1):
|
| 218 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
| 219 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
| 222 |
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
|
|
|
| 263 |
dtype=new_labels[0].dtype,
|
| 264 |
device=new_labels[0].device,
|
| 265 |
)
|
| 266 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 267 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
| 270 |
cur_len = cur_new_embed.shape[0]
|
|
|
|
| 331 |
self.resize_token_embeddings(len(tokenizer))
|
| 332 |
|
| 333 |
if model_args.mm_use_im_start_end:
|
| 334 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
|
|
|
|
|
|
| 335 |
self.resize_token_embeddings(len(tokenizer))
|
| 336 |
|
| 337 |
if num_new_tokens > 0:
|
|
|
|
| 339 |
output_embeddings = self.get_output_embeddings().weight.data
|
| 340 |
|
| 341 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 342 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
|
|
|
|
|
| 343 |
|
| 344 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 345 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
|
|
| 351 |
p.requires_grad = False
|
| 352 |
|
| 353 |
if model_args.pretrain_mm_mlp_adapter:
|
| 354 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
|
|
|
|
|
|
|
| 355 |
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
| 356 |
assert num_new_tokens == 2
|
| 357 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
medrax/llava/model/multimodal_encoder/builder.py
CHANGED
|
@@ -3,13 +3,7 @@ from .clip_encoder import CLIPVisionTower
|
|
| 3 |
|
| 4 |
|
| 5 |
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 6 |
-
vision_tower = getattr(
|
| 7 |
-
vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
|
| 8 |
-
)
|
| 9 |
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 10 |
-
if (
|
| 11 |
-
is_absolute_path_exists
|
| 12 |
-
or vision_tower.startswith("openai")
|
| 13 |
-
or vision_tower.startswith("laion")
|
| 14 |
-
):
|
| 15 |
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def build_vision_tower(vision_tower_cfg, **kwargs):
|
| 6 |
+
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
|
|
|
|
|
|
|
| 7 |
is_absolute_path_exists = os.path.exists(vision_tower)
|
| 8 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
medrax/llava/model/multimodal_projector/builder.py
CHANGED
|
@@ -19,9 +19,7 @@ class SimpleResBlock(nn.Module):
|
|
| 19 |
super().__init__()
|
| 20 |
self.pre_norm = nn.LayerNorm(channels)
|
| 21 |
|
| 22 |
-
self.proj = nn.Sequential(
|
| 23 |
-
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
|
| 24 |
-
)
|
| 25 |
|
| 26 |
def forward(self, x):
|
| 27 |
x = self.pre_norm(x)
|
|
|
|
| 19 |
super().__init__()
|
| 20 |
self.pre_norm = nn.LayerNorm(channels)
|
| 21 |
|
| 22 |
+
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def forward(self, x):
|
| 25 |
x = self.pre_norm(x)
|
medrax/llava/serve/cli.py
CHANGED
|
@@ -94,9 +94,7 @@ def main(args):
|
|
| 94 |
if image is not None:
|
| 95 |
# first message
|
| 96 |
if model.config.mm_use_im_start_end:
|
| 97 |
-
inp =
|
| 98 |
-
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
|
| 99 |
-
)
|
| 100 |
else:
|
| 101 |
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
|
| 102 |
conv.append_message(conv.roles[0], inp)
|
|
|
|
| 94 |
if image is not None:
|
| 95 |
# first message
|
| 96 |
if model.config.mm_use_im_start_end:
|
| 97 |
+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
|
|
|
|
|
|
|
| 98 |
else:
|
| 99 |
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
|
| 100 |
conv.append_message(conv.roles[0], inp)
|
medrax/llava/serve/controller.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
A controller manages distributed workers.
|
| 3 |
It sends worker addresses to clients.
|
| 4 |
"""
|
|
|
|
| 5 |
import argparse
|
| 6 |
import dataclasses
|
| 7 |
from enum import Enum, auto
|
|
@@ -199,9 +200,7 @@ class Controller:
|
|
| 199 |
yield json.dumps(ret).encode() + b"\0"
|
| 200 |
|
| 201 |
try:
|
| 202 |
-
response = requests.post(
|
| 203 |
-
worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5
|
| 204 |
-
)
|
| 205 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
| 206 |
if chunk:
|
| 207 |
yield chunk + b"\0"
|
|
@@ -240,9 +239,7 @@ app = FastAPI()
|
|
| 240 |
@app.post("/register_worker")
|
| 241 |
async def register_worker(request: Request):
|
| 242 |
data = await request.json()
|
| 243 |
-
controller.register_worker(
|
| 244 |
-
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
|
| 245 |
-
)
|
| 246 |
|
| 247 |
|
| 248 |
@app.post("/refresh_all_workers")
|
|
|
|
| 2 |
A controller manages distributed workers.
|
| 3 |
It sends worker addresses to clients.
|
| 4 |
"""
|
| 5 |
+
|
| 6 |
import argparse
|
| 7 |
import dataclasses
|
| 8 |
from enum import Enum, auto
|
|
|
|
| 200 |
yield json.dumps(ret).encode() + b"\0"
|
| 201 |
|
| 202 |
try:
|
| 203 |
+
response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5)
|
|
|
|
|
|
|
| 204 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
| 205 |
if chunk:
|
| 206 |
yield chunk + b"\0"
|
|
|
|
| 239 |
@app.post("/register_worker")
|
| 240 |
async def register_worker(request: Request):
|
| 241 |
data = await request.json()
|
| 242 |
+
controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None))
|
|
|
|
|
|
|
| 243 |
|
| 244 |
|
| 245 |
@app.post("/refresh_all_workers")
|
medrax/llava/serve/gradio_web_server.py
CHANGED
|
@@ -216,9 +216,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
|
|
| 216 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
| 217 |
for image, hash in zip(all_images, all_image_hash):
|
| 218 |
t = datetime.datetime.now()
|
| 219 |
-
filename = os.path.join(
|
| 220 |
-
LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
|
| 221 |
-
)
|
| 222 |
if not os.path.isfile(filename):
|
| 223 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 224 |
image.save(filename)
|
|
@@ -230,9 +228,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
|
|
| 230 |
"temperature": float(temperature),
|
| 231 |
"top_p": float(top_p),
|
| 232 |
"max_new_tokens": min(int(max_new_tokens), 1536),
|
| 233 |
-
"stop": state.sep
|
| 234 |
-
if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
|
| 235 |
-
else state.sep2,
|
| 236 |
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
|
| 237 |
}
|
| 238 |
logger.info(f"==== request ====\n{pload}")
|
|
@@ -330,9 +326,7 @@ block_css = """
|
|
| 330 |
|
| 331 |
|
| 332 |
def build_demo(embed_mode):
|
| 333 |
-
textbox = gr.Textbox(
|
| 334 |
-
show_label=False, placeholder="Enter text and press ENTER", container=False
|
| 335 |
-
)
|
| 336 |
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
| 337 |
state = gr.State()
|
| 338 |
|
|
@@ -468,9 +462,7 @@ def build_demo(embed_mode):
|
|
| 468 |
[state, chatbot] + btn_list,
|
| 469 |
)
|
| 470 |
|
| 471 |
-
clear_btn.click(
|
| 472 |
-
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False
|
| 473 |
-
)
|
| 474 |
|
| 475 |
textbox.submit(
|
| 476 |
add_text,
|
|
|
|
| 216 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
| 217 |
for image, hash in zip(all_images, all_image_hash):
|
| 218 |
t = datetime.datetime.now()
|
| 219 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
|
|
|
|
|
|
| 220 |
if not os.path.isfile(filename):
|
| 221 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
| 222 |
image.save(filename)
|
|
|
|
| 228 |
"temperature": float(temperature),
|
| 229 |
"top_p": float(top_p),
|
| 230 |
"max_new_tokens": min(int(max_new_tokens), 1536),
|
| 231 |
+
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
|
|
|
|
|
|
| 232 |
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
|
| 233 |
}
|
| 234 |
logger.info(f"==== request ====\n{pload}")
|
|
|
|
| 326 |
|
| 327 |
|
| 328 |
def build_demo(embed_mode):
|
| 329 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
|
|
|
|
|
|
| 330 |
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
| 331 |
state = gr.State()
|
| 332 |
|
|
|
|
| 462 |
[state, chatbot] + btn_list,
|
| 463 |
)
|
| 464 |
|
| 465 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
|
|
|
|
|
|
|
| 466 |
|
| 467 |
textbox.submit(
|
| 468 |
add_text,
|
medrax/llava/serve/model_worker.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
A model worker executes the model.
|
| 3 |
"""
|
|
|
|
| 4 |
import argparse
|
| 5 |
import asyncio
|
| 6 |
import json
|
|
@@ -155,9 +156,7 @@ class ModelWorker:
|
|
| 155 |
if images is not None and len(images) > 0 and self.is_multimodal:
|
| 156 |
if len(images) > 0:
|
| 157 |
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
| 158 |
-
raise ValueError(
|
| 159 |
-
"Number of images does not match number of <image> tokens in prompt"
|
| 160 |
-
)
|
| 161 |
|
| 162 |
images = [load_image_from_base64(image) for image in images]
|
| 163 |
images = process_images(images, image_processor, model.config)
|
|
@@ -172,9 +171,7 @@ class ModelWorker:
|
|
| 172 |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 173 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 174 |
|
| 175 |
-
num_image_tokens = (
|
| 176 |
-
prompt.count(replace_token) * model.get_vision_tower().num_patches
|
| 177 |
-
)
|
| 178 |
else:
|
| 179 |
images = None
|
| 180 |
image_args = {"images": images}
|
|
@@ -196,19 +193,14 @@ class ModelWorker:
|
|
| 196 |
)
|
| 197 |
keywords = [stop_str]
|
| 198 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 199 |
-
streamer = TextIteratorStreamer(
|
| 200 |
-
tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
| 201 |
-
)
|
| 202 |
|
| 203 |
-
max_new_tokens = min(
|
| 204 |
-
max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
|
| 205 |
-
)
|
| 206 |
|
| 207 |
if max_new_tokens < 1:
|
| 208 |
yield json.dumps(
|
| 209 |
{
|
| 210 |
-
"text": ori_prompt
|
| 211 |
-
+ "Exceeds max token length. Please start a new conversation, thanks.",
|
| 212 |
"error_code": 0,
|
| 213 |
}
|
| 214 |
).encode() + b"\0"
|
|
|
|
| 1 |
"""
|
| 2 |
A model worker executes the model.
|
| 3 |
"""
|
| 4 |
+
|
| 5 |
import argparse
|
| 6 |
import asyncio
|
| 7 |
import json
|
|
|
|
| 156 |
if images is not None and len(images) > 0 and self.is_multimodal:
|
| 157 |
if len(images) > 0:
|
| 158 |
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
| 159 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
|
|
|
|
|
|
| 160 |
|
| 161 |
images = [load_image_from_base64(image) for image in images]
|
| 162 |
images = process_images(images, image_processor, model.config)
|
|
|
|
| 171 |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
| 172 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 173 |
|
| 174 |
+
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
|
|
|
|
|
|
|
| 175 |
else:
|
| 176 |
images = None
|
| 177 |
image_args = {"images": images}
|
|
|
|
| 193 |
)
|
| 194 |
keywords = [stop_str]
|
| 195 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 196 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
|
|
|
|
|
|
| 199 |
|
| 200 |
if max_new_tokens < 1:
|
| 201 |
yield json.dumps(
|
| 202 |
{
|
| 203 |
+
"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.",
|
|
|
|
| 204 |
"error_code": 0,
|
| 205 |
}
|
| 206 |
).encode() + b"\0"
|
medrax/llava/serve/test_message.py
CHANGED
|
@@ -17,9 +17,7 @@ def main():
|
|
| 17 |
models.sort()
|
| 18 |
print(f"Models: {models}")
|
| 19 |
|
| 20 |
-
ret = requests.post(
|
| 21 |
-
controller_addr + "/get_worker_address", json={"model": args.model_name}
|
| 22 |
-
)
|
| 23 |
worker_addr = ret.json()["address"]
|
| 24 |
print(f"worker_addr: {worker_addr}")
|
| 25 |
|
|
@@ -38,9 +36,7 @@ def main():
|
|
| 38 |
"temperature": 0.7,
|
| 39 |
"stop": conv.sep2,
|
| 40 |
}
|
| 41 |
-
response = requests.post(
|
| 42 |
-
worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True
|
| 43 |
-
)
|
| 44 |
|
| 45 |
print(prompt, end="")
|
| 46 |
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
|
|
|
| 17 |
models.sort()
|
| 18 |
print(f"Models: {models}")
|
| 19 |
|
| 20 |
+
ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name})
|
|
|
|
|
|
|
| 21 |
worker_addr = ret.json()["address"]
|
| 22 |
print(f"worker_addr: {worker_addr}")
|
| 23 |
|
|
|
|
| 36 |
"temperature": 0.7,
|
| 37 |
"stop": conv.sep2,
|
| 38 |
}
|
| 39 |
+
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True)
|
|
|
|
|
|
|
| 40 |
|
| 41 |
print(prompt, end="")
|
| 42 |
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
medrax/llava/utils.py
CHANGED
|
@@ -45,9 +45,7 @@ def build_logger(logger_name, logger_filename):
|
|
| 45 |
if handler is None:
|
| 46 |
os.makedirs(LOGDIR, exist_ok=True)
|
| 47 |
filename = os.path.join(LOGDIR, logger_filename)
|
| 48 |
-
handler = logging.handlers.TimedRotatingFileHandler(
|
| 49 |
-
filename, when="D", utc=True, encoding="UTF-8"
|
| 50 |
-
)
|
| 51 |
handler.setFormatter(formatter)
|
| 52 |
|
| 53 |
for name, item in logging.root.manager.loggerDict.items():
|
|
|
|
| 45 |
if handler is None:
|
| 46 |
os.makedirs(LOGDIR, exist_ok=True)
|
| 47 |
filename = os.path.join(LOGDIR, logger_filename)
|
| 48 |
+
handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True, encoding="UTF-8")
|
|
|
|
|
|
|
| 49 |
handler.setFormatter(formatter)
|
| 50 |
|
| 51 |
for name, item in logging.root.manager.loggerDict.items():
|
medrax/models/model_factory.py
CHANGED
|
@@ -29,7 +29,7 @@ class ModelFactory:
|
|
| 29 |
"base_url_key": "OPENAI_BASE_URL",
|
| 30 |
},
|
| 31 |
"gemini": {
|
| 32 |
-
"class": ChatGoogleGenerativeAI,
|
| 33 |
"env_key": "GOOGLE_API_KEY",
|
| 34 |
"base_url_key": "GOOGLE_BASE_URL",
|
| 35 |
},
|
|
@@ -42,14 +42,12 @@ class ModelFactory:
|
|
| 42 |
"grok": {
|
| 43 |
"class": ChatXAI,
|
| 44 |
"env_key": "XAI_API_KEY",
|
| 45 |
-
}
|
| 46 |
# Add more providers with default configurations here
|
| 47 |
}
|
| 48 |
|
| 49 |
@classmethod
|
| 50 |
-
def register_provider(
|
| 51 |
-
cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs
|
| 52 |
-
) -> None:
|
| 53 |
"""Register a new model provider.
|
| 54 |
|
| 55 |
Args:
|
|
@@ -61,9 +59,7 @@ class ModelFactory:
|
|
| 61 |
cls._model_providers[prefix] = {"class": model_class, "env_key": env_key, **kwargs}
|
| 62 |
|
| 63 |
@classmethod
|
| 64 |
-
def create_model(
|
| 65 |
-
cls, model_name: str, temperature: float = 0.7, **kwargs
|
| 66 |
-
) -> BaseLanguageModel:
|
| 67 |
"""Create and return an instance of the appropriate language model.
|
| 68 |
|
| 69 |
Args:
|
|
@@ -79,9 +75,7 @@ class ModelFactory:
|
|
| 79 |
ValueError: If the required API key is missing
|
| 80 |
"""
|
| 81 |
# Find the matching provider based on model name prefix
|
| 82 |
-
provider_prefix = next(
|
| 83 |
-
(prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None
|
| 84 |
-
)
|
| 85 |
|
| 86 |
if not provider_prefix:
|
| 87 |
raise ValueError(
|
|
@@ -138,7 +132,4 @@ class ModelFactory:
|
|
| 138 |
Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
|
| 139 |
"""
|
| 140 |
# Return a copy to prevent accidental modification
|
| 141 |
-
return {
|
| 142 |
-
k: {kk: vv for kk, vv in v.items() if kk != "class"}
|
| 143 |
-
for k, v in cls._model_providers.items()
|
| 144 |
-
}
|
|
|
|
| 29 |
"base_url_key": "OPENAI_BASE_URL",
|
| 30 |
},
|
| 31 |
"gemini": {
|
| 32 |
+
"class": ChatGoogleGenerativeAI,
|
| 33 |
"env_key": "GOOGLE_API_KEY",
|
| 34 |
"base_url_key": "GOOGLE_BASE_URL",
|
| 35 |
},
|
|
|
|
| 42 |
"grok": {
|
| 43 |
"class": ChatXAI,
|
| 44 |
"env_key": "XAI_API_KEY",
|
| 45 |
+
},
|
| 46 |
# Add more providers with default configurations here
|
| 47 |
}
|
| 48 |
|
| 49 |
@classmethod
|
| 50 |
+
def register_provider(cls, prefix: str, model_class: Type[BaseLanguageModel], env_key: str, **kwargs) -> None:
|
|
|
|
|
|
|
| 51 |
"""Register a new model provider.
|
| 52 |
|
| 53 |
Args:
|
|
|
|
| 59 |
cls._model_providers[prefix] = {"class": model_class, "env_key": env_key, **kwargs}
|
| 60 |
|
| 61 |
@classmethod
|
| 62 |
+
def create_model(cls, model_name: str, temperature: float = 0.7, **kwargs) -> BaseLanguageModel:
|
|
|
|
|
|
|
| 63 |
"""Create and return an instance of the appropriate language model.
|
| 64 |
|
| 65 |
Args:
|
|
|
|
| 75 |
ValueError: If the required API key is missing
|
| 76 |
"""
|
| 77 |
# Find the matching provider based on model name prefix
|
| 78 |
+
provider_prefix = next((prefix for prefix in cls._model_providers if model_name.startswith(prefix)), None)
|
|
|
|
|
|
|
| 79 |
|
| 80 |
if not provider_prefix:
|
| 81 |
raise ValueError(
|
|
|
|
| 132 |
Dict[str, Dict[str, Any]]: Dictionary of registered providers and their configurations
|
| 133 |
"""
|
| 134 |
# Return a copy to prevent accidental modification
|
| 135 |
+
return {k: {kk: vv for kk, vv in v.items() if kk != "class"} for k, v in cls._model_providers.items()}
|
|
|
|
|
|
|
|
|
medrax/rag/rag.py
CHANGED
|
@@ -107,9 +107,7 @@ class CohereRAG:
|
|
| 107 |
# Initialize Pinecone
|
| 108 |
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 109 |
if not self.pinecone_api_key:
|
| 110 |
-
raise ValueError(
|
| 111 |
-
"PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io"
|
| 112 |
-
)
|
| 113 |
self.pinecone = Pinecone(api_key=self.pinecone_api_key)
|
| 114 |
self.index_name = self.config.pinecone_index_name
|
| 115 |
|
|
@@ -161,9 +159,7 @@ class CohereRAG:
|
|
| 161 |
)
|
| 162 |
|
| 163 |
print(f"Connecting to existing Pinecone index: {self.index_name}")
|
| 164 |
-
vectorstore = PineconeVectorStore.from_existing_index(
|
| 165 |
-
index_name=self.index_name, embedding=self.embeddings
|
| 166 |
-
)
|
| 167 |
|
| 168 |
# Check if the index is empty and needs to be populated
|
| 169 |
try:
|
|
@@ -329,9 +325,7 @@ class CohereRAG:
|
|
| 329 |
)
|
| 330 |
documents.append(doc)
|
| 331 |
|
| 332 |
-
print(
|
| 333 |
-
f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}"
|
| 334 |
-
)
|
| 335 |
return documents
|
| 336 |
|
| 337 |
except Exception as e:
|
|
|
|
| 107 |
# Initialize Pinecone
|
| 108 |
self.pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 109 |
if not self.pinecone_api_key:
|
| 110 |
+
raise ValueError("PINECONE_API_KEY environment variable not set. Please get a key from app.pinecone.io")
|
|
|
|
|
|
|
| 111 |
self.pinecone = Pinecone(api_key=self.pinecone_api_key)
|
| 112 |
self.index_name = self.config.pinecone_index_name
|
| 113 |
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
print(f"Connecting to existing Pinecone index: {self.index_name}")
|
| 162 |
+
vectorstore = PineconeVectorStore.from_existing_index(index_name=self.index_name, embedding=self.embeddings)
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# Check if the index is empty and needs to be populated
|
| 165 |
try:
|
|
|
|
| 325 |
)
|
| 326 |
documents.append(doc)
|
| 327 |
|
| 328 |
+
print(f"Loaded {len(documents)} document chunks from HuggingFace dataset: {dataset_name}")
|
|
|
|
|
|
|
| 329 |
return documents
|
| 330 |
|
| 331 |
except Exception as e:
|
medrax/tools/browsing/__init__.py
CHANGED
|
@@ -6,8 +6,8 @@ from .web_browser import WebBrowserTool, WebBrowserSchema, SearchQuerySchema, Vi
|
|
| 6 |
__all__ = [
|
| 7 |
"DuckDuckGoSearchTool",
|
| 8 |
"WebSearchInput",
|
| 9 |
-
"WebBrowserTool",
|
| 10 |
"WebBrowserSchema",
|
| 11 |
"SearchQuerySchema",
|
| 12 |
-
"VisitUrlSchema"
|
| 13 |
-
]
|
|
|
|
| 6 |
__all__ = [
|
| 7 |
"DuckDuckGoSearchTool",
|
| 8 |
"WebSearchInput",
|
| 9 |
+
"WebBrowserTool",
|
| 10 |
"WebBrowserSchema",
|
| 11 |
"SearchQuerySchema",
|
| 12 |
+
"VisitUrlSchema",
|
| 13 |
+
]
|
medrax/tools/browsing/duckduckgo.py
CHANGED
|
@@ -95,18 +95,12 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 95 |
super().__init__(**kwargs)
|
| 96 |
|
| 97 |
if DDGS is None:
|
| 98 |
-
logger.error(
|
| 99 |
-
|
| 100 |
-
)
|
| 101 |
-
raise ImportError(
|
| 102 |
-
"duckduckgo-search package is required for web search functionality"
|
| 103 |
-
)
|
| 104 |
|
| 105 |
logger.info("DuckDuckGo search tool initialized successfully")
|
| 106 |
|
| 107 |
-
def _perform_search_sync(
|
| 108 |
-
self, query: str, max_results: int = 5, region: str = "us-en"
|
| 109 |
-
) -> Dict[str, Any]:
|
| 110 |
"""
|
| 111 |
Perform the actual web search using DuckDuckGo synchronously.
|
| 112 |
|
|
@@ -118,9 +112,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 118 |
Returns:
|
| 119 |
Dict[str, Any]: Structured search results.
|
| 120 |
"""
|
| 121 |
-
logger.info(
|
| 122 |
-
f"Performing web search: '{query}' (max_results={max_results}, region={region})"
|
| 123 |
-
)
|
| 124 |
|
| 125 |
try:
|
| 126 |
# Initialize DDGS with error handling
|
|
@@ -158,9 +150,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 158 |
summary = f"No results found for '{query}'"
|
| 159 |
|
| 160 |
# Log successful completion
|
| 161 |
-
logger.info(
|
| 162 |
-
f"Web search completed successfully: {len(formatted_results)} results"
|
| 163 |
-
)
|
| 164 |
|
| 165 |
return {
|
| 166 |
"query": query,
|
|
@@ -217,7 +207,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 217 |
|
| 218 |
try:
|
| 219 |
result = self._perform_search_sync(query, max_results, region)
|
| 220 |
-
|
| 221 |
# Check if search was successful
|
| 222 |
if "error" in result:
|
| 223 |
metadata["analysis_status"] = "failed"
|
|
@@ -239,7 +229,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 239 |
}
|
| 240 |
metadata["analysis_status"] = "failed"
|
| 241 |
metadata["error_details"] = str(e)
|
| 242 |
-
|
| 243 |
return error_result, metadata
|
| 244 |
|
| 245 |
async def _arun(
|
|
@@ -296,9 +286,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 296 |
|
| 297 |
# Use asyncio to run sync search in executor
|
| 298 |
loop = asyncio.get_event_loop()
|
| 299 |
-
result, metadata = await loop.run_in_executor(
|
| 300 |
-
None, self._run, query, max_results, region
|
| 301 |
-
)
|
| 302 |
|
| 303 |
if writer:
|
| 304 |
# Parse result to get count for progress update
|
|
@@ -333,7 +321,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 333 |
"search_engine": "DuckDuckGo",
|
| 334 |
"timestamp": datetime.now().isoformat(),
|
| 335 |
}
|
| 336 |
-
|
| 337 |
metadata = {
|
| 338 |
"query": query,
|
| 339 |
"max_results": max_results,
|
|
@@ -344,12 +332,10 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 344 |
"analysis_status": "failed",
|
| 345 |
"error_details": str(e),
|
| 346 |
}
|
| 347 |
-
|
| 348 |
return error_result, metadata
|
| 349 |
|
| 350 |
-
def get_search_summary(
|
| 351 |
-
self, query: str, max_results: int = 3
|
| 352 |
-
) -> dict[str, str | list[str]]:
|
| 353 |
"""
|
| 354 |
Get a quick summary of search results for a given query.
|
| 355 |
|
|
@@ -375,14 +361,7 @@ class DuckDuckGoSearchTool(BaseTool):
|
|
| 375 |
results = result.get("results", [])
|
| 376 |
titles = [r["title"] for r in results]
|
| 377 |
urls = [r["url"] for r in results]
|
| 378 |
-
snippets = [
|
| 379 |
-
(
|
| 380 |
-
r["snippet"][:100] + "..."
|
| 381 |
-
if len(r["snippet"]) > 100
|
| 382 |
-
else r["snippet"]
|
| 383 |
-
)
|
| 384 |
-
for r in results
|
| 385 |
-
]
|
| 386 |
|
| 387 |
return {
|
| 388 |
"query": query,
|
|
|
|
| 95 |
super().__init__(**kwargs)
|
| 96 |
|
| 97 |
if DDGS is None:
|
| 98 |
+
logger.error("duckduckgo-search package not installed. Install with: pip install duckduckgo-search")
|
| 99 |
+
raise ImportError("duckduckgo-search package is required for web search functionality")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
logger.info("DuckDuckGo search tool initialized successfully")
|
| 102 |
|
| 103 |
+
def _perform_search_sync(self, query: str, max_results: int = 5, region: str = "us-en") -> Dict[str, Any]:
|
|
|
|
|
|
|
| 104 |
"""
|
| 105 |
Perform the actual web search using DuckDuckGo synchronously.
|
| 106 |
|
|
|
|
| 112 |
Returns:
|
| 113 |
Dict[str, Any]: Structured search results.
|
| 114 |
"""
|
| 115 |
+
logger.info(f"Performing web search: '{query}' (max_results={max_results}, region={region})")
|
|
|
|
|
|
|
| 116 |
|
| 117 |
try:
|
| 118 |
# Initialize DDGS with error handling
|
|
|
|
| 150 |
summary = f"No results found for '{query}'"
|
| 151 |
|
| 152 |
# Log successful completion
|
| 153 |
+
logger.info(f"Web search completed successfully: {len(formatted_results)} results")
|
|
|
|
|
|
|
| 154 |
|
| 155 |
return {
|
| 156 |
"query": query,
|
|
|
|
| 207 |
|
| 208 |
try:
|
| 209 |
result = self._perform_search_sync(query, max_results, region)
|
| 210 |
+
|
| 211 |
# Check if search was successful
|
| 212 |
if "error" in result:
|
| 213 |
metadata["analysis_status"] = "failed"
|
|
|
|
| 229 |
}
|
| 230 |
metadata["analysis_status"] = "failed"
|
| 231 |
metadata["error_details"] = str(e)
|
| 232 |
+
|
| 233 |
return error_result, metadata
|
| 234 |
|
| 235 |
async def _arun(
|
|
|
|
| 286 |
|
| 287 |
# Use asyncio to run sync search in executor
|
| 288 |
loop = asyncio.get_event_loop()
|
| 289 |
+
result, metadata = await loop.run_in_executor(None, self._run, query, max_results, region)
|
|
|
|
|
|
|
| 290 |
|
| 291 |
if writer:
|
| 292 |
# Parse result to get count for progress update
|
|
|
|
| 321 |
"search_engine": "DuckDuckGo",
|
| 322 |
"timestamp": datetime.now().isoformat(),
|
| 323 |
}
|
| 324 |
+
|
| 325 |
metadata = {
|
| 326 |
"query": query,
|
| 327 |
"max_results": max_results,
|
|
|
|
| 332 |
"analysis_status": "failed",
|
| 333 |
"error_details": str(e),
|
| 334 |
}
|
| 335 |
+
|
| 336 |
return error_result, metadata
|
| 337 |
|
| 338 |
+
def get_search_summary(self, query: str, max_results: int = 3) -> dict[str, str | list[str]]:
|
|
|
|
|
|
|
| 339 |
"""
|
| 340 |
Get a quick summary of search results for a given query.
|
| 341 |
|
|
|
|
| 361 |
results = result.get("results", [])
|
| 362 |
titles = [r["title"] for r in results]
|
| 363 |
urls = [r["url"] for r in results]
|
| 364 |
+
snippets = [(r["snippet"][:100] + "..." if len(r["snippet"]) > 100 else r["snippet"]) for r in results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
return {
|
| 367 |
"query": query,
|
medrax/tools/browsing/web_browser.py
CHANGED
|
@@ -78,9 +78,7 @@ class WebBrowserTool(BaseTool):
|
|
| 78 |
max_results: int = 5
|
| 79 |
args_schema: Type[BaseModel] = WebBrowserSchema
|
| 80 |
|
| 81 |
-
def __init__(
|
| 82 |
-
self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs
|
| 83 |
-
):
|
| 84 |
"""Initialize the web browser tool with optional search API credentials.
|
| 85 |
|
| 86 |
Args:
|
|
@@ -145,9 +143,7 @@ class WebBrowserTool(BaseTool):
|
|
| 145 |
except Exception as e:
|
| 146 |
return {"error": f"Search failed: {str(e)}"}
|
| 147 |
|
| 148 |
-
def visit_url(
|
| 149 |
-
self, url: str, max_content_length: int = 5000, max_links: int = 5
|
| 150 |
-
) -> Dict[str, Any]:
|
| 151 |
"""Visit a URL and extract its content with comprehensive parsing.
|
| 152 |
|
| 153 |
Args:
|
|
@@ -218,9 +214,7 @@ class WebBrowserTool(BaseTool):
|
|
| 218 |
return {
|
| 219 |
"title": title,
|
| 220 |
"content": (
|
| 221 |
-
text_content[:max_content_length]
|
| 222 |
-
if len(text_content) > max_content_length
|
| 223 |
-
else text_content
|
| 224 |
),
|
| 225 |
"url": url,
|
| 226 |
"links": links[:max_links], # Limit to max_links
|
|
|
|
| 78 |
max_results: int = 5
|
| 79 |
args_schema: Type[BaseModel] = WebBrowserSchema
|
| 80 |
|
| 81 |
+
def __init__(self, search_api_key: Optional[str] = None, search_engine_id: Optional[str] = None, **kwargs):
|
|
|
|
|
|
|
| 82 |
"""Initialize the web browser tool with optional search API credentials.
|
| 83 |
|
| 84 |
Args:
|
|
|
|
| 143 |
except Exception as e:
|
| 144 |
return {"error": f"Search failed: {str(e)}"}
|
| 145 |
|
| 146 |
+
def visit_url(self, url: str, max_content_length: int = 5000, max_links: int = 5) -> Dict[str, Any]:
|
|
|
|
|
|
|
| 147 |
"""Visit a URL and extract its content with comprehensive parsing.
|
| 148 |
|
| 149 |
Args:
|
|
|
|
| 214 |
return {
|
| 215 |
"title": title,
|
| 216 |
"content": (
|
| 217 |
+
text_content[:max_content_length] if len(text_content) > max_content_length else text_content
|
|
|
|
|
|
|
| 218 |
),
|
| 219 |
"url": url,
|
| 220 |
"links": links[:max_links], # Limit to max_links
|
medrax/tools/classification/__init__.py
CHANGED
|
@@ -3,9 +3,4 @@
|
|
| 3 |
from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
|
| 4 |
from .arcplus import ArcPlusClassifierTool, ArcPlusInput
|
| 5 |
|
| 6 |
-
__all__ = [
|
| 7 |
-
"TorchXRayVisionClassifierTool",
|
| 8 |
-
"TorchXRayVisionInput",
|
| 9 |
-
"ArcPlusClassifierTool",
|
| 10 |
-
"ArcPlusInput"
|
| 11 |
-
]
|
|
|
|
| 3 |
from .torchxrayvision import TorchXRayVisionClassifierTool, TorchXRayVisionInput
|
| 4 |
from .arcplus import ArcPlusClassifierTool, ArcPlusInput
|
| 5 |
|
| 6 |
+
__all__ = ["TorchXRayVisionClassifierTool", "TorchXRayVisionInput", "ArcPlusClassifierTool", "ArcPlusInput"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrax/tools/classification/arcplus.py
CHANGED
|
@@ -38,9 +38,7 @@ class OmniSwinTransformer(SwinTransformer):
|
|
| 38 |
|
| 39 |
self.omni_heads = []
|
| 40 |
for num_classes in num_classes_list:
|
| 41 |
-
self.omni_heads.append(
|
| 42 |
-
nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 43 |
-
)
|
| 44 |
self.omni_heads = nn.ModuleList(self.omni_heads)
|
| 45 |
|
| 46 |
def forward(self, x, head_n=None):
|
|
@@ -62,9 +60,7 @@ class OmniSwinTransformer(SwinTransformer):
|
|
| 62 |
class ArcPlusInput(BaseModel):
|
| 63 |
"""Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
|
| 64 |
|
| 65 |
-
image_path: str = Field(
|
| 66 |
-
..., description="Path to the radiology image file, only supports JPG or PNG images"
|
| 67 |
-
)
|
| 68 |
|
| 69 |
|
| 70 |
class ArcPlusClassifierTool(BaseTool):
|
|
@@ -249,11 +245,7 @@ class ArcPlusClassifierTool(BaseTool):
|
|
| 249 |
|
| 250 |
# Remove "module." prefix if present (improved logic from example)
|
| 251 |
if any([True if "module." in k else False for k in state_dict.keys()]):
|
| 252 |
-
state_dict = {
|
| 253 |
-
k.replace("module.", ""): v
|
| 254 |
-
for k, v in state_dict.items()
|
| 255 |
-
if k.startswith("module.")
|
| 256 |
-
}
|
| 257 |
|
| 258 |
# Load the model weights
|
| 259 |
msg = self.model.load_state_dict(state_dict, strict=False)
|
|
@@ -333,14 +325,10 @@ class ArcPlusClassifierTool(BaseTool):
|
|
| 333 |
|
| 334 |
# Map predictions to disease names
|
| 335 |
if len(predictions) != len(self.disease_list):
|
| 336 |
-
print(
|
| 337 |
-
f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}"
|
| 338 |
-
)
|
| 339 |
# Pad or truncate as needed
|
| 340 |
if len(predictions) < len(self.disease_list):
|
| 341 |
-
predictions = np.pad(
|
| 342 |
-
predictions, (0, len(self.disease_list) - len(predictions))
|
| 343 |
-
)
|
| 344 |
else:
|
| 345 |
predictions = predictions[: len(self.disease_list)]
|
| 346 |
|
|
|
|
| 38 |
|
| 39 |
self.omni_heads = []
|
| 40 |
for num_classes in num_classes_list:
|
| 41 |
+
self.omni_heads.append(nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
|
|
|
|
|
|
|
| 42 |
self.omni_heads = nn.ModuleList(self.omni_heads)
|
| 43 |
|
| 44 |
def forward(self, x, head_n=None):
|
|
|
|
| 60 |
class ArcPlusInput(BaseModel):
|
| 61 |
"""Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images."""
|
| 62 |
|
| 63 |
+
image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class ArcPlusClassifierTool(BaseTool):
|
|
|
|
| 245 |
|
| 246 |
# Remove "module." prefix if present (improved logic from example)
|
| 247 |
if any([True if "module." in k else False for k in state_dict.keys()]):
|
| 248 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if k.startswith("module.")}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
# Load the model weights
|
| 251 |
msg = self.model.load_state_dict(state_dict, strict=False)
|
|
|
|
| 325 |
|
| 326 |
# Map predictions to disease names
|
| 327 |
if len(predictions) != len(self.disease_list):
|
| 328 |
+
print(f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}")
|
|
|
|
|
|
|
| 329 |
# Pad or truncate as needed
|
| 330 |
if len(predictions) < len(self.disease_list):
|
| 331 |
+
predictions = np.pad(predictions, (0, len(self.disease_list) - len(predictions)))
|
|
|
|
|
|
|
| 332 |
else:
|
| 333 |
predictions = predictions[: len(self.disease_list)]
|
| 334 |
|
medrax/tools/classification/torchxrayvision.py
CHANGED
|
@@ -16,9 +16,7 @@ from langchain_core.tools import BaseTool
|
|
| 16 |
class TorchXRayVisionInput(BaseModel):
|
| 17 |
"""Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 18 |
|
| 19 |
-
image_path: str = Field(
|
| 20 |
-
..., description="Path to the radiology image file, only supports JPG or PNG images"
|
| 21 |
-
)
|
| 22 |
|
| 23 |
|
| 24 |
class TorchXRayVisionClassifierTool(BaseTool):
|
|
|
|
| 16 |
class TorchXRayVisionInput(BaseModel):
|
| 17 |
"""Input for TorchXRayVision chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 18 |
|
| 19 |
+
image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class TorchXRayVisionClassifierTool(BaseTool):
|
medrax/tools/dicom.py
CHANGED
|
@@ -14,9 +14,7 @@ class DicomProcessorInput(BaseModel):
|
|
| 14 |
"""Input schema for the DICOM Processor Tool."""
|
| 15 |
|
| 16 |
dicom_path: str = Field(..., description="Path to the DICOM file")
|
| 17 |
-
window_center: Optional[float] = Field(
|
| 18 |
-
None, description="Window center for contrast adjustment"
|
| 19 |
-
)
|
| 20 |
window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
|
| 21 |
|
| 22 |
|
|
|
|
| 14 |
"""Input schema for the DICOM Processor Tool."""
|
| 15 |
|
| 16 |
dicom_path: str = Field(..., description="Path to the DICOM file")
|
| 17 |
+
window_center: Optional[float] = Field(None, description="Window center for contrast adjustment")
|
|
|
|
|
|
|
| 18 |
window_width: Optional[float] = Field(None, description="Window width for contrast adjustment")
|
| 19 |
|
| 20 |
|
medrax/tools/grounding.py
CHANGED
|
@@ -89,11 +89,8 @@ class XRayPhraseGroundingTool(BaseTool):
|
|
| 89 |
trust_remote_code=True,
|
| 90 |
quantization_config=quantization_config,
|
| 91 |
)
|
| 92 |
-
self.processor = AutoProcessor.from_pretrained(
|
| 93 |
-
model_path, cache_dir=cache_dir, trust_remote_code=True
|
| 94 |
-
)
|
| 95 |
|
| 96 |
-
|
| 97 |
self.model = self.model.eval()
|
| 98 |
|
| 99 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
|
@@ -167,12 +164,8 @@ class XRayPhraseGroundingTool(BaseTool):
|
|
| 167 |
)
|
| 168 |
|
| 169 |
prompt_length = inputs["input_ids"].shape[-1]
|
| 170 |
-
decoded_text = self.processor.decode(
|
| 171 |
-
|
| 172 |
-
)
|
| 173 |
-
predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(
|
| 174 |
-
decoded_text
|
| 175 |
-
)
|
| 176 |
|
| 177 |
metadata = {
|
| 178 |
"image_path": image_path,
|
|
@@ -199,9 +192,7 @@ class XRayPhraseGroundingTool(BaseTool):
|
|
| 199 |
# Convert model bboxes to list format and get original image bboxes
|
| 200 |
model_bboxes = [list(bbox) for bbox in pred_bboxes]
|
| 201 |
original_bboxes = [
|
| 202 |
-
self.processor.adjust_box_for_original_image_size(
|
| 203 |
-
bbox, width=image.size[0], height=image.size[1]
|
| 204 |
-
)
|
| 205 |
for bbox in model_bboxes
|
| 206 |
]
|
| 207 |
|
|
|
|
| 89 |
trust_remote_code=True,
|
| 90 |
quantization_config=quantization_config,
|
| 91 |
)
|
| 92 |
+
self.processor = AutoProcessor.from_pretrained(model_path, cache_dir=cache_dir, trust_remote_code=True)
|
|
|
|
|
|
|
| 93 |
|
|
|
|
| 94 |
self.model = self.model.eval()
|
| 95 |
|
| 96 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
prompt_length = inputs["input_ids"].shape[-1]
|
| 167 |
+
decoded_text = self.processor.decode(output[0][prompt_length:], skip_special_tokens=True)
|
| 168 |
+
predictions = self.processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
metadata = {
|
| 171 |
"image_path": image_path,
|
|
|
|
| 192 |
# Convert model bboxes to list format and get original image bboxes
|
| 193 |
model_bboxes = [list(bbox) for bbox in pred_bboxes]
|
| 194 |
original_bboxes = [
|
| 195 |
+
self.processor.adjust_box_for_original_image_size(bbox, width=image.size[0], height=image.size[1])
|
|
|
|
|
|
|
| 196 |
for bbox in model_bboxes
|
| 197 |
]
|
| 198 |
|
medrax/tools/rag.py
CHANGED
|
@@ -14,7 +14,7 @@ class RAGTool(BaseTool):
|
|
| 14 |
|
| 15 |
The knowledge base includes:
|
| 16 |
- Medical textbooks and reference materials
|
| 17 |
-
- Research papers and clinical studies
|
| 18 |
- Medical manuals and guidelines
|
| 19 |
- Specialized medical literature
|
| 20 |
|
|
|
|
| 14 |
|
| 15 |
The knowledge base includes:
|
| 16 |
- Medical textbooks and reference materials
|
| 17 |
+
- Research papers and clinical studies
|
| 18 |
- Medical manuals and guidelines
|
| 19 |
- Specialized medical literature
|
| 20 |
|
medrax/tools/report_generation.py
CHANGED
|
@@ -22,9 +22,7 @@ from transformers import (
|
|
| 22 |
class ChestXRayInput(BaseModel):
|
| 23 |
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 24 |
|
| 25 |
-
image_path: str = Field(
|
| 26 |
-
..., description="Path to the radiology image file, only supports JPG or PNG images"
|
| 27 |
-
)
|
| 28 |
|
| 29 |
|
| 30 |
class ChestXRayReportGeneratorTool(BaseTool):
|
|
@@ -170,12 +168,8 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
| 170 |
"""
|
| 171 |
try:
|
| 172 |
# Process image for both models
|
| 173 |
-
findings_pixels = self._process_image(
|
| 174 |
-
|
| 175 |
-
)
|
| 176 |
-
impression_pixels = self._process_image(
|
| 177 |
-
image_path, self.impression_processor, self.impression_model
|
| 178 |
-
)
|
| 179 |
|
| 180 |
# Generate both sections
|
| 181 |
with torch.inference_mode():
|
|
@@ -187,11 +181,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
| 187 |
)
|
| 188 |
|
| 189 |
# Combine into formatted report
|
| 190 |
-
report =
|
| 191 |
-
"CHEST X-RAY REPORT\n\n"
|
| 192 |
-
f"FINDINGS:\n{findings_text}\n\n"
|
| 193 |
-
f"IMPRESSION:\n{impression_text}"
|
| 194 |
-
)
|
| 195 |
|
| 196 |
output = {
|
| 197 |
"report": report,
|
|
|
|
| 22 |
class ChestXRayInput(BaseModel):
|
| 23 |
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
|
| 24 |
|
| 25 |
+
image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images")
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
class ChestXRayReportGeneratorTool(BaseTool):
|
|
|
|
| 168 |
"""
|
| 169 |
try:
|
| 170 |
# Process image for both models
|
| 171 |
+
findings_pixels = self._process_image(image_path, self.findings_processor, self.findings_model)
|
| 172 |
+
impression_pixels = self._process_image(image_path, self.impression_processor, self.impression_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
# Generate both sections
|
| 175 |
with torch.inference_mode():
|
|
|
|
| 181 |
)
|
| 182 |
|
| 183 |
# Combine into formatted report
|
| 184 |
+
report = "CHEST X-RAY REPORT\n\n" f"FINDINGS:\n{findings_text}\n\n" f"IMPRESSION:\n{impression_text}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
output = {
|
| 187 |
"report": report,
|
medrax/tools/segmentation/__init__.py
CHANGED
|
@@ -3,10 +3,4 @@
|
|
| 3 |
from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
|
| 4 |
from .medsam2 import MedSAM2Tool, MedSAM2Input
|
| 5 |
|
| 6 |
-
__all__ = [
|
| 7 |
-
"ChestXRaySegmentationTool",
|
| 8 |
-
"ChestXRaySegmentationInput",
|
| 9 |
-
"OrganMetrics",
|
| 10 |
-
"MedSAM2Tool",
|
| 11 |
-
"MedSAM2Input"
|
| 12 |
-
]
|
|
|
|
| 3 |
from .segmentation import ChestXRaySegmentationTool, ChestXRaySegmentationInput, OrganMetrics
|
| 4 |
from .medsam2 import MedSAM2Tool, MedSAM2Input
|
| 5 |
|
| 6 |
+
__all__ = ["ChestXRaySegmentationTool", "ChestXRaySegmentationInput", "OrganMetrics", "MedSAM2Tool", "MedSAM2Input"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrax/tools/segmentation/medsam2.py
CHANGED
|
@@ -26,7 +26,6 @@ from hydra import initialize_config_dir
|
|
| 26 |
from hydra.core.global_hydra import GlobalHydra
|
| 27 |
|
| 28 |
|
| 29 |
-
|
| 30 |
class MedSAM2Input(BaseModel):
|
| 31 |
"""Input schema for the MedSAM2 Tool."""
|
| 32 |
|
|
@@ -47,7 +46,7 @@ class MedSAM2Input(BaseModel):
|
|
| 47 |
|
| 48 |
class MedSAM2Tool(BaseTool):
|
| 49 |
"""Advanced medical image segmentation tool using MedSAM2.
|
| 50 |
-
|
| 51 |
This tool provides state-of-the-art medical image segmentation capabilities using
|
| 52 |
the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
|
| 53 |
Supports interactive prompting with boxes, points, or automatic segmentation.
|
|
@@ -92,18 +91,15 @@ class MedSAM2Tool(BaseTool):
|
|
| 92 |
# This works around the issue with initialize_config_module in sam2
|
| 93 |
if GlobalHydra.instance().is_initialized():
|
| 94 |
GlobalHydra.instance().clear()
|
| 95 |
-
|
| 96 |
config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
|
| 97 |
initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
|
| 98 |
-
|
| 99 |
hf_hub_download(
|
| 100 |
-
repo_id=model_path,
|
| 101 |
-
filename=model_file,
|
| 102 |
-
local_dir=self.cache_dir,
|
| 103 |
-
local_dir_use_symlinks=False
|
| 104 |
)
|
| 105 |
|
| 106 |
-
config_path = model_cfg.replace(
|
| 107 |
sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
|
| 108 |
self.predictor = SAM2ImagePredictor(sam2_model)
|
| 109 |
|
|
@@ -114,37 +110,37 @@ class MedSAM2Tool(BaseTool):
|
|
| 114 |
"""Load and preprocess image for medical analysis."""
|
| 115 |
try:
|
| 116 |
# Handle different image formats
|
| 117 |
-
if image_path.lower().endswith(
|
| 118 |
# DICOM files - would need DICOM processor
|
| 119 |
raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
|
| 120 |
-
|
| 121 |
# Load standard image formats
|
| 122 |
image = Image.open(image_path)
|
| 123 |
-
|
| 124 |
# For medical images, convert to grayscale first if needed, then to RGB
|
| 125 |
-
if image.mode ==
|
| 126 |
# Convert grayscale to RGB for SAM2
|
| 127 |
-
image = image.convert(
|
| 128 |
-
elif image.mode !=
|
| 129 |
-
if image.mode ==
|
| 130 |
# Create white background for RGBA
|
| 131 |
-
background = Image.new(
|
| 132 |
background.paste(image, mask=image.split()[-1])
|
| 133 |
image = background
|
| 134 |
else:
|
| 135 |
-
image = image.convert(
|
| 136 |
-
|
| 137 |
# Convert to numpy array
|
| 138 |
image_np = np.array(image)
|
| 139 |
-
|
| 140 |
# Ensure image is in proper range [0, 255]
|
| 141 |
if image_np.max() <= 1.0:
|
| 142 |
image_np = (image_np * 255).astype(np.uint8)
|
| 143 |
else:
|
| 144 |
image_np = image_np.astype(np.uint8)
|
| 145 |
-
|
| 146 |
return image_np
|
| 147 |
-
|
| 148 |
except Exception as e:
|
| 149 |
raise ValueError(f"Failed to load image {image_path}: {str(e)}")
|
| 150 |
|
|
@@ -152,55 +148,53 @@ class MedSAM2Tool(BaseTool):
|
|
| 152 |
"""Process and validate prompts."""
|
| 153 |
if prompt_type == "auto":
|
| 154 |
return None, None, None
|
| 155 |
-
|
| 156 |
if prompt_coords is None:
|
| 157 |
if prompt_type != "auto":
|
| 158 |
raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
|
| 159 |
return None, None, None
|
| 160 |
-
|
| 161 |
if prompt_type == "box":
|
| 162 |
if len(prompt_coords) != 4:
|
| 163 |
raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
|
| 164 |
-
|
| 165 |
x1, y1, x2, y2 = prompt_coords
|
| 166 |
# Validate coordinates
|
| 167 |
if x1 >= x2 or y1 >= y2:
|
| 168 |
raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
|
| 169 |
-
|
| 170 |
input_box = np.array([[x1, y1, x2, y2]])
|
| 171 |
return input_box, None, None
|
| 172 |
-
|
| 173 |
elif prompt_type == "point":
|
| 174 |
if len(prompt_coords) != 2:
|
| 175 |
raise ValueError("Point prompt requires 2 coordinates: [x,y]")
|
| 176 |
-
|
| 177 |
x, y = prompt_coords
|
| 178 |
input_point = np.array([[x, y]])
|
| 179 |
input_label = np.array([1]) # Positive point
|
| 180 |
return None, input_point, input_label
|
| 181 |
-
|
| 182 |
else:
|
| 183 |
raise ValueError(f"Unknown prompt type: {prompt_type}")
|
| 184 |
|
| 185 |
def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
|
| 186 |
"""Create visualization of segmentation results."""
|
| 187 |
plt.figure(figsize=(10, 10))
|
| 188 |
-
|
| 189 |
# Convert RGB image to grayscale for background display
|
| 190 |
if len(image.shape) == 3:
|
| 191 |
# Convert RGB to grayscale using standard luminance formula
|
| 192 |
-
gray_image = 0.299 * image[
|
| 193 |
else:
|
| 194 |
gray_image = image
|
| 195 |
-
|
| 196 |
# Display grayscale background
|
| 197 |
-
plt.imshow(
|
| 198 |
-
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
# Generate color palette for multiple masks
|
| 202 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
|
| 203 |
-
|
| 204 |
# Process and overlay each mask
|
| 205 |
for idx, (mask, color) in enumerate(zip(masks, colors)):
|
| 206 |
if mask.sum() > 0:
|
|
@@ -208,33 +202,31 @@ class MedSAM2Tool(BaseTool):
|
|
| 208 |
mask_bool = mask.astype(bool)
|
| 209 |
colored_mask = np.zeros((*mask_bool.shape, 4))
|
| 210 |
colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
|
| 211 |
-
plt.imshow(
|
| 212 |
-
|
| 213 |
-
)
|
| 214 |
-
|
| 215 |
# Add legend entry for each mask
|
| 216 |
mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
|
| 217 |
plt.plot([], [], color=color, label=mask_label, linewidth=3)
|
| 218 |
-
|
| 219 |
# Add prompt visualization with consistent styling
|
| 220 |
-
if prompt_info.get(
|
| 221 |
-
box = prompt_info[
|
| 222 |
x1, y1, x2, y2 = box
|
| 223 |
-
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1],
|
| 224 |
-
|
| 225 |
-
if prompt_info.get(
|
| 226 |
-
point = prompt_info[
|
| 227 |
-
plt.plot(point[0], point[1],
|
| 228 |
-
|
| 229 |
plt.title("Segmentation Overlay")
|
| 230 |
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
| 231 |
plt.axis("off")
|
| 232 |
-
|
| 233 |
# Save visualization with higher DPI like segmentation tool
|
| 234 |
viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
|
| 235 |
-
plt.savefig(viz_path, bbox_inches=
|
| 236 |
plt.close()
|
| 237 |
-
|
| 238 |
return str(viz_path)
|
| 239 |
|
| 240 |
def _run(
|
|
@@ -249,28 +241,28 @@ class MedSAM2Tool(BaseTool):
|
|
| 249 |
try:
|
| 250 |
# Load image
|
| 251 |
image = self._load_image(image_path)
|
| 252 |
-
|
| 253 |
# Set image for predictor
|
| 254 |
self.predictor.set_image(image)
|
| 255 |
-
|
| 256 |
# Process prompts
|
| 257 |
-
input_box, input_point, input_label = self._process_prompts(
|
| 258 |
-
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
# Run inference
|
| 262 |
if prompt_type == "auto":
|
| 263 |
# For auto segmentation, try multiple approaches and select best result
|
| 264 |
h, w = image.shape[:2]
|
| 265 |
-
|
| 266 |
# Try multiple points in key areas for medical images
|
| 267 |
-
sample_points = np.array(
|
| 268 |
-
[
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
|
|
|
| 272 |
sample_labels = np.array([1, 1, 1]) # All positive points
|
| 273 |
-
|
| 274 |
masks, scores, logits = self.predictor.predict(
|
| 275 |
point_coords=sample_points,
|
| 276 |
point_labels=sample_labels,
|
|
@@ -283,29 +275,29 @@ class MedSAM2Tool(BaseTool):
|
|
| 283 |
box=input_box,
|
| 284 |
multimask_output=True,
|
| 285 |
)
|
| 286 |
-
|
| 287 |
# Create visualization
|
| 288 |
prompt_info = {
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
}
|
| 294 |
viz_path = self._create_visualization(image, masks, prompt_info)
|
| 295 |
-
|
| 296 |
# Create output dictionary (main results)
|
| 297 |
output = {
|
| 298 |
"segmentation_image_path": viz_path,
|
| 299 |
-
"confidence_scores": scores.tolist() if hasattr(scores,
|
| 300 |
"num_masks": len(masks),
|
| 301 |
"best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
|
| 302 |
"mask_summary": {
|
| 303 |
"total_masks": len(masks),
|
| 304 |
"mask_shapes": [list(mask.shape) for mask in masks],
|
| 305 |
-
"segmented_area_pixels": [int(mask.sum()) for mask in masks]
|
| 306 |
},
|
| 307 |
}
|
| 308 |
-
|
| 309 |
# Create metadata dictionary
|
| 310 |
metadata = {
|
| 311 |
"image_path": image_path,
|
|
@@ -317,9 +309,9 @@ class MedSAM2Tool(BaseTool):
|
|
| 317 |
"num_masks_generated": len(masks),
|
| 318 |
"analysis_status": "completed",
|
| 319 |
}
|
| 320 |
-
|
| 321 |
return output, metadata
|
| 322 |
-
|
| 323 |
except Exception as e:
|
| 324 |
error_output = {"error": str(e)}
|
| 325 |
error_metadata = {
|
|
@@ -338,4 +330,4 @@ class MedSAM2Tool(BaseTool):
|
|
| 338 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 339 |
) -> Tuple[Dict[str, Any], Dict]:
|
| 340 |
"""Async version of _run."""
|
| 341 |
-
return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
|
|
|
|
| 26 |
from hydra.core.global_hydra import GlobalHydra
|
| 27 |
|
| 28 |
|
|
|
|
| 29 |
class MedSAM2Input(BaseModel):
|
| 30 |
"""Input schema for the MedSAM2 Tool."""
|
| 31 |
|
|
|
|
| 46 |
|
| 47 |
class MedSAM2Tool(BaseTool):
|
| 48 |
"""Advanced medical image segmentation tool using MedSAM2.
|
| 49 |
+
|
| 50 |
This tool provides state-of-the-art medical image segmentation capabilities using
|
| 51 |
the MedSAM2 model, which is specifically adapted for medical imaging from Meta's SAM2.
|
| 52 |
Supports interactive prompting with boxes, points, or automatic segmentation.
|
|
|
|
| 91 |
# This works around the issue with initialize_config_module in sam2
|
| 92 |
if GlobalHydra.instance().is_initialized():
|
| 93 |
GlobalHydra.instance().clear()
|
| 94 |
+
|
| 95 |
config_dir = Path(__file__).parent.parent.parent.parent / "MedSAM2" / "sam2" / "configs"
|
| 96 |
initialize_config_dir(config_dir=str(config_dir), version_base="1.2")
|
| 97 |
+
|
| 98 |
hf_hub_download(
|
| 99 |
+
repo_id=model_path, filename=model_file, local_dir=self.cache_dir, local_dir_use_symlinks=False
|
|
|
|
|
|
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
+
config_path = model_cfg.replace(".yaml", "")
|
| 103 |
sam2_model = build_sam2(config_path, str(self.cache_dir / model_file), device=device)
|
| 104 |
self.predictor = SAM2ImagePredictor(sam2_model)
|
| 105 |
|
|
|
|
| 110 |
"""Load and preprocess image for medical analysis."""
|
| 111 |
try:
|
| 112 |
# Handle different image formats
|
| 113 |
+
if image_path.lower().endswith(".dcm"):
|
| 114 |
# DICOM files - would need DICOM processor
|
| 115 |
raise ValueError("DICOM files not directly supported. Please convert to standard image format first.")
|
| 116 |
+
|
| 117 |
# Load standard image formats
|
| 118 |
image = Image.open(image_path)
|
| 119 |
+
|
| 120 |
# For medical images, convert to grayscale first if needed, then to RGB
|
| 121 |
+
if image.mode == "L": # Grayscale
|
| 122 |
# Convert grayscale to RGB for SAM2
|
| 123 |
+
image = image.convert("RGB")
|
| 124 |
+
elif image.mode != "RGB":
|
| 125 |
+
if image.mode == "RGBA":
|
| 126 |
# Create white background for RGBA
|
| 127 |
+
background = Image.new("RGB", image.size, (255, 255, 255))
|
| 128 |
background.paste(image, mask=image.split()[-1])
|
| 129 |
image = background
|
| 130 |
else:
|
| 131 |
+
image = image.convert("RGB")
|
| 132 |
+
|
| 133 |
# Convert to numpy array
|
| 134 |
image_np = np.array(image)
|
| 135 |
+
|
| 136 |
# Ensure image is in proper range [0, 255]
|
| 137 |
if image_np.max() <= 1.0:
|
| 138 |
image_np = (image_np * 255).astype(np.uint8)
|
| 139 |
else:
|
| 140 |
image_np = image_np.astype(np.uint8)
|
| 141 |
+
|
| 142 |
return image_np
|
| 143 |
+
|
| 144 |
except Exception as e:
|
| 145 |
raise ValueError(f"Failed to load image {image_path}: {str(e)}")
|
| 146 |
|
|
|
|
| 148 |
"""Process and validate prompts."""
|
| 149 |
if prompt_type == "auto":
|
| 150 |
return None, None, None
|
| 151 |
+
|
| 152 |
if prompt_coords is None:
|
| 153 |
if prompt_type != "auto":
|
| 154 |
raise ValueError(f"Prompt coordinates required for prompt type '{prompt_type}'")
|
| 155 |
return None, None, None
|
| 156 |
+
|
| 157 |
if prompt_type == "box":
|
| 158 |
if len(prompt_coords) != 4:
|
| 159 |
raise ValueError("Box prompt requires 4 coordinates: [x1,y1,x2,y2]")
|
| 160 |
+
|
| 161 |
x1, y1, x2, y2 = prompt_coords
|
| 162 |
# Validate coordinates
|
| 163 |
if x1 >= x2 or y1 >= y2:
|
| 164 |
raise ValueError("Invalid box coordinates: x1 < x2 and y1 < y2 required")
|
| 165 |
+
|
| 166 |
input_box = np.array([[x1, y1, x2, y2]])
|
| 167 |
return input_box, None, None
|
| 168 |
+
|
| 169 |
elif prompt_type == "point":
|
| 170 |
if len(prompt_coords) != 2:
|
| 171 |
raise ValueError("Point prompt requires 2 coordinates: [x,y]")
|
| 172 |
+
|
| 173 |
x, y = prompt_coords
|
| 174 |
input_point = np.array([[x, y]])
|
| 175 |
input_label = np.array([1]) # Positive point
|
| 176 |
return None, input_point, input_label
|
| 177 |
+
|
| 178 |
else:
|
| 179 |
raise ValueError(f"Unknown prompt type: {prompt_type}")
|
| 180 |
|
| 181 |
def _create_visualization(self, image: np.ndarray, masks: np.ndarray, prompt_info: Dict) -> str:
|
| 182 |
"""Create visualization of segmentation results."""
|
| 183 |
plt.figure(figsize=(10, 10))
|
| 184 |
+
|
| 185 |
# Convert RGB image to grayscale for background display
|
| 186 |
if len(image.shape) == 3:
|
| 187 |
# Convert RGB to grayscale using standard luminance formula
|
| 188 |
+
gray_image = 0.299 * image[:, :, 0] + 0.587 * image[:, :, 1] + 0.114 * image[:, :, 2]
|
| 189 |
else:
|
| 190 |
gray_image = image
|
| 191 |
+
|
| 192 |
# Display grayscale background
|
| 193 |
+
plt.imshow(gray_image, cmap="gray", extent=[0, image.shape[1], image.shape[0], 0])
|
| 194 |
+
|
|
|
|
|
|
|
| 195 |
# Generate color palette for multiple masks
|
| 196 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
|
| 197 |
+
|
| 198 |
# Process and overlay each mask
|
| 199 |
for idx, (mask, color) in enumerate(zip(masks, colors)):
|
| 200 |
if mask.sum() > 0:
|
|
|
|
| 202 |
mask_bool = mask.astype(bool)
|
| 203 |
colored_mask = np.zeros((*mask_bool.shape, 4))
|
| 204 |
colored_mask[mask_bool] = (*color[:3], 0.3) # 30% transparency like segmentation tool
|
| 205 |
+
plt.imshow(colored_mask, extent=[0, image.shape[1], image.shape[0], 0])
|
| 206 |
+
|
|
|
|
|
|
|
| 207 |
# Add legend entry for each mask
|
| 208 |
mask_label = f"Mask {idx + 1} (score: {prompt_info.get('scores', [0])[idx] if idx < len(prompt_info.get('scores', [])) else 0:.3f})"
|
| 209 |
plt.plot([], [], color=color, label=mask_label, linewidth=3)
|
| 210 |
+
|
| 211 |
# Add prompt visualization with consistent styling
|
| 212 |
+
if prompt_info.get("box") is not None:
|
| 213 |
+
box = prompt_info["box"][0]
|
| 214 |
x1, y1, x2, y2 = box
|
| 215 |
+
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], "g-", linewidth=2, label="Box Prompt")
|
| 216 |
+
|
| 217 |
+
if prompt_info.get("point") is not None:
|
| 218 |
+
point = prompt_info["point"][0]
|
| 219 |
+
plt.plot(point[0], point[1], "go", markersize=10, label="Point Prompt")
|
| 220 |
+
|
| 221 |
plt.title("Segmentation Overlay")
|
| 222 |
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
|
| 223 |
plt.axis("off")
|
| 224 |
+
|
| 225 |
# Save visualization with higher DPI like segmentation tool
|
| 226 |
viz_path = self.temp_dir / f"medsam2_result_{uuid.uuid4().hex[:8]}.png"
|
| 227 |
+
plt.savefig(viz_path, bbox_inches="tight", dpi=300)
|
| 228 |
plt.close()
|
| 229 |
+
|
| 230 |
return str(viz_path)
|
| 231 |
|
| 232 |
def _run(
|
|
|
|
| 241 |
try:
|
| 242 |
# Load image
|
| 243 |
image = self._load_image(image_path)
|
| 244 |
+
|
| 245 |
# Set image for predictor
|
| 246 |
self.predictor.set_image(image)
|
| 247 |
+
|
| 248 |
# Process prompts
|
| 249 |
+
input_box, input_point, input_label = self._process_prompts(prompt_type, prompt_coords, image.shape[:2])
|
| 250 |
+
|
|
|
|
|
|
|
| 251 |
# Run inference
|
| 252 |
if prompt_type == "auto":
|
| 253 |
# For auto segmentation, try multiple approaches and select best result
|
| 254 |
h, w = image.shape[:2]
|
| 255 |
+
|
| 256 |
# Try multiple points in key areas for medical images
|
| 257 |
+
sample_points = np.array(
|
| 258 |
+
[
|
| 259 |
+
[w // 3, h // 3], # Upper left lung area
|
| 260 |
+
[2 * w // 3, h // 3], # Upper right lung area
|
| 261 |
+
[w // 2, 2 * h // 3], # Lower center area
|
| 262 |
+
]
|
| 263 |
+
)
|
| 264 |
sample_labels = np.array([1, 1, 1]) # All positive points
|
| 265 |
+
|
| 266 |
masks, scores, logits = self.predictor.predict(
|
| 267 |
point_coords=sample_points,
|
| 268 |
point_labels=sample_labels,
|
|
|
|
| 275 |
box=input_box,
|
| 276 |
multimask_output=True,
|
| 277 |
)
|
| 278 |
+
|
| 279 |
# Create visualization
|
| 280 |
prompt_info = {
|
| 281 |
+
"box": input_box,
|
| 282 |
+
"point": input_point,
|
| 283 |
+
"type": prompt_type,
|
| 284 |
+
"scores": scores, # Add scores for legend display
|
| 285 |
}
|
| 286 |
viz_path = self._create_visualization(image, masks, prompt_info)
|
| 287 |
+
|
| 288 |
# Create output dictionary (main results)
|
| 289 |
output = {
|
| 290 |
"segmentation_image_path": viz_path,
|
| 291 |
+
"confidence_scores": scores.tolist() if hasattr(scores, "tolist") else list(scores),
|
| 292 |
"num_masks": len(masks),
|
| 293 |
"best_mask_score": float(scores[0]) if len(scores) > 0 else 0.0,
|
| 294 |
"mask_summary": {
|
| 295 |
"total_masks": len(masks),
|
| 296 |
"mask_shapes": [list(mask.shape) for mask in masks],
|
| 297 |
+
"segmented_area_pixels": [int(mask.sum()) for mask in masks],
|
| 298 |
},
|
| 299 |
}
|
| 300 |
+
|
| 301 |
# Create metadata dictionary
|
| 302 |
metadata = {
|
| 303 |
"image_path": image_path,
|
|
|
|
| 309 |
"num_masks_generated": len(masks),
|
| 310 |
"analysis_status": "completed",
|
| 311 |
}
|
| 312 |
+
|
| 313 |
return output, metadata
|
| 314 |
+
|
| 315 |
except Exception as e:
|
| 316 |
error_output = {"error": str(e)}
|
| 317 |
error_metadata = {
|
|
|
|
| 330 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 331 |
) -> Tuple[Dict[str, Any], Dict]:
|
| 332 |
"""Async version of _run."""
|
| 333 |
+
return self._run(image_path, prompt_type, prompt_coords, slice_index, run_manager)
|
medrax/tools/segmentation/segmentation.py
CHANGED
|
@@ -41,9 +41,7 @@ class OrganMetrics(BaseModel):
|
|
| 41 |
area_pixels: int = Field(..., description="Area in pixels")
|
| 42 |
area_cm2: float = Field(..., description="Approximate area in cm²")
|
| 43 |
centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
|
| 44 |
-
bbox: Tuple[int, int, int, int] = Field(
|
| 45 |
-
..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)"
|
| 46 |
-
)
|
| 47 |
|
| 48 |
# Size metrics
|
| 49 |
width: int = Field(..., description="Width of the organ in pixels")
|
|
@@ -51,9 +49,7 @@ class OrganMetrics(BaseModel):
|
|
| 51 |
aspect_ratio: float = Field(..., description="Height/width ratio")
|
| 52 |
|
| 53 |
# Position metrics
|
| 54 |
-
relative_position: Dict[str, float] = Field(
|
| 55 |
-
..., description="Position relative to image boundaries (0-1 scale)"
|
| 56 |
-
)
|
| 57 |
|
| 58 |
# Analysis metrics
|
| 59 |
mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
|
|
@@ -90,9 +86,7 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 90 |
self.model = self.model.to(self.device)
|
| 91 |
self.model.eval()
|
| 92 |
|
| 93 |
-
self.transform = torchvision.transforms.Compose(
|
| 94 |
-
[xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)]
|
| 95 |
-
)
|
| 96 |
|
| 97 |
self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
|
| 98 |
self.temp_dir.mkdir(exist_ok=True)
|
|
@@ -115,9 +109,7 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 115 |
"Spine": 13,
|
| 116 |
}
|
| 117 |
|
| 118 |
-
def _align_mask_to_original(
|
| 119 |
-
self, mask: np.ndarray, original_shape: Tuple[int, int]
|
| 120 |
-
) -> np.ndarray:
|
| 121 |
"""
|
| 122 |
Align a mask from the transformed (cropped/resized) space back to the full original image.
|
| 123 |
Assumes that the transform does a center crop to a square of side = min(original height, width)
|
|
@@ -170,23 +162,17 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 170 |
bbox=tuple(map(int, props.bbox)),
|
| 171 |
width=int(props.bbox[3] - props.bbox[1]),
|
| 172 |
height=int(props.bbox[2] - props.bbox[0]),
|
| 173 |
-
aspect_ratio=float(
|
| 174 |
-
(props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])
|
| 175 |
-
),
|
| 176 |
relative_position=relative_pos,
|
| 177 |
mean_intensity=float(mean_intensity),
|
| 178 |
std_intensity=float(std_intensity),
|
| 179 |
confidence_score=float(confidence),
|
| 180 |
)
|
| 181 |
|
| 182 |
-
def _save_visualization(
|
| 183 |
-
self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]
|
| 184 |
-
) -> str:
|
| 185 |
"""Save visualization of original image with segmentation masks overlaid."""
|
| 186 |
plt.figure(figsize=(10, 10))
|
| 187 |
-
plt.imshow(
|
| 188 |
-
original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0]
|
| 189 |
-
)
|
| 190 |
|
| 191 |
# Generate color palette for organs
|
| 192 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
|
|
@@ -202,14 +188,10 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 202 |
# Create a colored overlay with transparency
|
| 203 |
colored_mask = np.zeros((*original_img.shape, 4))
|
| 204 |
colored_mask[mask > 0] = (*color[:3], 0.3)
|
| 205 |
-
plt.imshow(
|
| 206 |
-
colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0]
|
| 207 |
-
)
|
| 208 |
|
| 209 |
# Add legend entry for the organ
|
| 210 |
-
organ_name = list(self.organ_map.keys())[
|
| 211 |
-
list(self.organ_map.values()).index(organ_idx)
|
| 212 |
-
]
|
| 213 |
plt.plot([], [], color=color, label=organ_name, linewidth=3)
|
| 214 |
|
| 215 |
plt.title("Segmentation Overlay")
|
|
@@ -266,9 +248,7 @@ class ChestXRaySegmentationTool(BaseTool):
|
|
| 266 |
for idx, organ_name in zip(organ_indices, organs):
|
| 267 |
mask = pred_masks[0, idx].cpu().numpy()
|
| 268 |
if mask.sum() > 0:
|
| 269 |
-
metrics = self._compute_organ_metrics(
|
| 270 |
-
mask, original_img, float(pred_probs[0, idx].mean().cpu())
|
| 271 |
-
)
|
| 272 |
if metrics:
|
| 273 |
results[organ_name] = metrics
|
| 274 |
|
|
|
|
| 41 |
area_pixels: int = Field(..., description="Area in pixels")
|
| 42 |
area_cm2: float = Field(..., description="Approximate area in cm²")
|
| 43 |
centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
|
| 44 |
+
bbox: Tuple[int, int, int, int] = Field(..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)")
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# Size metrics
|
| 47 |
width: int = Field(..., description="Width of the organ in pixels")
|
|
|
|
| 49 |
aspect_ratio: float = Field(..., description="Height/width ratio")
|
| 50 |
|
| 51 |
# Position metrics
|
| 52 |
+
relative_position: Dict[str, float] = Field(..., description="Position relative to image boundaries (0-1 scale)")
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Analysis metrics
|
| 55 |
mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
|
|
|
|
| 86 |
self.model = self.model.to(self.device)
|
| 87 |
self.model.eval()
|
| 88 |
|
| 89 |
+
self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)])
|
|
|
|
|
|
|
| 90 |
|
| 91 |
self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
|
| 92 |
self.temp_dir.mkdir(exist_ok=True)
|
|
|
|
| 109 |
"Spine": 13,
|
| 110 |
}
|
| 111 |
|
| 112 |
+
def _align_mask_to_original(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray:
|
|
|
|
|
|
|
| 113 |
"""
|
| 114 |
Align a mask from the transformed (cropped/resized) space back to the full original image.
|
| 115 |
Assumes that the transform does a center crop to a square of side = min(original height, width)
|
|
|
|
| 162 |
bbox=tuple(map(int, props.bbox)),
|
| 163 |
width=int(props.bbox[3] - props.bbox[1]),
|
| 164 |
height=int(props.bbox[2] - props.bbox[0]),
|
| 165 |
+
aspect_ratio=float((props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])),
|
|
|
|
|
|
|
| 166 |
relative_position=relative_pos,
|
| 167 |
mean_intensity=float(mean_intensity),
|
| 168 |
std_intensity=float(std_intensity),
|
| 169 |
confidence_score=float(confidence),
|
| 170 |
)
|
| 171 |
|
| 172 |
+
def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
|
|
|
|
|
|
|
| 173 |
"""Save visualization of original image with segmentation masks overlaid."""
|
| 174 |
plt.figure(figsize=(10, 10))
|
| 175 |
+
plt.imshow(original_img, cmap="gray", extent=[0, original_img.shape[1], original_img.shape[0], 0])
|
|
|
|
|
|
|
| 176 |
|
| 177 |
# Generate color palette for organs
|
| 178 |
colors = plt.cm.rainbow(np.linspace(0, 1, len(organ_indices)))
|
|
|
|
| 188 |
# Create a colored overlay with transparency
|
| 189 |
colored_mask = np.zeros((*original_img.shape, 4))
|
| 190 |
colored_mask[mask > 0] = (*color[:3], 0.3)
|
| 191 |
+
plt.imshow(colored_mask, extent=[0, original_img.shape[1], original_img.shape[0], 0])
|
|
|
|
|
|
|
| 192 |
|
| 193 |
# Add legend entry for the organ
|
| 194 |
+
organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
|
|
|
|
|
|
|
| 195 |
plt.plot([], [], color=color, label=organ_name, linewidth=3)
|
| 196 |
|
| 197 |
plt.title("Segmentation Overlay")
|
|
|
|
| 248 |
for idx, organ_name in zip(organ_indices, organs):
|
| 249 |
mask = pred_masks[0, idx].cpu().numpy()
|
| 250 |
if mask.sum() > 0:
|
| 251 |
+
metrics = self._compute_organ_metrics(mask, original_img, float(pred_probs[0, idx].mean().cpu()))
|
|
|
|
|
|
|
| 252 |
if metrics:
|
| 253 |
results[organ_name] = metrics
|
| 254 |
|
medrax/tools/utils.py
CHANGED
|
@@ -16,18 +16,10 @@ class ImageVisualizerInput(BaseModel):
|
|
| 16 |
|
| 17 |
image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
|
| 18 |
title: Optional[str] = Field(None, description="Optional title to display above the image")
|
| 19 |
-
description: Optional[str] = Field(
|
| 20 |
-
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
10, description="Optional figure width in inches"
|
| 24 |
-
)
|
| 25 |
-
height: Optional[int] = Field(
|
| 26 |
-
10, description="Optional figure height in inches"
|
| 27 |
-
)
|
| 28 |
-
cmap: Optional[str] = Field(
|
| 29 |
-
"rgb", description="Optional colormap to use for displaying the image"
|
| 30 |
-
)
|
| 31 |
|
| 32 |
|
| 33 |
class ImageVisualizerTool(BaseTool):
|
|
@@ -65,9 +57,7 @@ class ImageVisualizerTool(BaseTool):
|
|
| 65 |
|
| 66 |
# Add description if provided
|
| 67 |
if description:
|
| 68 |
-
plt.figtext(
|
| 69 |
-
0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10
|
| 70 |
-
)
|
| 71 |
|
| 72 |
# Adjust margins to minimize whitespace while preventing overlap
|
| 73 |
plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
|
|
|
|
| 16 |
|
| 17 |
image_path: str = Field(..., description="Path to the image file to display, only supports JPG or PNG images")
|
| 18 |
title: Optional[str] = Field(None, description="Optional title to display above the image")
|
| 19 |
+
description: Optional[str] = Field(None, description="Optional description to display below the image")
|
| 20 |
+
width: Optional[int] = Field(10, description="Optional figure width in inches")
|
| 21 |
+
height: Optional[int] = Field(10, description="Optional figure height in inches")
|
| 22 |
+
cmap: Optional[str] = Field("rgb", description="Optional colormap to use for displaying the image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class ImageVisualizerTool(BaseTool):
|
|
|
|
| 57 |
|
| 58 |
# Add description if provided
|
| 59 |
if description:
|
| 60 |
+
plt.figtext(0.5, 0.01, description, wrap=True, horizontalalignment="center", fontsize=10)
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Adjust margins to minimize whitespace while preventing overlap
|
| 63 |
plt.subplots_adjust(top=0.95, bottom=0.05, left=0.05, right=0.95)
|
medrax/tools/vqa/__init__.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
"""Visual Question Answering tools for medical images."""
|
| 2 |
|
| 3 |
from .llava_med import LlavaMedTool, LlavaMedInput
|
| 4 |
-
from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
|
| 5 |
from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
|
| 6 |
from .medgemma.medgemma_setup import setup_medgemma_env
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
"LlavaMedTool",
|
| 10 |
"LlavaMedInput",
|
| 11 |
-
"CheXagentXRayVQATool",
|
| 12 |
"XRayVQAToolInput",
|
| 13 |
"MedGemmaAPIClientTool",
|
| 14 |
"MedGemmaVQAInput",
|
| 15 |
-
"setup_medgemma_env"
|
| 16 |
-
]
|
|
|
|
| 1 |
"""Visual Question Answering tools for medical images."""
|
| 2 |
|
| 3 |
from .llava_med import LlavaMedTool, LlavaMedInput
|
| 4 |
+
from .xray_vqa import CheXagentXRayVQATool, XRayVQAToolInput
|
| 5 |
from .medgemma.medgemma_client import MedGemmaAPIClientTool, MedGemmaVQAInput
|
| 6 |
from .medgemma.medgemma_setup import setup_medgemma_env
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
"LlavaMedTool",
|
| 10 |
"LlavaMedInput",
|
| 11 |
+
"CheXagentXRayVQATool",
|
| 12 |
"XRayVQAToolInput",
|
| 13 |
"MedGemmaAPIClientTool",
|
| 14 |
"MedGemmaVQAInput",
|
| 15 |
+
"setup_medgemma_env",
|
| 16 |
+
]
|
medrax/tools/vqa/llava_med.py
CHANGED
|
@@ -83,13 +83,7 @@ class LlavaMedTool(BaseTool):
|
|
| 83 |
self, question: str, image_path: Optional[str] = None
|
| 84 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 85 |
if self.model.config.mm_use_im_start_end:
|
| 86 |
-
question =
|
| 87 |
-
DEFAULT_IM_START_TOKEN
|
| 88 |
-
+ DEFAULT_IMAGE_TOKEN
|
| 89 |
-
+ DEFAULT_IM_END_TOKEN
|
| 90 |
-
+ "\n"
|
| 91 |
-
+ question
|
| 92 |
-
)
|
| 93 |
else:
|
| 94 |
question = DEFAULT_IMAGE_TOKEN + "\n" + question
|
| 95 |
|
|
@@ -99,9 +93,7 @@ class LlavaMedTool(BaseTool):
|
|
| 99 |
prompt = conv.get_prompt()
|
| 100 |
|
| 101 |
input_ids = (
|
| 102 |
-
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
| 103 |
-
.unsqueeze(0)
|
| 104 |
-
.cuda()
|
| 105 |
)
|
| 106 |
|
| 107 |
image_tensor = None
|
|
@@ -147,11 +139,11 @@ class LlavaMedTool(BaseTool):
|
|
| 147 |
)
|
| 148 |
|
| 149 |
answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 150 |
-
|
| 151 |
output = {
|
| 152 |
"answer": answer,
|
| 153 |
}
|
| 154 |
-
|
| 155 |
metadata = {
|
| 156 |
"question": question,
|
| 157 |
"image_path": image_path,
|
|
|
|
| 83 |
self, question: str, image_path: Optional[str] = None
|
| 84 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 85 |
if self.model.config.mm_use_im_start_end:
|
| 86 |
+
question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
else:
|
| 88 |
question = DEFAULT_IMAGE_TOKEN + "\n" + question
|
| 89 |
|
|
|
|
| 93 |
prompt = conv.get_prompt()
|
| 94 |
|
| 95 |
input_ids = (
|
| 96 |
+
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
image_tensor = None
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 142 |
+
|
| 143 |
output = {
|
| 144 |
"answer": answer,
|
| 145 |
}
|
| 146 |
+
|
| 147 |
metadata = {
|
| 148 |
"question": question,
|
| 149 |
"image_path": image_path,
|
medrax/tools/vqa/xray_vqa.py
CHANGED
|
@@ -15,13 +15,9 @@ from langchain_core.tools import BaseTool
|
|
| 15 |
class XRayVQAToolInput(BaseModel):
|
| 16 |
"""Input schema for the CheXagent Tool."""
|
| 17 |
|
| 18 |
-
image_paths: List[str] = Field(
|
| 19 |
-
..., description="List of paths to chest X-ray images to analyze"
|
| 20 |
-
)
|
| 21 |
prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
|
| 22 |
-
max_new_tokens: int = Field(
|
| 23 |
-
512, description="Maximum number of tokens to generate in the response"
|
| 24 |
-
)
|
| 25 |
|
| 26 |
|
| 27 |
class CheXagentXRayVQATool(BaseTool):
|
|
@@ -99,16 +95,14 @@ class CheXagentXRayVQATool(BaseTool):
|
|
| 99 |
Returns:
|
| 100 |
str: Model's response
|
| 101 |
"""
|
| 102 |
-
query = self.tokenizer.from_list_format(
|
| 103 |
-
[*[{"image": path} for path in image_paths], {"text": prompt}]
|
| 104 |
-
)
|
| 105 |
conv = [
|
| 106 |
{"from": "system", "value": "You are a helpful assistant."},
|
| 107 |
{"from": "human", "value": query},
|
| 108 |
]
|
| 109 |
-
input_ids = self.tokenizer.apply_chat_template(
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
|
| 113 |
# Run inference
|
| 114 |
with torch.inference_mode():
|
|
|
|
| 15 |
class XRayVQAToolInput(BaseModel):
|
| 16 |
"""Input schema for the CheXagent Tool."""
|
| 17 |
|
| 18 |
+
image_paths: List[str] = Field(..., description="List of paths to chest X-ray images to analyze")
|
|
|
|
|
|
|
| 19 |
prompt: str = Field(..., description="Question or instruction about the chest X-ray images")
|
| 20 |
+
max_new_tokens: int = Field(512, description="Maximum number of tokens to generate in the response")
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class CheXagentXRayVQATool(BaseTool):
|
|
|
|
| 95 |
Returns:
|
| 96 |
str: Model's response
|
| 97 |
"""
|
| 98 |
+
query = self.tokenizer.from_list_format([*[{"image": path} for path in image_paths], {"text": prompt}])
|
|
|
|
|
|
|
| 99 |
conv = [
|
| 100 |
{"from": "system", "value": "You are a helpful assistant."},
|
| 101 |
{"from": "human", "value": query},
|
| 102 |
]
|
| 103 |
+
input_ids = self.tokenizer.apply_chat_template(conv, add_generation_prompt=True, return_tensors="pt").to(
|
| 104 |
+
device=self.device
|
| 105 |
+
)
|
| 106 |
|
| 107 |
# Run inference
|
| 108 |
with torch.inference_mode():
|
medrax/tools/xray_generation.py
CHANGED
|
@@ -11,26 +11,15 @@ from langchain_core.tools import BaseTool
|
|
| 11 |
|
| 12 |
class ChestXRayGeneratorInput(BaseModel):
|
| 13 |
"""Input schema for the Chest X-Ray Generator Tool."""
|
| 14 |
-
|
| 15 |
prompt: str = Field(
|
| 16 |
-
...,
|
| 17 |
-
description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
|
| 18 |
-
)
|
| 19 |
-
height: int = Field(
|
| 20 |
-
512,
|
| 21 |
-
description="Height of generated image in pixels"
|
| 22 |
-
)
|
| 23 |
-
width: int = Field(
|
| 24 |
-
512,
|
| 25 |
-
description="Width of generated image in pixels"
|
| 26 |
-
)
|
| 27 |
-
num_inference_steps: int = Field(
|
| 28 |
-
75,
|
| 29 |
-
description="Number of denoising steps (higher = better quality but slower)"
|
| 30 |
)
|
|
|
|
|
|
|
|
|
|
| 31 |
guidance_scale: float = Field(
|
| 32 |
-
4.0,
|
| 33 |
-
description="How closely to follow the prompt (higher = more faithful but less diverse)"
|
| 34 |
)
|
| 35 |
|
| 36 |
|
|
@@ -60,11 +49,11 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 60 |
):
|
| 61 |
"""Initialize the chest X-ray generator tool."""
|
| 62 |
super().__init__()
|
| 63 |
-
|
| 64 |
self.device = torch.device(device) if device else "cuda"
|
| 65 |
self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
|
| 66 |
self.model = self.model.to(torch.float32).to(self.device)
|
| 67 |
-
|
| 68 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
| 69 |
self.temp_dir.mkdir(exist_ok=True)
|
| 70 |
|
|
@@ -97,7 +86,7 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 97 |
num_inference_steps=num_inference_steps,
|
| 98 |
height=height,
|
| 99 |
width=width,
|
| 100 |
-
guidance_scale=guidance_scale
|
| 101 |
)
|
| 102 |
|
| 103 |
# Save generated image
|
|
@@ -107,7 +96,7 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 107 |
output = {
|
| 108 |
"image_path": str(image_path),
|
| 109 |
}
|
| 110 |
-
|
| 111 |
metadata = {
|
| 112 |
"prompt": prompt,
|
| 113 |
"num_inference_steps": num_inference_steps,
|
|
@@ -126,7 +115,7 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 126 |
"prompt": prompt,
|
| 127 |
"analysis_status": "failed",
|
| 128 |
"error_details": str(e),
|
| 129 |
-
}
|
| 130 |
)
|
| 131 |
|
| 132 |
async def _arun(
|
|
@@ -139,4 +128,4 @@ class ChestXRayGeneratorTool(BaseTool):
|
|
| 139 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 140 |
) -> Tuple[Dict[str, str], Dict]:
|
| 141 |
"""Async version of _run."""
|
| 142 |
-
return self._run(prompt, num_inference_steps, guidance_scale, height, width)
|
|
|
|
| 11 |
|
| 12 |
class ChestXRayGeneratorInput(BaseModel):
|
| 13 |
"""Input schema for the Chest X-Ray Generator Tool."""
|
| 14 |
+
|
| 15 |
prompt: str = Field(
|
| 16 |
+
..., description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
+
height: int = Field(512, description="Height of generated image in pixels")
|
| 19 |
+
width: int = Field(512, description="Width of generated image in pixels")
|
| 20 |
+
num_inference_steps: int = Field(75, description="Number of denoising steps (higher = better quality but slower)")
|
| 21 |
guidance_scale: float = Field(
|
| 22 |
+
4.0, description="How closely to follow the prompt (higher = more faithful but less diverse)"
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
|
|
|
|
| 49 |
):
|
| 50 |
"""Initialize the chest X-ray generator tool."""
|
| 51 |
super().__init__()
|
| 52 |
+
|
| 53 |
self.device = torch.device(device) if device else "cuda"
|
| 54 |
self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
|
| 55 |
self.model = self.model.to(torch.float32).to(self.device)
|
| 56 |
+
|
| 57 |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
|
| 58 |
self.temp_dir.mkdir(exist_ok=True)
|
| 59 |
|
|
|
|
| 86 |
num_inference_steps=num_inference_steps,
|
| 87 |
height=height,
|
| 88 |
width=width,
|
| 89 |
+
guidance_scale=guidance_scale,
|
| 90 |
)
|
| 91 |
|
| 92 |
# Save generated image
|
|
|
|
| 96 |
output = {
|
| 97 |
"image_path": str(image_path),
|
| 98 |
}
|
| 99 |
+
|
| 100 |
metadata = {
|
| 101 |
"prompt": prompt,
|
| 102 |
"num_inference_steps": num_inference_steps,
|
|
|
|
| 115 |
"prompt": prompt,
|
| 116 |
"analysis_status": "failed",
|
| 117 |
"error_details": str(e),
|
| 118 |
+
},
|
| 119 |
)
|
| 120 |
|
| 121 |
async def _arun(
|
|
|
|
| 128 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
| 129 |
) -> Tuple[Dict[str, str], Dict]:
|
| 130 |
"""Async version of _run."""
|
| 131 |
+
return self._run(prompt, num_inference_steps, guidance_scale, height, width)
|