amarzullo24 commited on
Commit
a122f1b
·
1 Parent(s): a7f6540

enabling local llm

Browse files
Files changed (3) hide show
  1. README.md +6 -0
  2. main.py +13 -1
  3. quickstart.py +8 -1
README.md CHANGED
@@ -214,6 +214,12 @@ ChestXRayGeneratorTool(
214
  - Some tools (LLaVA-Med, Grounding) are more resource-intensive
215
  <br>
216
 
 
 
 
 
 
 
217
 
218
  ## Star History
219
  <div align="center">
 
214
  - Some tools (LLaVA-Med, Grounding) are more resource-intensive
215
  <br>
216
 
217
+ ### Local LLMs
218
+ If you are running a local LLM using frameworks like [Ollama](https://ollama.com/) or [LM Studio](https://lmstudio.ai/), you need to configure your environment variables accordingly. For example:
219
+ ```
220
+ export OPENAI_BASE_URL="http://localhost:11434/v1"
221
+ export OPENAI_API_KEY="ollama"
222
+ ```
223
 
224
  ## Star History
225
  <div align="center">
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import warnings
2
  from typing import *
3
  from dotenv import load_dotenv
@@ -27,6 +28,7 @@ def initialize_agent(
27
  model="chatgpt-4o-latest",
28
  temperature=0.7,
29
  top_p=0.95,
 
30
  ):
31
  """Initialize the MedRAX agent with specified tools and configuration.
32
 
@@ -39,6 +41,7 @@ def initialize_agent(
39
  model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
40
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
41
  top_p (float, optional): Top P for the model. Defaults to 0.95.
 
42
 
43
  Returns:
44
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
@@ -72,7 +75,7 @@ def initialize_agent(
72
  tools_dict[tool_name] = all_tools[tool_name]()
73
 
74
  checkpointer = MemorySaver()
75
- model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p)
76
  agent = Agent(
77
  model,
78
  tools=list(tools_dict.values()),
@@ -107,6 +110,14 @@ if __name__ == "__main__":
107
  # "ChestXRayGeneratorTool",
108
  ]
109
 
 
 
 
 
 
 
 
 
110
  agent, tools_dict = initialize_agent(
111
  "medrax/docs/system_prompts.txt",
112
  tools_to_use=selected_tools,
@@ -116,6 +127,7 @@ if __name__ == "__main__":
116
  model="gpt-4o", # Change this to the model you want to use, e.g. gpt-4o-mini
117
  temperature=0.7,
118
  top_p=0.95,
 
119
  )
120
  demo = create_demo(agent, tools_dict)
121
 
 
1
+ import os
2
  import warnings
3
  from typing import *
4
  from dotenv import load_dotenv
 
28
  model="chatgpt-4o-latest",
29
  temperature=0.7,
30
  top_p=0.95,
31
+ openai_kwargs={}
32
  ):
33
  """Initialize the MedRAX agent with specified tools and configuration.
34
 
 
41
  model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
42
  temperature (float, optional): Temperature for the model. Defaults to 0.7.
43
  top_p (float, optional): Top P for the model. Defaults to 0.95.
44
+ openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
45
 
46
  Returns:
47
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
 
75
  tools_dict[tool_name] = all_tools[tool_name]()
76
 
77
  checkpointer = MemorySaver()
78
+ model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
79
  agent = Agent(
80
  model,
81
  tools=list(tools_dict.values()),
 
110
  # "ChestXRayGeneratorTool",
111
  ]
112
 
113
+ # Collect the ENV variables
114
+ openai_kwargs = {}
115
+ if api_key := os.getenv("OPENAI_API_KEY"):
116
+ openai_kwargs["api_key"] = api_key
117
+
118
+ if base_url := os.getenv("OPENAI_BASE_URL"):
119
+ openai_kwargs["base_url"] = base_url
120
+
121
  agent, tools_dict = initialize_agent(
122
  "medrax/docs/system_prompts.txt",
123
  tools_to_use=selected_tools,
 
127
  model="gpt-4o", # Change this to the model you want to use, e.g. gpt-4o-mini
128
  temperature=0.7,
129
  top_p=0.95,
130
+ openai_kwargs=openai_kwargs
131
  )
132
  demo = create_demo(agent, tools_dict)
133
 
quickstart.py CHANGED
@@ -230,10 +230,17 @@ def main():
230
  dataset = load_dataset("json", data_files="chestagentbench/metadata.jsonl")
231
  train_dataset = dataset["train"]
232
 
 
233
  api_key = os.getenv("OPENAI_API_KEY")
234
  if not api_key:
235
  raise ValueError("OPENAI_API_KEY environment variable is not set.")
236
- client = openai.OpenAI(api_key=api_key)
 
 
 
 
 
 
237
 
238
  total_examples = len(train_dataset)
239
  processed = 0
 
230
  dataset = load_dataset("json", data_files="chestagentbench/metadata.jsonl")
231
  train_dataset = dataset["train"]
232
 
233
+ # Collecting ENV variables
234
  api_key = os.getenv("OPENAI_API_KEY")
235
  if not api_key:
236
  raise ValueError("OPENAI_API_KEY environment variable is not set.")
237
+
238
+ kwargs = {}
239
+ if base_url := os.getenv("OPENAI_BASE_URL"):
240
+ kwargs["base_url"] = base_url
241
+
242
+ # Initialize the OpenAI Client
243
+ client = openai.OpenAI(api_key=api_key, **kwargs)
244
 
245
  total_examples = len(train_dataset)
246
  processed = 0