agent_course / src /llm /openai_llm.py
initial01's picture
Upload 12 files
853cf7b verified
from typing import Any
from collections.abc import Generator
from smolagents import (
OpenAIModel,
ChatMessage,
ChatMessageStreamDelta,
Tool,
TokenUsage
)
from smolagents.models import (
ChatMessageToolCallStreamDelta,
ChatMessageStreamDelta,
remove_content_after_stop_sequences
)
import openai
class SwitchableOpenAIModel(OpenAIModel):
"""This model connects to an OpenAI-compatible API server.
Parameters:
model_list (`str`):
The models identifier to use on the server (e.g. "gpt-5").
api_base (`str`, *optional*):
The base URL of the OpenAI-compatible API server.
api_key (`str`, *optional*):
The API key to use for authentication.
organization (`str`, *optional*):
The organization to use for the API request.
project (`str`, *optional*):
The project to use for the API request.
client_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the OpenAI client (like organization, project, max_retries etc.).
custom_role_conversions (`dict[str, str]`, *optional*):
Custom role conversion mapping to convert message roles in others.
Useful for specific models that do not support specific message roles like "system".
flatten_messages_as_text (`bool`, default `False`):
Whether to flatten messages as text.
**kwargs:
Additional keyword arguments to forward to the underlying OpenAI API completion call, for instance `temperature`.
"""
def __init__(
self,
model_list: str,
api_base: str | None = None,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
client_kwargs: dict[str, Any] | None = None,
custom_role_conversions: dict[str, str] | None = None,
flatten_messages_as_text: bool = False,
**kwargs,
):
self.model_list = model_list
self.model_index = 0
super().__init__(
model_id=self.model_list[self.model_index],
api_base=api_base,
api_key=api_key,
organization=organization,
project=project,
client_kwargs=client_kwargs,
custom_role_conversions=custom_role_conversions,
flatten_messages_as_text=flatten_messages_as_text,
**kwargs,
)
def generate_stream(
self,
messages: list[ChatMessage | dict],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> Generator[ChatMessageStreamDelta]:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
model=self.model_list[self.model_index],
custom_role_conversions=self.custom_role_conversions,
convert_images_to_image_urls=True,
**kwargs,
)
self._apply_rate_limit()
try:
for event in self.client.chat.completions.create(
**completion_kwargs,
stream=True,
stream_options={"include_usage": True},
):
if event.usage:
yield ChatMessageStreamDelta(
content="",
token_usage=TokenUsage(
input_tokens=event.usage.prompt_tokens,
output_tokens=event.usage.completion_tokens,
),
)
if event.choices:
choice = event.choices[0]
if choice.delta:
yield ChatMessageStreamDelta(
content=choice.delta.content,
tool_calls=[
ChatMessageToolCallStreamDelta(
index=delta.index,
id=delta.id,
type=delta.type,
function=delta.function,
)
for delta in choice.delta.tool_calls
]
if choice.delta.tool_calls
else None,
)
else:
if not getattr(choice, "finish_reason", None):
raise ValueError(
f"No content or tool calls in event: {event}")
except openai.RateLimitError as err:
if self.model_index < len(self.model_list) - 1:
self.model_index += 1
print(
f"Switching to model {self.model_list[self.model_index]}")
return self.generate_stream(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
else:
raise err
except Exception as err:
raise err
def generate(
self,
messages: list[ChatMessage | dict],
stop_sequences: list[str] | None = None,
response_format: dict[str, str] | None = None,
tools_to_call_from: list[Tool] | None = None,
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
model=self.model_list[self.model_index],
custom_role_conversions=self.custom_role_conversions,
convert_images_to_image_urls=True,
**kwargs,
)
self._apply_rate_limit()
try:
response = self.client.chat.completions.create(**completion_kwargs)
except openai.RateLimitError as err:
if self.model_index < len(self.model_list) - 1:
self.model_index += 1
print(
f"Switching to model {self.model_list[self.model_index]}")
return self.generate(
messages=messages,
stop_sequences=stop_sequences,
response_format=response_format,
tools_to_call_from=tools_to_call_from,
**kwargs,
)
else:
raise err
except Exception as err:
raise err
content = response.choices[0].message.content
if stop_sequences is not None and not self.supports_stop_parameter:
content = remove_content_after_stop_sequences(
content, stop_sequences)
return ChatMessage(
role=response.choices[0].message.role,
content=content,
tool_calls=response.choices[0].message.tool_calls,
raw=response,
token_usage=TokenUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
),
)