Spaces:
Runtime error
Runtime error
| from typing import Any | |
| from collections.abc import Generator | |
| from smolagents import ( | |
| OpenAIModel, | |
| ChatMessage, | |
| ChatMessageStreamDelta, | |
| Tool, | |
| TokenUsage | |
| ) | |
| from smolagents.models import ( | |
| ChatMessageToolCallStreamDelta, | |
| ChatMessageStreamDelta, | |
| remove_content_after_stop_sequences | |
| ) | |
| import openai | |
| class SwitchableOpenAIModel(OpenAIModel): | |
| """This model connects to an OpenAI-compatible API server. | |
| Parameters: | |
| model_list (`str`): | |
| The models identifier to use on the server (e.g. "gpt-5"). | |
| api_base (`str`, *optional*): | |
| The base URL of the OpenAI-compatible API server. | |
| api_key (`str`, *optional*): | |
| The API key to use for authentication. | |
| organization (`str`, *optional*): | |
| The organization to use for the API request. | |
| project (`str`, *optional*): | |
| The project to use for the API request. | |
| client_kwargs (`dict[str, Any]`, *optional*): | |
| Additional keyword arguments to pass to the OpenAI client (like organization, project, max_retries etc.). | |
| custom_role_conversions (`dict[str, str]`, *optional*): | |
| Custom role conversion mapping to convert message roles in others. | |
| Useful for specific models that do not support specific message roles like "system". | |
| flatten_messages_as_text (`bool`, default `False`): | |
| Whether to flatten messages as text. | |
| **kwargs: | |
| Additional keyword arguments to forward to the underlying OpenAI API completion call, for instance `temperature`. | |
| """ | |
| def __init__( | |
| self, | |
| model_list: str, | |
| api_base: str | None = None, | |
| api_key: str | None = None, | |
| organization: str | None = None, | |
| project: str | None = None, | |
| client_kwargs: dict[str, Any] | None = None, | |
| custom_role_conversions: dict[str, str] | None = None, | |
| flatten_messages_as_text: bool = False, | |
| **kwargs, | |
| ): | |
| self.model_list = model_list | |
| self.model_index = 0 | |
| super().__init__( | |
| model_id=self.model_list[self.model_index], | |
| api_base=api_base, | |
| api_key=api_key, | |
| organization=organization, | |
| project=project, | |
| client_kwargs=client_kwargs, | |
| custom_role_conversions=custom_role_conversions, | |
| flatten_messages_as_text=flatten_messages_as_text, | |
| **kwargs, | |
| ) | |
| def generate_stream( | |
| self, | |
| messages: list[ChatMessage | dict], | |
| stop_sequences: list[str] | None = None, | |
| response_format: dict[str, str] | None = None, | |
| tools_to_call_from: list[Tool] | None = None, | |
| **kwargs, | |
| ) -> Generator[ChatMessageStreamDelta]: | |
| completion_kwargs = self._prepare_completion_kwargs( | |
| messages=messages, | |
| stop_sequences=stop_sequences, | |
| response_format=response_format, | |
| tools_to_call_from=tools_to_call_from, | |
| model=self.model_list[self.model_index], | |
| custom_role_conversions=self.custom_role_conversions, | |
| convert_images_to_image_urls=True, | |
| **kwargs, | |
| ) | |
| self._apply_rate_limit() | |
| try: | |
| for event in self.client.chat.completions.create( | |
| **completion_kwargs, | |
| stream=True, | |
| stream_options={"include_usage": True}, | |
| ): | |
| if event.usage: | |
| yield ChatMessageStreamDelta( | |
| content="", | |
| token_usage=TokenUsage( | |
| input_tokens=event.usage.prompt_tokens, | |
| output_tokens=event.usage.completion_tokens, | |
| ), | |
| ) | |
| if event.choices: | |
| choice = event.choices[0] | |
| if choice.delta: | |
| yield ChatMessageStreamDelta( | |
| content=choice.delta.content, | |
| tool_calls=[ | |
| ChatMessageToolCallStreamDelta( | |
| index=delta.index, | |
| id=delta.id, | |
| type=delta.type, | |
| function=delta.function, | |
| ) | |
| for delta in choice.delta.tool_calls | |
| ] | |
| if choice.delta.tool_calls | |
| else None, | |
| ) | |
| else: | |
| if not getattr(choice, "finish_reason", None): | |
| raise ValueError( | |
| f"No content or tool calls in event: {event}") | |
| except openai.RateLimitError as err: | |
| if self.model_index < len(self.model_list) - 1: | |
| self.model_index += 1 | |
| print( | |
| f"Switching to model {self.model_list[self.model_index]}") | |
| return self.generate_stream( | |
| messages=messages, | |
| stop_sequences=stop_sequences, | |
| response_format=response_format, | |
| tools_to_call_from=tools_to_call_from, | |
| **kwargs, | |
| ) | |
| else: | |
| raise err | |
| except Exception as err: | |
| raise err | |
| def generate( | |
| self, | |
| messages: list[ChatMessage | dict], | |
| stop_sequences: list[str] | None = None, | |
| response_format: dict[str, str] | None = None, | |
| tools_to_call_from: list[Tool] | None = None, | |
| **kwargs, | |
| ) -> ChatMessage: | |
| completion_kwargs = self._prepare_completion_kwargs( | |
| messages=messages, | |
| stop_sequences=stop_sequences, | |
| response_format=response_format, | |
| tools_to_call_from=tools_to_call_from, | |
| model=self.model_list[self.model_index], | |
| custom_role_conversions=self.custom_role_conversions, | |
| convert_images_to_image_urls=True, | |
| **kwargs, | |
| ) | |
| self._apply_rate_limit() | |
| try: | |
| response = self.client.chat.completions.create(**completion_kwargs) | |
| except openai.RateLimitError as err: | |
| if self.model_index < len(self.model_list) - 1: | |
| self.model_index += 1 | |
| print( | |
| f"Switching to model {self.model_list[self.model_index]}") | |
| return self.generate( | |
| messages=messages, | |
| stop_sequences=stop_sequences, | |
| response_format=response_format, | |
| tools_to_call_from=tools_to_call_from, | |
| **kwargs, | |
| ) | |
| else: | |
| raise err | |
| except Exception as err: | |
| raise err | |
| content = response.choices[0].message.content | |
| if stop_sequences is not None and not self.supports_stop_parameter: | |
| content = remove_content_after_stop_sequences( | |
| content, stop_sequences) | |
| return ChatMessage( | |
| role=response.choices[0].message.role, | |
| content=content, | |
| tool_calls=response.choices[0].message.tool_calls, | |
| raw=response, | |
| token_usage=TokenUsage( | |
| input_tokens=response.usage.prompt_tokens, | |
| output_tokens=response.usage.completion_tokens, | |
| ), | |
| ) | |