import gradio as gr import os import sys import json import gc import numpy as np from vllm import LLM, SamplingParams from jinja2 import Template from typing import List import types from tooluniverse import ToolUniverse from gradio import ChatMessage from .toolrag import ToolRAGModel import torch import logging # Configure logging with a more specific logger name logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("TxAgent") from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format class TxAgent: def __init__(self, model_name, rag_model_name, tool_files_dict=None, enable_finish=True, enable_rag=False, enable_summary=False, init_rag_num=0, step_rag_num=0, summary_mode='step', summary_skip_last_k=0, summary_context_length=None, force_finish=True, avoid_repeat=True, seed=None, enable_checker=False, enable_chat=False, additional_default_tools=None): self.model_name = model_name self.tokenizer = None self.terminators = None self.rag_model_name = rag_model_name self.tool_files_dict = tool_files_dict self.model = None self.rag_model = ToolRAGModel(rag_model_name) self.tooluniverse = None self.prompt_multi_step = "You are a helpful assistant that solves problems through step-by-step reasoning." self.self_prompt = "Strictly follow the instruction." self.chat_prompt = "You are a helpful assistant for user chat." self.enable_finish = enable_finish self.enable_rag = enable_rag self.enable_summary = enable_summary self.summary_mode = summary_mode self.summary_skip_last_k = summary_skip_last_k self.summary_context_length = summary_context_length self.init_rag_num = init_rag_num self.step_rag_num = step_rag_num self.force_finish = force_finish self.avoid_repeat = avoid_repeat self.seed = seed self.enable_checker = enable_checker self.additional_default_tools = additional_default_tools logger.info("TxAgent initialized with model: %s, RAG: %s", model_name, rag_model_name) def init_model(self): self.load_models() self.load_tooluniverse() def load_models(self, model_name=None): if model_name is not None: if model_name == self.model_name: return f"The model {model_name} is already loaded." self.model_name = model_name self.model = LLM( model=self.model_name, dtype="float16", max_model_len=131072, max_num_batched_tokens=65536, # Increased for A100 80GB max_num_seqs=512, gpu_memory_utilization=0.95, # Higher utilization for better performance trust_remote_code=True, ) self.chat_template = Template(self.model.get_tokenizer().chat_template) self.tokenizer = self.model.get_tokenizer() logger.info( "Model %s loaded with max_model_len=%d, max_num_batched_tokens=%d, gpu_memory_utilization=%.2f", self.model_name, 131072, 32768, 0.9 ) return f"Model {model_name} loaded successfully." def load_tooluniverse(self): self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict) self.tooluniverse.load_tools() special_tools = self.tooluniverse.prepare_tool_prompts( self.tooluniverse.tool_category_dicts["special_tools"]) self.special_tools_name = [tool['name'] for tool in special_tools] logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name)) def load_tool_desc_embedding(self): cache_path = os.path.join(os.path.dirname(self.tool_files_dict["new_tool"]), "tool_embeddings.pkl") if os.path.exists(cache_path): self.rag_model.load_cached_embeddings(cache_path) else: self.rag_model.load_tool_desc_embedding(self.tooluniverse) self.rag_model.save_embeddings(cache_path) logger.debug("Tool description embeddings loaded") def rag_infer(self, query, top_k=5): return self.rag_model.rag_infer(query, top_k) def initialize_tools_prompt(self, call_agent, call_agent_level, message): picked_tools_prompt = [] picked_tools_prompt = self.add_special_tools( picked_tools_prompt, call_agent=call_agent) if call_agent: call_agent_level += 1 if call_agent_level >= 2: call_agent = False return picked_tools_prompt, call_agent_level def initialize_conversation(self, message, conversation=None, history=None): if conversation is None: conversation = [] conversation = self.set_system_prompt( conversation, self.prompt_multi_step) if history: for i in range(len(history)): if history[i]['role'] == 'user': conversation.append({"role": "user", "content": history[i]['content']}) elif history[i]['role'] == 'assistant': conversation.append({"role": "assistant", "content": history[i]['content']}) conversation.append({"role": "user", "content": message}) logger.debug("Conversation initialized with %d messages", len(conversation)) return conversation def tool_RAG(self, message=None, picked_tool_names=None, existing_tools_prompt=[], rag_num=0, return_call_result=False): if not self.enable_rag: return [] extra_factor = 10 if picked_tool_names is None: assert picked_tool_names is not None or message is not None picked_tool_names = self.rag_infer( message, top_k=rag_num * extra_factor) picked_tool_names_no_special = [tool for tool in picked_tool_names if tool not in self.special_tools_name] picked_tool_names = picked_tool_names_no_special[:rag_num] picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names) picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools) logger.debug("Retrieved %d tools via RAG", len(picked_tools_prompt)) if return_call_result: return picked_tools_prompt, picked_tool_names return picked_tools_prompt def add_special_tools(self, tools, call_agent=False): if self.enable_finish: tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True)) logger.debug("Finish tool added") if call_agent: tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True)) logger.debug("CallAgent tool added") return tools def add_finish_tools(self, tools): tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True)) logger.debug("Finish tool added") return tools def set_system_prompt(self, conversation, sys_prompt): if not conversation: conversation.append({"role": "system", "content": sys_prompt}) else: conversation[0] = {"role": "system", "content": sys_prompt} return conversation def run_function_call(self, fcall_str, return_message=False, existing_tools_prompt=None, message_for_call_agent=None, call_agent=False, call_agent_level=None, temperature=None): try: function_call_json, message = self.tooluniverse.extract_function_call_json( fcall_str, return_message=return_message, verbose=False) except Exception as e: logger.error("Tool call parsing failed: %s", e) function_call_json = [] message = fcall_str call_results = [] special_tool_call = '' if function_call_json: if isinstance(function_call_json, list): for i in range(len(function_call_json)): logger.info("Tool Call: %s", function_call_json[i]) if function_call_json[i]["name"] == 'Finish': special_tool_call = 'Finish' break elif function_call_json[i]["name"] == 'CallAgent': if call_agent_level < 2 and call_agent: solution_plan = function_call_json[i]['arguments']['solution'] full_message = ( message_for_call_agent + "\nYou must follow the following plan to answer the question: " + str(solution_plan) ) call_result = self.run_multistep_agent( full_message, temperature=temperature, max_new_tokens=512, max_token=131072, call_agent=False, call_agent_level=call_agent_level) if call_result is None: call_result = "⚠️ No content returned from sub-agent." else: call_result = call_result.split('[FinalAnswer]')[-1].strip() else: call_result = "Error: CallAgent disabled." else: call_result = self.tooluniverse.run_one_function(function_call_json[i]) call_id = self.tooluniverse.call_id_gen() function_call_json[i]["call_id"] = call_id logger.info("Tool Call Result: %s", call_result) call_results.append({ "role": "tool", "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id}) }) else: call_results.append({ "role": "tool", "content": json.dumps({"content": "Invalid or no function call detected."}) }) revised_messages = [{ "role": "assistant", "content": message.strip(), "tool_calls": json.dumps(function_call_json) }] + call_results return revised_messages, existing_tools_prompt, special_tool_call def run_function_call_stream(self, fcall_str, return_message=False, existing_tools_prompt=None, message_for_call_agent=None, call_agent=False, call_agent_level=None, temperature=None, return_gradio_history=True): try: function_call_json, message = self.tooluniverse.extract_function_call_json( fcall_str, return_message=return_message, verbose=False) except Exception as e: logger.error("Tool call parsing failed: %s", e) function_call_json = [] message = fcall_str call_results = [] special_tool_call = '' if return_gradio_history: gradio_history = [] if function_call_json: if isinstance(function_call_json, list): for i in range(len(function_call_json)): if function_call_json[i]["name"] == 'Finish': special_tool_call = 'Finish' break elif function_call_json[i]["name"] == 'DirectResponse': call_result = function_call_json[i]['arguments']['respose'] special_tool_call = 'DirectResponse' elif function_call_json[i]["name"] == 'RequireClarification': call_result = function_call_json[i]['arguments']['unclear_question'] special_tool_call = 'RequireClarification' elif function_call_json[i]["name"] == 'CallAgent': if call_agent_level < 2 and call_agent: solution_plan = function_call_json[i]['arguments']['solution'] full_message = ( message_for_call_agent + "\nYou must follow the following plan to answer the question: " + str(solution_plan) ) sub_agent_task = "Sub TxAgent plan: " + str(solution_plan) call_result = yield from self.run_gradio_chat( full_message, history=[], temperature=temperature, max_new_tokens=512, max_token=131072, call_agent=False, call_agent_level=call_agent_level, conversation=None, sub_agent_task=sub_agent_task) if call_result is not None and isinstance(call_result, str): call_result = call_result.split('[FinalAnswer]')[-1] else: call_result = "⚠️ No content returned from sub-agent." else: call_result = "Error: CallAgent disabled." else: call_result = self.tooluniverse.run_one_function(function_call_json[i]) call_id = self.tooluniverse.call_id_gen() function_call_json[i]["call_id"] = call_id call_results.append({ "role": "tool", "content": json.dumps({"tool_name": function_call_json[i]["name"], "content": call_result, "call_id": call_id}) }) if return_gradio_history and function_call_json[i]["name"] != 'Finish': metadata = {"title": f"🧰 {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])} gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata)) else: call_results.append({ "role": "tool", "content": json.dumps({"content": "Invalid or no function call detected."}) }) revised_messages = [{ "role": "assistant", "content": message.strip(), "tool_calls": json.dumps(function_call_json) }] + call_results if return_gradio_history: return revised_messages, existing_tools_prompt, special_tool_call, gradio_history return revised_messages, existing_tools_prompt, special_tool_call def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None): if conversation[-1]['role'] == 'assistant': conversation.append( {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'}) finish_tools_prompt = self.add_finish_tools([]) last_outputs_str = self.llm_infer( messages=conversation, temperature=temperature, tools=finish_tools_prompt, output_begin_string='[FinalAnswer]', skip_special_tokens=True, max_new_tokens=max_new_tokens, max_token=max_token) logger.info("Unfinished reasoning answer: %s", last_outputs_str[:100]) return last_outputs_str def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int, max_round: int = 5, call_agent=False, call_agent_level=0): logger.info("Starting multistep agent for message: %s", message[:100]) picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( call_agent, call_agent_level, message) conversation = self.initialize_conversation(message) outputs = [] last_outputs = [] next_round = True current_round = 0 token_overflow = False enable_summary = False last_status = {} while next_round and current_round < max_round: current_round += 1 if len(outputs) > 0: function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call( last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt, message_for_call_agent=message, call_agent=call_agent, call_agent_level=call_agent_level, temperature=temperature) if special_tool_call == 'Finish': next_round = False conversation.extend(function_call_messages) content = function_call_messages[0]['content'] if content is None: return "❌ No content returned after Finish tool call." return content.split('[FinalAnswer]')[-1] if (self.enable_summary or token_overflow) and not call_agent: enable_summary = True last_status = self.function_result_summary( conversation, status=last_status, enable_summary=enable_summary) if function_call_messages: conversation.extend(function_call_messages) outputs.append(tool_result_format(function_call_messages)) else: next_round = False conversation.extend([{"role": "assistant", "content": ''.join(last_outputs)}]) return ''.join(last_outputs).replace("", "") last_outputs = [] outputs.append("### TxAgent:\n") last_outputs_str, token_overflow = self.llm_infer( messages=conversation, temperature=temperature, tools=picked_tools_prompt, skip_special_tokens=False, max_new_tokens=2048, max_token=131072, check_token_status=True) if last_outputs_str is None: logger.warning("Token limit exceeded") if self.force_finish: return self.get_answer_based_on_unfinished_reasoning( conversation, temperature, max_new_tokens, max_token) return "❌ Token limit exceeded." last_outputs.append(last_outputs_str) if max_round == current_round: logger.warning("Max rounds exceeded") if self.force_finish: return self.get_answer_based_on_unfinished_reasoning( conversation, temperature, max_new_tokens, max_token) return None def build_logits_processor(self, messages, llm): logger.warning("Logits processor disabled due to vLLM V1 limitation") return None def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None, max_new_tokens=512, max_token=131072, skip_special_tokens=True, model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False): if model is None: model = self.model logits_processor = self.build_logits_processor(messages, model) sampling_params = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, seed=seed if seed is not None else self.seed, ) prompt = self.chat_template.render( messages=messages, tools=tools, add_generation_prompt=True) if output_begin_string is not None: prompt += output_begin_string if check_token_status and max_token is not None: token_overflow = False num_input_tokens = len(self.tokenizer.encode(prompt, add_special_tokens=False)) logger.info("Input prompt tokens: %d, max_token: %d", num_input_tokens, max_token) if num_input_tokens > max_token: torch.cuda.empty_cache() gc.collect() logger.warning("Token overflow: %d > %d", num_input_tokens, max_token) return None, True output = model.generate(prompt, sampling_params=sampling_params) output_text = output[0].outputs[0].text output_tokens = len(self.tokenizer.encode(output_text, add_special_tokens=False)) logger.debug("Inference output: %s (output tokens: %d)", output_text[:100], output_tokens) torch.cuda.empty_cache() gc.collect() if check_token_status and max_token is not None: return output_text, token_overflow return output_text def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int): logger.info("Starting self agent") conversation = self.set_system_prompt([], self.self_prompt) conversation.append({"role": "user", "content": message}) return self.llm_infer( messages=conversation, temperature=temperature, tools=None, max_new_tokens=max_new_tokens, max_token=max_token) def run_chat_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int): logger.info("Starting chat agent") conversation = self.set_system_prompt([], self.chat_prompt) conversation.append({"role": "user", "content": message}) return self.llm_infer( messages=conversation, temperature=temperature, tools=None, max_new_tokens=max_new_tokens, max_token=max_token) def run_format_agent(self, message: str, answer: str, temperature: float, max_new_tokens: int, max_token: int): logger.info("Starting format agent") if '[FinalAnswer]' in answer: possible_final_answer = answer.split("[FinalAnswer]")[-1] elif "\n\n" in answer: possible_final_answer = answer.split("\n\n")[-1] else: possible_final_answer = answer.strip() if len(possible_final_answer) == 1 and possible_final_answer in ['A', 'B', 'C', 'D', 'E']: return possible_final_answer elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']: return possible_final_answer[0] conversation = self.set_system_prompt( [], "Transform the agent's answer to a single letter: 'A', 'B', 'C', 'D'.") conversation.append({"role": "user", "content": message + "\nAgent's answer: " + answer + "\nAnswer (must be a letter):"}) return self.llm_infer( messages=conversation, temperature=temperature, tools=None, max_new_tokens=max_new_tokens, max_token=max_token) def run_summary_agent(self, thought_calls: str, function_response: str, temperature: float, max_new_tokens: int, max_token: int): logger.info("Summarizing tool result") prompt = f"""Thought and function calls: {thought_calls} Function calls' responses: \"\"\" {function_response} \"\"\" Summarize the function calls' l responses in one sentence with all necessary information. """ conversation = [{"role": "user", "content": prompt}] output = self.llm_infer( messages=conversation, temperature=temperature, tools=None, max_new_tokens=max_new_tokens, max_token=max_token) if '[' in output: output = output.split('[')[0] return output def function_result_summary(self, input_list, status, enable_summary): if 'tool_call_step' not in status: status['tool_call_step'] = 0 for idx in range(len(input_list)): pos_id = len(input_list) - idx - 1 if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]: break status['step'] = status.get('step', 0) + 1 if not enable_summary: return status status['summarized_index'] = status.get('summarized_index', 0) status['summarized_step'] = status.get('summarized_step', 0) status['previous_length'] = status.get('previous_length', 0) status['history'] = status.get('history', []) function_response = '' idx = status['summarized_index'] this_thought_calls = None while idx < len(input_list): if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \ (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length): if input_list[idx]['role'] == 'assistant': if function_response: status['summarized_step'] += 1 result_summary = self.run_summary_agent( thought_calls=this_thought_calls, function_response=function_response, temperature=0.1, max_new_tokens=512, max_token=131072) input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary}) status['summarized_index'] = last_call_idx + 2 idx += 1 last_call_idx = idx this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls'] function_response = '' elif input_list[idx]['role'] == 'tool' and this_thought_calls is not None: function_response += input_list[idx]['content'] del input_list[idx] idx -= 1 else: break idx += 1 if function_response: status['summarized_step'] += 1 result_summary = self.run_summary_agent( thought_calls=this_thought_calls, function_response=function_response, temperature=0.1, max_new_tokens=512, max_token=131072) tool_calls = json.loads(input_list[last_call_idx]['tool_calls']) for tool_call in tool_calls: del tool_call['call_id'] input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls) input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary}) status['summarized_index'] = last_call_idx + 2 return status def update_parameters(self, **kwargs): updated_attributes = {} for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) updated_attributes[key] = value logger.info("Updated parameters: %s", updated_attributes) return updated_attributes def run_gradio_chat(self, message: str, history: list, temperature: float, max_new_tokens: int = 2048, max_token: int = 131072, call_agent: bool = False, conversation: gr.State = None, max_round: int = 5, seed: int = None, call_agent_level: int = 0, sub_agent_task: str = None, uploaded_files: list = None): logger.info("Chat started, message: %s", message[:100]) if not message or len(message.strip()) < 5: yield "Please provide a valid message or upload files to analyze." return picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( call_agent, call_agent_level, message) conversation = self.initialize_conversation( message, conversation, history) history = [] last_outputs = [] next_round = True current_round = 0 enable_summary = False last_status = {} token_overflow = False try: while next_round and current_round < max_round: current_round += 1 logger.debug("Starting round %d/%d", current_round, max_round) if last_outputs: function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream( last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt, message_for_call_agent=message, call_agent=call_agent, call_agent_level=call_agent_level, temperature=temperature) history.extend(current_gradio_history) if special_tool_call == 'Finish': logger.info("Finish tool called, ending chat") yield history next_round = False conversation.extend(function_call_messages) content = function_call_messages[0]['content'] if content: return content return "No content returned after Finish tool call." elif special_tool_call in ['RequireClarification', 'DirectResponse']: last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.") history.append(ChatMessage(role="assistant", content=last_msg.content)) logger.info("Special tool %s called, ending chat", special_tool_call) yield history next_round = False return last_msg.content if (self.enable_summary or token_overflow) and not call_agent: enable_summary = True last_status = self.function_result_summary( conversation, status=last_status, enable_summary=enable_summary) if function_call_messages: conversation.extend(function_call_messages) yield history else: next_round = False conversation.append({"role": "assistant", "content": ''.join(last_outputs)}) logger.info("No function call messages, ending chat") return ''.join(last_outputs).replace("", "") last_outputs = [] last_outputs_str, token_overflow = self.llm_infer( messages=conversation, temperature=temperature, tools=picked_tools_prompt, skip_special_tokens=False, max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True) if last_outputs_str is None: logger.warning("Token limit exceeded") if self.force_finish: last_outputs_str = self.get_answer_based_on_unfinished_reasoning( conversation, temperature, max_new_tokens, max_token) history.append(ChatMessage(role="assistant", content=last_outputs_str.strip())) yield history return last_outputs_str error_msg = "Token limit exceeded." history.append(ChatMessage(role="assistant", content=error_msg)) yield history return error_msg last_thought = last_outputs_str.split("[TOOL_CALLS]")[0] for msg in history: if msg.metadata is not None: msg.metadata['status'] = 'done' if '[FinalAnswer]' in last_thought: parts = last_thought.split('[FinalAnswer]', 1) final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "") history.append(ChatMessage(role="assistant", content=final_thought.strip())) yield history history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) logger.info("Final answer provided: %s", final_answer[:100]) yield history next_round = False # Ensure we exit after final answer return final_answer else: history.append(ChatMessage(role="assistant", content=last_thought)) yield history last_outputs.append(last_outputs_str) if next_round: if self.force_finish: last_outputs_str = self.get_answer_based_on_unfinished_reasoning( conversation, temperature, max_new_tokens, max_token) parts = last_outputs_str.split('[FinalAnswer]', 1) final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "") history.append(ChatMessage(role="assistant", content=final_thought.strip())) yield history history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) logger.info("Forced final answer: %s", final_answer[:100]) yield history return final_answer else: error_msg = "Reasoning rounds exceeded limit." history.append(ChatMessage(role="assistant", content=error_msg)) yield history return error_msg except Exception as e: logger.error("Exception in run_gradio_chat: %s", e, exc_info=True) error_msg = f"Error: {e}" history.append(ChatMessage(role="assistant", content=error_msg)) yield history if self.force_finish: last_outputs_str = self.get_answer_based_on_unfinished_reasoning( conversation, temperature, max_new_tokens, max_token) parts = last_outputs_str.split('[FinalAnswer]', 1) final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "") history.append(ChatMessage(role="assistant", content=final_thought.strip())) yield history history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) logger.info("Forced final answer after error: %s", final_answer[:100]) yield history return final_answer return error_msg def run_gradio_chat_batch(self, messages: List[str], temperature: float, max_new_tokens: int = 2048, max_token: int = 131072, call_agent: bool = False, conversation: List = None, max_round: int = 5, seed: int = None, call_agent_level: int = 0): """Run batch inference for multiple messages.""" logger.info("Starting batch chat for %d messages", len(messages)) batch_results = [] for message in messages: # Initialize conversation for each message conv = self.initialize_conversation(message, conversation, history=None) picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( call_agent, call_agent_level, message) # Run single inference for simplicity (extend for multi-round if needed) output, token_overflow = self.llm_infer( messages=conv, temperature=temperature, tools=picked_tools_prompt, max_new_tokens=max_new_tokens, max_token=max_token, skip_special_tokens=False, seed=seed, check_token_status=True ) if output is None: logger.warning("Token limit exceeded for message: %s", message[:100]) batch_results.append("Token limit exceeded.") else: batch_results.append(output) logger.info("Batch chat completed for %d messages", len(messages)) return batch_results