VictorLJZ commited on
Commit
f60c51c
·
1 Parent(s): e040fb2

added openrouter provider

Browse files
benchmarking/cli.py CHANGED
@@ -22,6 +22,7 @@ def create_llm_provider(model_name: str, provider_type: str, **kwargs) -> LLMPro
22
  provider_map = {
23
  "openai": OpenAIProvider,
24
  "google": GoogleProvider,
 
25
  "medrax": MedRAXProvider,
26
  }
27
 
@@ -112,14 +113,14 @@ def main():
112
  # Run benchmark command
113
  run_parser = subparsers.add_parser("run", help="Run a benchmark")
114
  run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
115
- run_parser.add_argument("--provider", required=True, choices=["openai", "google", "medrax"], help="LLM provider")
116
  run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
117
  run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
118
  run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
119
  run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
120
  run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
121
  run_parser.add_argument("--top-p", type=float, default=0.95, help="Top-p value")
122
- run_parser.add_argument("--max-tokens", type=int, default=5000, help="Maximum tokens per response")
123
 
124
  run_parser.set_defaults(func=run_benchmark_command)
125
 
 
22
  provider_map = {
23
  "openai": OpenAIProvider,
24
  "google": GoogleProvider,
25
+ "openrouter": OpenRouterProvider,
26
  "medrax": MedRAXProvider,
27
  }
28
 
 
113
  # Run benchmark command
114
  run_parser = subparsers.add_parser("run", help="Run a benchmark")
115
  run_parser.add_argument("--model", required=True, help="Model name (e.g., gpt-4o, gemini-2.5-pro)")
116
+ run_parser.add_argument("--provider", required=True, choices=["openai", "google", "openrouter", "medrax"], help="LLM provider")
117
  run_parser.add_argument("--benchmark", required=True, choices=["rexvqa", "chestagentbench"], help="Benchmark to run")
118
  run_parser.add_argument("--data-dir", required=True, help="Directory containing benchmark data")
119
  run_parser.add_argument("--output-dir", default="benchmark_results", help="Output directory for results")
120
  run_parser.add_argument("--max-questions", type=int, help="Maximum number of questions to process")
121
  run_parser.add_argument("--temperature", type=float, default=0.7, help="Model temperature")
122
  run_parser.add_argument("--top-p", type=float, default=0.95, help="Top-p value")
123
+ run_parser.add_argument("--max-tokens", type=int, default=1000, help="Maximum tokens per response")
124
 
125
  run_parser.set_defaults(func=run_benchmark_command)
126
 
benchmarking/llm_providers/__init__.py CHANGED
@@ -4,6 +4,7 @@ from .base import LLMProvider, LLMRequest, LLMResponse
4
  from .openai_provider import OpenAIProvider
5
  from .google_provider import GoogleProvider
6
  from .medrax_provider import MedRAXProvider
 
7
 
