# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ RLVE-Gym Environment Implementation. """ from typing import Optional, Tuple import random from openenv_core.env_server.interfaces import Environment from models import RlveGymState, RlveGymAction, RlveGymObservation from server.Gym.environment import VerifiableEnvironment from server.Gym.parameter_controller import ParameterController from server.Gym.environments import identifier2environment from server.Gym.parameter_controllers import identifier2controller class RlveGymEnvironment(Environment): """ Wrap any verifiable environment from RLVE-Gym behind the OpenEnv ``Environment`` API. """ def __init__( self, environment_identifier: str = "Multiplication", difficulty: int = 0, answer_markers: Optional[Tuple[str, str]] = None, initial_seed: int = 0, ): """Initialize the RLVE_Gym environment.""" self._state = RlveGymState( seed=initial_seed, problem_input=None, num_samples=0, sum_accuracy=0, ) self.environment_identifier = environment_identifier self.difficulty = difficulty self.answer_markers = answer_markers self.problem = None def reset(self) -> RlveGymObservation: """ Reset the environment. Returns: problem_input: The generated problem input string (or None if generation failed) verifier_result: None success: Boolean indicating if the reset was successful message: Message indicating the result of the reset """ if (self.environment_identifier not in identifier2environment) or ( self.environment_identifier not in identifier2controller ): return RlveGymObservation( problem_input=None, verifier_result=None, success=False, message="Invalid environment identifier.", reward=None, ) if not (isinstance(self.difficulty, int) and self.difficulty >= 0): return RlveGymObservation( problem_input=None, verifier_result=None, success=False, message="Difficulty should be a non-negative integer.", reward=None, ) if not (isinstance(self._state.seed, int) and self._state.seed >= 0): return RlveGymObservation( problem_input=None, verifier_result=None, success=False, message="Seed should be a non-negative integer.", reward=None, ) try: problem: VerifiableEnvironment = identifier2environment[self.environment_identifier]( answer_markers=self.answer_markers ) except Exception as e: return RlveGymObservation( problem_input=None, verifier_result=None, success=False, message=f"Failed to initialize environment: {e}", reward=None, ) controller: ParameterController = identifier2controller[self.environment_identifier]() for _ in range(self.difficulty): controller.update() random.seed(self._state.seed) parameter = random.choice(controller.get_parameter_list()) if problem.generator(seed=self._state.seed, parameter=parameter): self._state.problem_input = problem.prompt_generator() self.problem = problem else: self._state.problem_input = None self.problem = None self._state.seed += 1 self._state.num_samples = self._state.sum_accuracy = 0 if self.problem is not None: return RlveGymObservation( problem_input=self._state.problem_input, verifier_result=None, success=True, message="Problem generated successfully.", reward=None, ) else: return RlveGymObservation( problem_input=None, verifier_result=None, success=False, message="Problem generation failed. Please try decreasing difficulty or changing seed.", reward=None, ) def step(self, action: RlveGymAction) -> RlveGymObservation: # type: ignore[override] """ Execute a step in the environment by verifying the model output. Args: action: RlveGymAction containing the output to verify Returns: problem_input: The problem input string from the current state verifier_result: Result of the verification containing accuracy and other metrics success: Boolean indicating if the step was successful message: Message indicating the result of the step """ if self.problem is None: return RlveGymObservation( problem_input=None, verifier_result=None, success=False, message="Problem not ready. Please reset the environment.", reward=None, ) try: verifier_result = self.problem.verifier(action.output) except Exception as e: return RlveGymObservation( problem_input=self._state.problem_input, verifier_result=None, success=False, message=f"Verification failed with error: {e}", reward=None, ) self._state.num_samples += 1 self._state.sum_accuracy += verifier_result["accuracy"] return RlveGymObservation( problem_input=self._state.problem_input, verifier_result=verifier_result, success=True, message="Verification completed.", reward=verifier_result["reward"], ) @property def state(self) -> RlveGymState: """ Get the current environment state. Returns: seed: The current random seed value for problem generation problem_input: The generated problem input string (or None if generation failed) num_samples: Number of samples taken so far sum_accuracy: Sum of accuracies from verifications so far """ return self._state