8
  __all__ = [
9
  "LLMProvider",
@@ -12,4 +13,5 @@ __all__ = [
12
  "OpenAIProvider",
13
  "GoogleProvider",
14
  "MedRAXProvider",
 
15
  ]
 
4
  from .openai_provider import OpenAIProvider
5
  from .google_provider import GoogleProvider
6
  from .medrax_provider import MedRAXProvider
7
+ from .openrouter_provider import OpenRouterProvider
8
 
9
  __all__ = [
10
  "LLMProvider",
 
13
  "OpenAIProvider",
14
  "GoogleProvider",
15
  "MedRAXProvider",
16
+ "OpenRouterProvider",
17
  ]
benchmarking/llm_providers/openrouter_provider.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """xAI LLM provider implementation using OpenRouter API via OpenAI SDK."""
2
+
3
+ import os
4
+ import time
5
+ from tenacity import retry, wait_exponential, stop_after_attempt
6
+ import base64
7
+ from openai import OpenAI
8
+
9
+ from .base import LLMProvider, LLMRequest, LLMResponse
10
+
11
+
12
+ class OpenRouterProvider(LLMProvider):
13
+ """LLM provider using OpenRouter API via OpenAI SDK."""
14
+
15
+ def _setup(self) -> None:
16
+ """Set up OpenRouter client models."""
17
+ api_key = os.getenv("OPENROUTER_API_KEY")
18
+ if not api_key:
19
+ raise ValueError("OPENROUTER_API_KEY environment variable is required for xAI Grok via OpenRouter.")
20
+ base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
21
+ # Use OpenAI SDK with OpenRouter endpoint
22
+ self.client = OpenAI(api_key=api_key, base_url=base_url)
23
+
24
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
25
+ def generate_response(self, request: LLMRequest) -> LLMResponse:
26
+ """Generate response using OpenRouter Grok model via OpenAI SDK.
27
+
28
+ Args:
29
+ request (LLMRequest): The request containing text, images, and parameters
30
+ Returns:
31
+ LLMResponse: The response from xAI Grok via OpenRouter
32
+ """
33
+ start_time = time.time()
34
+
35
+ # Build messages
36
+ messages = []
37
+ if self.system_prompt:
38
+ messages.append({"role": "system", "content": self.system_prompt})
39
+
40
+ user_content = []
41
+ user_content.append({"type": "text", "text": request.text})
42
+
43
+ # Add images if provided
44
+ if request.images:
45
+ valid_images = self._validate_image_paths(request.images)
46
+ for image_path in valid_images:
47
+ try:
48
+ image_b64 = self._encode_image(image_path)
49
+ user_content.append({
50
+ "type": "image_url",
51
+ "image_url": {
52
+ "url": f"data:image/jpeg;base64,{image_b64}",
53
+ "detail": "high"
54
+ }
55
+ })
56
+ except Exception as e:
57
+ print(f"Error reading image {image_path}: {e}")
58
+
59
+ messages.append({"role": "user", "content": user_content})
60
+
61
+ try:
62
+ response = self.client.chat.completions.create(
63
+ model=self.model_name,
64
+ messages=messages,
65
+ temperature=request.temperature,
66
+ top_p=request.top_p,
67
+ max_tokens=request.max_tokens,
68
+ **(request.additional_params or {})
69
+ )
70
+ duration = time.time() - start_time
71
+ content = response.choices[0].message.content if response.choices else ""
72
+ usage = {}
73
+ if hasattr(response, 'usage') and response.usage:
74
+ usage = {
75
+ "prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
76
+ "completion_tokens": getattr(response.usage, "completion_tokens", 0),
77
+ "total_tokens": getattr(response.usage, "total_tokens", 0)
78
+ }
79
+ return LLMResponse(
80
+ content=content,
81
+ usage=usage,
82
+ duration=duration,
83
+ raw_response=response
84
+ )
85
+ except Exception as e:
86
+ return LLMResponse(
87
+ content=f"Error: {str(e)}",
88
+ duration=time.time() - start_time,
89
+ raw_response=None
90
+ )
medrax/models/model_factory.py CHANGED
@@ -28,7 +28,11 @@ class ModelFactory:
28
  "env_key": "OPENAI_API_KEY",
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
- "gemini": {"class": ChatGoogleGenerativeAI, "env_key": "GOOGLE_API_KEY"},
 
 
 
 
32
  "openrouter": {
33
  "class": ChatOpenAI, # OpenRouter uses OpenAI-compatible interface
34
  "env_key": "OPENROUTER_API_KEY",
 
28
  "env_key": "OPENAI_API_KEY",
29
  "base_url_key": "OPENAI_BASE_URL",
30
  },
31
+ "gemini": {
32
+ "class": ChatGoogleGenerativeAI,
33
+ "env_key": "GOOGLE_API_KEY",
34
+ "base_url_key": "GOOGLE_BASE_URL",
35
+ },
36
  "openrouter": {
37
  "class": ChatOpenAI, # OpenRouter uses OpenAI-compatible interface
38
  "env_key": "OPENROUTER_API_KEY",