NexRL is a large-scale distributed reinforcement learning training framework designed for modern RL applications. NexRL provides a scalable, modular architecture that seamlessly supports various training and inference backends.
- Multiple Launch Mode Support: Seamlessly runs in both local and Ray distributed modes
- Modular Design: Clean separation of concerns with well-defined interfaces and extensible components
- Training-as-a-Service & Rollout-as-a-Service: Unified API architecture that seamlessly supports different training and inference frameworks through service abstraction
- Resource Management: Intelligent placement and co-location of services for optimal performance
- Activity Tracking: Comprehensive monitoring and health checking system for production deployments
- Error Handling: Centralized error reporting and recovery mechanisms
NexRL follows a modular architecture where components communicate through explicit interfaces and APIs.
- NexRLController: Main orchestrator that initializes and coordinates all components
- DataLoader: Provides input data for rollout workers (training and validation)
- RolloutWorkers: Execute environment interactions and generate trajectories
- TrajectoryPool: Collects and batches trajectories from rollout workers
- AlgorithmProcessor: Processes trajectory batches for training
- TrainBatchPool: Manages training batches for models
- TrainWorker: Performs actual model training via training service clients
- WeightSyncController: Manages model weights and synchronization
- Validator: Collects validation trajectories and computes metrics
- ActivityTracker: Monitors system health and activity, coordinates experiment logging
- RayResourceManager: Handles distributed resource allocation and actor co-location
Core type definitions used throughout the framework.
ModelTag = str # Type alias for model identificationUsed to identify different models within the system.
Trajectory = dict[str, Any]Represents a single trajectory containing environment interaction data. Common keys include:
prompt: Input prompt for LLMresponse: LLM responsefinish_reason: Completion statusmodel_tag: Associated model identifier
@dataclass
class Batch:
values: dict[str, Any] # Tensor or data arrays, length = metadata['batch_size']
metadata: dict[str, Any] # Batch metadata including 'batch_size'Methods:
__len__() -> int: Returns batch size from metadatacopy() -> Batch: Creates a deep copy of the batchto_dict() -> dict[str, Any]: Converts batch to single dictionary (metadata keys overwrite values keys)remove_redundant_left_padding(data, pad_token_id, fields, anchor_field, max_strip_threshold) -> Batch: Static method that removes redundant left padding tokens common across all sequencesremove_redundant_right_padding(data, pad_token_id, fields, anchor_field, max_strip_threshold) -> Batch: Static method that removes redundant right padding tokens common across all sequencesto_nextrainer_batch() -> dict[str, Any]: Converts batch to NexTrainer format with separated tensor/non-tensor values and metadata
class NexRLRole(Enum):
ROLLOUT_WORKER = "rollout_worker"
TRAIN_WORKER = "train_worker"
ALGORITHM_PROCESSOR = "algorithm_processor"
TRAJECTORY_POOL = "trajectory_pool"
TRAIN_BATCH_POOL = "train_batch_pool"
WEIGHT_SYNC_CONTROLLER = "weight_sync_controller"
DATA_LOADER = "data_loader"
VALIDATE_DATALOADER = "validate_dataloader"
VALIDATOR = "validator"Defines different component roles for resource pool mapping.
Base class for all NexRL components, enabling Ray colocation compatibility.
class NexRLModule(ABC):
def __init__(self):
self._module_name: str = "invalid"
self._activity_tracker: ActivityTrackerProxy = NonePurpose: Provides common interface for all NexRL modules to work with Ray resource management and activity tracking.
Methods:
set_activity_tracker(tracker: ActivityTrackerProxy): Sets the activity tracker for this moduleset_module_name(module_name: str): Sets the name of this moduleget_module_name() -> str: Gets the name of this modulehealth_check() -> bool: Health check method to verify the module is alive and responsive, used during initialization and monitoringeasy_dump(value, keys, value_formatter): Convenience method to dump values with automatic module context for debugging purposes
The main orchestrator responsible for initializing, coordinating, and monitoring all framework components.
def __init__(self, config: DictConfig)Parameters:
config: Hydra configuration containing all module settings
Functionality:
- Initializes all framework modules based on launch mode
- Sets up Ray resources in distributed mode
- Establishes inter-module references
def run() -> NoneStarts the training process by launching all components and entering monitoring loop.
Process:
- Initializes train workers with final configuration
- Loads initial checkpoint (or resumes from existing)
- Optionally runs validation before training
- Starts all worker components asynchronously
- Monitors system health and activity
- Checks for weight sync validation triggers
- Checks for completion conditions
- Handles graceful shutdown
def _stop()Gracefully stops all components and waits for activity completion.
Features:
- Signals all workers to stop
- Waits for quiescence with timeout
- Logs remaining activities on timeout
def _check_finish() -> boolDetermines if training should stop based on:
- Maximum training steps reached
- System quiescence (all pools empty, no active work)
def _check_module_liveness(self, timeout: float = 5.0) -> boolParameters:
timeout: Ray operation timeout in seconds
Returns: True if all modules are alive, False if any are dead
def _check_module_exceptions(self) -> boolReturns: True if system is healthy, False if critical errors detected
def _load_initial_checkpoint()Loads initial checkpoint or prepares for training from scratch. Creates sync weight buffer and performs initial weight sync to inference service.
def _load_resume_checkpoint()Loads checkpoint based on resume configuration (auto or from_path). Supports automatic detection of latest checkpoint or explicit path specification.
def _find_latest_checkpoint(self, checkpoint_folder: str) -> str | NoneFinds the latest checkpoint in the given folder by parsing global_step_* directories.
def _run_validate(self, model_tag: ModelTag)Runs validation cycle after a weight sync event. Switches workers to validation mode, waits for completion, computes metrics, and switches back to training mode.
def _start_validate(self, model_tag: ModelTag)Starts validation by switching rollout workers to validation mode.
def _end_validate(self, model_tag: ModelTag)Ends validation by computing metrics, logging results, switching workers back to training mode, and notifying weight sync controller.
Abstract base class for data input components.
def __init__(self, config: DictConfig, is_validate: bool = False)Parameters:
config: Configuration for the data loaderis_validate: Whether this dataloader is for validation (affects behavior and tracking)
def __len__(self) -> intReturns: Number of remaining data items
def __getitem__(self, index: int) -> dict[str, Any]Parameters:
index: Index of data item to retrieve
Returns: Single data item as dictionary
def get_next_item(self) -> dict[str, Any] | NoneReturns: Next data item in sequence, or None if exhausted
def is_finished(self) -> boolReturns: True if no more data available
def can_return_item(self) -> boolReturns: True if the data loader can return an item currently
def reset() -> NoneResets the data loader to initial state (used for validation cycles)
def add_item(self, item: dict[str, Any]) -> NoneParameters:
item: Data item to add (added to end by default)
def add_item_front(self, item: dict[str, Any]) -> NoneParameters:
item: Data item to add to beginning of queue
def add_item_back(self, item: dict[str, Any]) -> NoneParameters:
item: Data item to add to end of queue
Service client for interacting with LLM APIs, encapsulating OpenAI client functionality.
def __init__(self, config: DictConfig)Parameters:
config: Configuration containing LLM settings
Initializes:
- OpenAI client with API key and base URL
- Model tag and weight sync coordination settings
def completion(self, prompt: str, **kwargs) -> dict[str, Any]Parameters:
prompt: Input text prompt**kwargs: Additional completion parameters
Returns: Dictionary containing:
prompt: Original input promptresponse: LLM generated textfinish_reason: Completion status- Additional passed kwargs
Features:
- Automatic retry logic with configurable max_retries
- Weight sync coordination (blocks if weight sync in progress)
- Error handling and logging
def generate(self, messages: list[dict[str, Any]], **kwargs) -> dict[str, Any]Parameters:
messages: List of message dictionaries for chat completion**kwargs: Additional generation parameters
Returns: Dictionary containing:
messages: Original input messagesresponse: Generated response texttool_calls: Any tool calls madefinish_reason: Completion status- Additional passed kwargs
def set_weight_sync_controller(self, controller: WeightSyncController)Parameters:
controller: Weight synchronization controller reference
Abstract base class for rollout execution workers that interact with LLM services.
def __init__(self, config: DictConfig)Parameters:
config: Worker configuration including LLM settings
Initializes:
- LLMServiceClient for LLM interactions
- Threading components for async execution
- Module references (set via
set_module_references)
set_module_references(trajectory_pool, dataloader, weight_sync_controller, validate_dataloader, validator)
def set_module_references(self, trajectory_pool: TrajectoryPool, dataloader: BaseDataLoader, weight_sync_controller: WeightSyncController, validate_dataloader: BaseDataLoader, validator: Validator)Parameters:
trajectory_pool: Reference to trajectory collection pooldataloader: Reference to data sourceweight_sync_controller: Reference to weight synchronization controllervalidate_dataloader: Reference to validation data sourcevalidator: Reference to validation trajectory collector
def set_activity_tracker(self, tracker: ActivityTrackerProxy)Parameters:
tracker: Activity monitoring proxy
def run()Starts the worker thread and begins the main processing loop.
Preconditions:
- Module references must be set
- Activity tracker must be set
def stop()Gracefully stops the worker and waits for thread completion.
def begin_validate()Switches the worker to validation mode. The worker will use the validation dataloader and send trajectories to the validator.
def end_validate()Switches the worker back to training mode. The worker will use the training dataloader and send trajectories to the trajectory pool.
def step(self, task: dict[str, Any]) -> str | NoneParameters:
task: Single task to process
Returns:
"success": Trajectory processed and added successfully"fail": Failed to process trajectory"re-rollout": Should retry processing (weight sync in progress)None: Processing failed before trajectory creation
Abstract method - Must be implemented by derived classes to define specific worker behavior.
Workers access LLM functionality through the _llm_client (LLMServiceClient instance):
def _llm_client.completion(self, prompt: str, **kwargs) -> dict[str, Any]Parameters:
prompt: Input text prompt**kwargs: Additional completion parameters (model, max_tokens, temperature, etc.)
Returns: Dictionary containing:
prompt: Original input promptresponse: LLM generated textfinish_reason: Completion status- Additional passed kwargs
Features:
- Automatic retry logic with configurable max_retries
- Weight sync coordination (blocks during sync)
- Error handling and logging
def _llm_client.generate(self, messages: list[dict[str, Any]], **kwargs) -> dict[str, Any]Parameters:
messages: List of message dictionaries for chat completion**kwargs: Additional generation parameters
Returns: Dictionary containing:
messages: Original input messagesresponse: Generated response texttool_calls: Any tool calls madefinish_reason: Completion status- Additional passed kwargs
def _put_trajectory(self, trajectory: Trajectory) -> strParameters:
trajectory: Completed trajectory to submit
Returns:
"success": Trajectory submitted successfully"fail": Failed to submit trajectory"re-rollout": Should retry (weight sync in progress)
def _get_rollout_task(self) -> dict[str, Any] | NoneReturns: Next task from dataloader, or None if none available
Features:
- Automatic sleep to prevent busy waiting
- Non-blocking operation
def _put_rollout_task(self, task: dict[str, Any]) -> boolParameters:
task: Task to return to dataloader for reprocessing
Returns: True if successfully returned, False otherwise
Concrete implementation of BaseRolloutWorker with basic LLM completion functionality.
def __init__(self, config: DictConfig)Inherits from BaseRolloutWorker.
def step(self, task: dict[str, Any]) -> str | NoneParameters:
task: Task dictionary containingpromptfield
Returns:
"success": Trajectory processed and submitted successfully"fail": Failed to submit trajectory"re-rollout": Should retry processingNone: Processing failed (missing prompt)
Process:
- Extracts prompt from task
- Calls LLMServiceClient completion
- Creates trajectory with prompt, response, and task metadata
- Submits trajectory to trajectory pool
- Returns submission result
Error Handling:
- Returns None if prompt missing
- Propagates result from trajectory submission
Advanced rollout worker implementation for agent-based tasks with tool calling and multi-turn interaction support.
def __init__(self, config: DictConfig)Inherits from BaseRolloutWorker and adds agent-specific functionality.
- Supports chat-based interactions with message history
- Tool calling capabilities through LLM generate method
- Multi-turn conversation management
- Agent-specific trajectory formatting
This worker type is designed for more complex agent tasks that require stateful interactions and tool usage.
Multi-store trajectory pool that manages separate TrajectoryPoolInstance objects for different models, providing flexible batching strategies and weight synchronization coordination.
def __init__(self, config: DictConfig)Parameters:
config: Pool configuration including grouping and batching settings
Configuration Options:
key_list: List of keys for grouping trajectoriesgroup_size: Number of trajectories per groupbatch_size: Default batch size for retrievalcheck_batch_ready_function: Batch readiness criteria ("batch_size", "loaded_batch_finished")
def put_trajectory(self, trajectory: Trajectory) -> strParameters:
trajectory: Trajectory data to store
Returns:
"success": Trajectory stored successfully"fail": Failed to store trajectory"re-rollout": Should retry (weight sync in progress)
Process:
- Extracts ModelTag from trajectory (defaults to "default")
- Creates or retrieves appropriate TrajectoryPoolInstance
- Adds trajectory to instance (may block during weight sync)
def get_batch(self, batch_size: int | None = None, model_tag: ModelTag | None = None) -> Batch | NoneParameters:
batch_size: Number of trajectories to includemodel_tag: Specific model to get batch from
Returns: Batch of trajectories, or None if insufficient samples
Behavior:
- If
model_tagis None, tries any available store - If specified model_tag has no store, returns None
def get_batch_any(self, batch_size: int | None = None) -> Batch | NoneParameters:
batch_size: Number of trajectories to retrieve
Returns: Batch from any store with sufficient samples, or None
def is_empty(self, model_tag: ModelTag | None = None) -> boolParameters:
model_tag: Specific model to check, or None for all models
Returns: True if specified store (or all stores) is empty
def get_model_tags(self) -> list[ModelTag]Returns: List of all ModelTags with active stores
Individual pool instance managing trajectories for a single model with weight synchronization coordination.
def set_module_references(self, dataloader: BaseDataLoader, weight_sync_controller: WeightSyncController, activity_tracker: ActivityTrackerProxy)Parameters:
dataloader: Reference to data loaderweight_sync_controller: Weight synchronization controlleractivity_tracker: Activity tracking proxy
def put_trajectory(self, trajectory: Trajectory) -> strParameters:
trajectory: Trajectory to add
Returns:
"success": Added successfully"fail": Failed to add"re-rollout": Weight sync in progress, should retry
def notify_weight_sync_starting()Blocks new trajectory additions during weight synchronization.
def unlock_for_weight_sync()Unblocks trajectory additions after weight synchronization completes.
TrajectoryPoolInstance automatically creates appropriate stores based on configuration:
- Use Case: No grouping required
- Behavior: Directly adds trajectories to finished samples
- Configuration: Empty
key_list
- Use Case: Single-level grouping (e.g., by user ID)
- Behavior: Groups trajectories by specified key, releases when group reaches target size
- Configuration: Single item in
key_list
- Use Case: Multi-level grouping (e.g., by user ID then session ID)
- Behavior: Creates nested hierarchy, releases leaf groups when complete
- Configuration: Multiple items in
key_list
Abstract base class for processing trajectories into training batches.
def __init__(self, config: DictConfig)Parameters:
config: Processor configuration including batch_size
def set_module_references(self, trajectory_pool: TrajectoryPool, train_batch_pool: TrainBatchPool) -> NoneParameters:
trajectory_pool: Source of trajectory batchestrain_batch_pool: Destination for processed batches
def set_activity_tracker(self, tracker: ActivityTrackerProxy)Parameters:
tracker: Activity monitoring proxy
def run()Starts the processor thread and begins the main processing loop.
def stop()Gracefully stops the processor.
def _fit(self, batch: Batch, update_fn: str)Parameters:
batch: Trajectory batch to processupdate_fn: Update function identifier
Purpose: Process trajectories through model services to compute advantages, logprobs, etc.
def _get_batch(self, batch_size: int | None = None) -> BatchParameters:
batch_size: Number of trajectories to fetch
Returns: Batch from trajectory pool
def _put_batch(self, batch: Batch, update_fn: str) -> boolParameters:
batch: Processed batch to submitupdate_fn: Update function identifier
Returns: True if successfully submitted
Manages training batches organized by model tags.
def __init__(self, config: DictConfig)Parameters:
config: Pool configuration
def put_batch(self, batch: Batch, update_fn: str) -> boolParameters:
batch: Training batch to storeupdate_fn: Update function identifier
Returns: True if successfully stored
Process:
- Extracts model_tag from batch metadata (defaults to "default")
- Creates model queue if needed
- Appends batch to appropriate queue
def get_batch(self, model: ModelTag | None = "default") -> Batch | NoneParameters:
model: Model identifier to get batch for
Returns: Training batch for specified model, or None if unavailable
def is_empty(self) -> boolReturns: True if all model queues are empty
Collects validation trajectories, computes metrics, and logs results. Unlike TrajectoryPool, this component focuses on simple collection without batching logic.
def __init__(self, config: DictConfig)Parameters:
config: Validator configuration
def set_module_references(self, validate_dataloader: BaseDataLoader)Parameters:
validate_dataloader: Reference to validation data loader
def put_trajectory(self, trajectory: Trajectory) -> strParameters:
trajectory: Validation trajectory to store
Returns: "success" to match TrajectoryPool.put_trajectory signature
Stores a validation trajectory for later metric computation.
def is_complete(self) -> boolReturns: True if all validation trajectories have been collected
Checks if validation dataloader is drained and rollout workers are quiescent.
def compute_and_log_metrics(self) -> dict[str, float]Returns: Dictionary of computed metrics with "val/" prefix
Computes mean of each score key across all trajectories and logs results via activity tracker.
def clear()Clears all stored validation trajectories.
Concrete implementation for model training workers that interfaces with training services (e.g., NexTrainer).
def __init__(self, config: DictConfig)Parameters:
config: Worker configuration including total_train_steps and train_service settings
Initializes:
- Train service client for model training
- Training statistics tracking
- Threading components for async execution
- Checkpoint saving support
def set_module_references(self, train_batch_pool: TrainBatchPool, weight_sync_controller: WeightSyncController)Parameters:
train_batch_pool: Source of training batchesweight_sync_controller: Reference to weight synchronization controller
def set_activity_tracker(self, tracker: ActivityTrackerProxy) -> NoneParameters:
tracker: Activity monitoring proxy
def initialize_workers() -> NoneInitializes training service workers with the final configuration. For NexTrainer backend, this sends actor config to workers and initializes the model on GPU.
def run() -> NoneStarts the training worker thread and begins the main training loop.
def stop() -> NoneGracefully stops the training worker and waits for any ongoing checkpoint saves to complete.
def get_train_step(self) -> intReturns: Current training step count
def set_train_step(self, step: int) -> NoneParameters:
step: The training step to set
Sets the current training step (used for resuming from checkpoint).
def _step(self, batch: Batch)Parameters:
batch: Training batch to process
Trains the model on a single batch, updates metrics, saves checkpoints if needed, and notifies weight sync controller.
def _get_batch(self) -> Batch | NoneReturns: Next training batch, or None if unavailable
Manages model weights and synchronization coordination across the system. Supports multiple synchronization modes and coordinates with trajectory pools and rollout services.
def __init__(self, config: DictConfig)Parameters:
config: Weight manager configuration
Configuration Options:
sync_mode: Synchronization mode ("sync", "fully-async", "batch-async")staleness_threshold: Maximum staleness allowed in async modescheckpoint_manager: Checkpoint manager configuration
def set_module_references(self, dataloader: BaseDataLoader, trajectory_pool: TrajectoryPool) -> NoneParameters:
dataloader: Reference to data loadertrajectory_pool: Reference to trajectory pool
def check_rollout_service_status(self, model_tag: ModelTag) -> Literal["continue", "block"]Parameters:
model_tag: Model to check status for
Returns:
"continue": Rollout service can continue processing"block": Rollout service should block for weight sync
def trajectory_pool_notify_batch_ready(self, model_tag: ModelTag) -> NoneParameters:
model_tag: Model with ready batch
Coordinates weight synchronization when training batches are ready.
def train_worker_notify_weight_update(self, worker_name: str, model_tag: ModelTag) -> NoneParameters:
worker_name: Name of the training workermodel_tag: Model that was trained
Handles training completion and performs synchronous weight synchronization. May trigger validation if configured.
def sync_weight_to_rollout_service(self, model_tag: ModelTag) -> NoneParameters:
model_tag: Model to sync weights for
Synchronizes model weights from training service to rollout/inference service.
def get_rollout_model_version(self, model_tag: ModelTag) -> intParameters:
model_tag: Model to get version for
Returns: Current rollout model version number
def is_waiting_for_validation(self) -> boolReturns: True if weight sync is waiting for validation to complete
Used by controller to trigger validation cycles after weight updates.
def end_validate(self, model_tag: ModelTag) -> NoneParameters:
model_tag: Model that completed validation
Called after validation completes to unlock weight synchronization.
- Blocks all workers until all sync to newest version
- Ensures strict consistency across all components
- No blocking, workers sync opportunistically
- Allows maximum throughput with potential staleness
- Blocks individual workers when staleness exceeds threshold
- Balances consistency and throughput
Checkpoint management is handled through the TrainServiceClient interface, which provides methods for saving and loading checkpoints. The WeightSyncController coordinates checkpoint operations but delegates actual checkpoint I/O to the training service.
Key Operations:
- Checkpoints are saved via
TrainServiceClient.save_checkpoint()during training - Initial checkpoint loading is handled by the controller during startup
- Resume functionality supports automatic checkpoint discovery or explicit path specification
- Weight synchronization uses a dedicated sync weight buffer path
Centralized tracker for monitoring in-flight work across all modules and coordinating experiment logging.
def __init__(self, config: DictConfig, max_errors: int = 1000) -> NoneParameters:
config: Configuration containing project and experiment names for loggingmax_errors: Maximum errors to retain in memory
Attributes:
experiment_logger: Tracking instance for logging metrics to wandb/etc.
def start(self, module: str, work: str) -> strParameters:
module: Module name starting workwork: Description of work being performed
Returns: Unique token for this work item
def end(self, token: str) -> NoneParameters:
token: Token from corresponding start() call
def is_quiescent(self) -> boolReturns: True if no work is currently in progress
def wait_quiescent(self, timeout: float | None = None) -> boolParameters:
timeout: Maximum time to wait, or None for indefinite
Returns: True if quiescence achieved, False if timeout
def get_running_status_summary(self) -> strReturns: Human-readable summary of current activity
def register_module(self, module_name: str, module_ref: Any, is_rollout_worker: bool = False) -> NoneParameters:
module_name: Name for health checkingmodule_ref: Reference to module (local object or Ray actor)is_rollout_worker: Whether this module is a rollout worker (for specialized monitoring)
def is_rollout_worker_quiescent(self) -> boolReturns: True if all registered rollout workers are idle
def check_module_liveness(self, timeout: float = 5.0) -> boolParameters:
timeout: Timeout for Ray operations
Returns: True if all registered modules are alive
report_exception(module: str, work: str, exception: Exception, severity: ErrorSeverity | None = None) -> str
def report_exception(self, module: str, work: str, exception: Exception, severity: ErrorSeverity | None = None) -> strParameters:
module: Module where exception occurredwork: Work contextexception: Exception that was raisedseverity: Error severity (defaults to ERROR)
Returns: Unique error ID
def get_error_health_status(self) -> dict[str, Any]Returns: Dictionary with health status information
def set_training_step(self, step: int) -> NoneParameters:
step: Current training step
Updates the current training step for logging and monitoring purposes.
def get_training_step(self) -> intReturns: Current training step
def experiment_logger_post(self, backend: str, **kwargs)Parameters:
backend: Logging backend ("wandb", etc.)**kwargs: Backend-specific parameters (e.g., data, step, content, title)
Posts metrics or messages to the specified logging backend through the experiment_logger.
Local proxy that forwards activity tracking calls to a central ActivityTracker.
def __init__(self, central_tracker: Any)Parameters:
central_tracker: Reference to central ActivityTracker (local or Ray actor)
def track(self, module: str, work: str, auto_report_errors: bool = True) -> "_ProxyTrackCtx"Parameters:
module: Module name performing workwork: Description of workauto_report_errors: Whether to automatically report exceptions
Returns: Context manager for activity tracking
def is_rollout_worker_quiescent(self) -> boolReturns: True if all registered rollout workers are idle
def set_training_step(self, step: int) -> NoneParameters:
step: Current training step
Forwards training step update to central tracker.
def get_training_step(self) -> intReturns: Current training step from central tracker
def experiment_logger_post(self, backend: str, **kwargs)Parameters:
backend: Logging backend ("wandb", etc.)**kwargs: Backend-specific parameters
Forwards logging request to central tracker's experiment_logger.
Usage Example:
with activity_tracker.track("MyModule", "processing_batch"):
# Perform work here
process_batch(batch)
# Activity automatically tracked and errors reportedCentralized error reporting and aggregation system.
def __init__(self, max_errors: int = 1000)Parameters:
max_errors: Maximum errors to keep in memory
report_exception(module_name: str, work_context: str, exception: Exception, severity: ErrorSeverity = ErrorSeverity.ERROR, details: dict[str, Any] | None = None) -> str
def report_exception(self, module_name: str, work_context: str, exception: Exception, severity: ErrorSeverity = ErrorSeverity.ERROR, details: dict[str, Any] | None = None) -> strParameters:
module_name: Module where error occurredwork_context: Context of work being performedexception: Exception instanceseverity: Error severity leveldetails: Additional error details
Returns: Unique error ID
report_error(module_name: str, work_context: str, message: str, severity: ErrorSeverity = ErrorSeverity.ERROR, details: dict[str, Any] | None = None) -> str
def report_error(self, module_name: str, work_context: str, message: str, severity: ErrorSeverity = ErrorSeverity.ERROR, details: dict[str, Any] | None = None) -> strParameters:
module_name: Module where error occurredwork_context: Context of work being performedmessage: Error messageseverity: Error severity leveldetails: Additional error details
Returns: Unique error ID
def get_health_status(self) -> dict[str, Any]Returns: Dictionary containing:
status: Overall health status ("healthy", "warning", "error")message: Summary messagerecent_error_count: Total recent errorserror_level_count: Recent error-level issueswarning_level_count: Recent warning-level issues
Enumeration of error severity levels.
class ErrorSeverity(Enum):
INFO = "info"
WARNING = "warning"
ERROR = "error"Data class containing error information.
@dataclass
class ErrorInfo:
error_id: str # Unique error identifier
timestamp: float # Error occurrence time
module_name: str # Module where error occurred
work_context: str # Work context
severity: ErrorSeverity # Severity level
message: str # Error message
details: dict[str, Any] # Additional details
exception_type: str | None # Exception type name
exception_traceback: str | None # Full tracebackManages Ray actors and handles actor creation with co-location support.
def __init__(self)Initializes the resource manager with empty role registrations.
register_role(role: NexRLRole, cls: type, config: DictConfig, count: int, colocation_group: str | None)
def register_role(self, role: NexRLRole, cls: type, config: DictConfig, count: int, colocation_group: str | None = None)Parameters:
role: NexRL role being registeredcls: Class to instantiate for this roleconfig: Configuration for the rolecount: Number of instances to createcolocation_group: Optional group name for co-location (None = standalone actor)
Colocation Behavior:
- Roles with the same
colocation_groupwill share a single Ray actor - Roles with
colocation_group=Noneget dedicated actors - Methods are prefixed with role name for co-located actors
def create_all_actors()Creates and deploys all registered actors based on role registrations and colocation groups.
def get_actor_wrapper(self, role: NexRLRole) -> list[Any]Parameters:
role: Role to get actor wrappers for
Returns: List of actor wrappers for the role
Wrapper that enables elegant access to co-located Ray actors.
def __init__(self, actor: ActorHandle, actor_class: type, role: NexRLRole, is_colocated: bool = True)Parameters:
actor: Ray actor handleactor_class: Original actor classrole: NexRL role of this componentis_colocated: Whether actor is co-located with others
Functionality:
- Automatically rebinds methods from actor to wrapper
- Handles role-prefixed method names for co-located actors
- Provides transparent access to actor methods
def execute(func: Any, *args, **kwargs) -> AnyParameters:
func: Function to execute (local or Ray remote)*args, **kwargs: Function arguments
Returns: Function result
Behavior:
- Local mode: Always executes locally
- Ray mode: Auto-detects Ray remote methods and uses ray.get()
def execute_async(func: Any, *args, **kwargs) -> AnyParameters:
func: Function to execute*args, **kwargs: Function arguments
Returns: Immediate result (local) or ObjectRef (Ray)
Purpose: Enables asynchronous execution patterns
def set_logging_basic_config(level)Parameters:
level: Logging level (e.g., logging.DEBUG)
Purpose: Sets up consistent logging format across the framework
NexRL uses Hydra for configuration management. The main configuration structure includes:
launch_mode: "local" or "ray"project_name: Project name for experiment trackingexperiment_name: Experiment name for loggingdata: DataLoader configurationrollout_worker: RolloutWorker configurationtrajectory_pool: TrajectoryPool configurationalgorithm: AlgorithmProcessor configurationtrain_batch_pool: TrainBatchPool configurationtrain_worker: TrainWorker configurationweight: WeightSyncController configurationservice: Service configurations (train_service, inference_service)validate: Validation configurationresume: Resume configurationlogger: Logging backend configurationruntime_monitor: Runtime monitoring configuration
type: Loader type ("mock", "torch")seed: Random seed for data loading- Additional configuration depends on loader type
type: Worker type ("mock", "simple", "single_turn_math")num_workers: Total number of rollout workers
type: Pool type ("default")batch_size: Batch size for trajectory processinggroup_size: Size of trajectory groupskey_list: Keys for hierarchical groupingcheck_batch_ready_function: Batch readiness criteria
type: Processor type ("mock", "grpo")- Additional configuration depends on processor type
type: Pool type ("default")batch_size: Batch size for training
type: Worker type ("default")total_train_steps: Maximum training stepsnum_workers: Number of training workers (usually 1)checkpoint_path: Path to save checkpointssync_weight_path: Path for weight synchronization buffersave_freq: Checkpoint save frequency (in steps)remove_previous_ckpt: Whether to remove previous checkpoints
type: Controller type ("default")sync_mode: Synchronization mode ("sync", "fully-async", "batch-async")staleness_threshold: Maximum staleness in async modes
train_service: Training service configurationbackend: Service backend ("mock", "nextrainer")url: Service URLmodel_tag: Model identifieridentifier: Optional service identifier
inference_service: Inference service configurationbackend: Service backend ("vllm", etc.)url: Service URLmodel_tag: Model identifierapi_key: API key for servicemax_retries: Retry attemptsfreeze_for_weight_sync: Whether to block during weight sync
validate_before_train: Run validation before starting trainingdata: Validation dataloader configuration (same structure as main data config)eval: Validator configurationtype: Validator type ("default")
mode: Resume mode ("disable", "auto", "from_path")resume_path: Path to checkpoint for "from_path" mode
backend: Logging backend ("wandb", etc.)
runtime_monitor:exception_handling:enabled: Enable exception monitoringcheck_interval: Exception check interval (seconds)policy: Error handling policy ("stop_on_error", "continue", "stop_on_critical")
health_check:enabled: Enable module liveness monitoringcheck_interval: Health check interval (seconds)timeout: Health check timeout (Ray mode only)
import hydra
from omegaconf import DictConfig
from nexrl import NexRLController
@hydra.main(config_path="config", config_name="rl_train")
def main(config: DictConfig):
config.launch_mode = "local"
controller = NexRLController(config)
controller.run()import ray
import hydra
from omegaconf import DictConfig
from nexrl import NexRLController
@hydra.main(config_path="config", config_name="rl_train")
def main(config: DictConfig):
config.launch_mode = "ray"
ray.init()
# Create controller as Ray actor
ControllerActor = ray.remote(NexRLController)
controller_actor = ControllerActor.remote(config)
ray.get(controller_actor.run.remote())
ray.shutdown()from nexrl import BaseRolloutWorker
from typing import Any
class MyRolloutWorker(BaseRolloutWorker):
def step(self, task: dict[str, Any]) -> str | None:
# Extract task data
if "prompt" not in task:
return None
prompt = task["prompt"]
# Custom processing logic using LLMServiceClient
result = self._llm_client.completion(prompt, temperature=0.7)
# Create trajectory
trajectory = {
"prompt": prompt,
"response": result["response"],
"finish_reason": result["finish_reason"],
"custom_metadata": task.get("metadata", {}),
"model_tag": self._llm_client._model_tag
}
# Submit to trajectory pool and return result
return self._put_trajectory(trajectory)from nexrl import BaseAlgorithmProcessor
class MyAlgorithmProcessor(BaseAlgorithmProcessor):
def _fit(self, batch: Batch, update_fn: str):
# Process trajectories through model services
processed_values = {}
for key, values in batch.values.items():
# Apply custom processing logic
processed_values[key] = self._process_values(values)
# Create processed batch
processed_batch = Batch(
values=processed_values,
metadata=batch.metadata
)
# Submit to training pool
self._put_batch(processed_batch, update_fn)# In any module with activity tracker
with self._activity_tracker.track("MyModule", "processing_data"):
# Perform work that should be tracked
result = process_data(data)
# Automatic activity tracking and error reporting
# Check rollout worker specific status
if self._activity_tracker.is_rollout_worker_quiescent():
logger.info("All rollout workers are idle")
# Log metrics to experiment tracking
self._activity_tracker.experiment_logger_post(
backend="wandb",
data={"metric_name": value},
step=training_step
)# Validation is automatically triggered by the controller when:
# 1. validate_before_train is enabled (runs before training starts)
# 2. Weight sync completes and validation frequency is configured
# Custom rollout workers should handle validation mode:
def step(self, task: dict[str, Any]) -> str | None:
# Check if in validation mode
if self._is_running_validate:
# Use validation dataloader and validator
validate_task = self._get_validate_task()
trajectory = self._process_task(validate_task)
return self._put_validate_trajectory(trajectory)
else:
# Normal training mode
trajectory = self._process_task(task)
return self._put_trajectory(trajectory)- Always inherit from NexRLModule for Ray compatibility
- Implement proper cleanup in stop() methods
- Use activity tracking for long-running operations
- Handle exceptions gracefully and report through activity tracker
- Define clear resource pool mappings based on workload
- Use co-location for related services to reduce communication overhead
- Monitor resource usage through activity tracker
- Plan GPU allocation based on model requirements
- Use structured error reporting through ErrorReporter
- Implement proper retry logic for transient failures
- Monitor system health through activity tracker
- Define clear error policies for different scenarios
- Use Hydra for configuration management
- Define environment-specific overrides
- Validate configurations before deployment
- Document configuration options clearly
- Use appropriate batching strategies for trajectory collection
- Monitor activity tracker for bottlenecks
- Optimize resource pool configurations
- Use async execution patterns where appropriate
- Check Ray cluster status with
ray status - Verify placement group creation
- Monitor resource allocation
- Check activity tracker health status
- Verify module references are set correctly
- Monitor error reporter for communication errors
- Check activity tracker for bottlenecks
- Monitor resource utilization
- Verify batching configurations
- Validate configuration syntax
- Check module type compatibility
- Verify resource specifications
- Ensure validation dataloader is properly configured
- Check that rollout workers support validation mode (begin_validate/end_validate)
- Verify validator is receiving trajectories
- Check validation frequency configuration in weight sync controller
- Verify checkpoint paths are accessible
- Check resume mode configuration (disable/auto/from_path)
- Ensure checkpoint directory structure matches expected format (global_step_*)
- Verify train service has proper checkpoint save/load permissions
# Get current system status
status = controller.activity_tracker.get_running_status_summary()
print(f"System status: {status}")
# Check health
health = controller.activity_tracker.get_error_health_status()
print(f"Health status: {health}")# Check actor wrappers for each role
for role in NexRLRole:
wrappers = resource_manager.get_actor_wrapper(role)
print(f"{role}: {len(wrappers)} actors")
# Check module health
for module_name, module_ref in controller.activity_tracker._module_refs.items():
is_alive = controller.activity_tracker.check_module_liveness()
print(f"{module_name}: {'alive' if is_alive else 'dead'}")# Get recent errors
health_status = error_reporter.get_health_status()
if health_status["status"] != "healthy":
print(f"System unhealthy: {health_status['message']}")This developer guide provides comprehensive coverage of the NexRL framework's architecture, components, and usage patterns. For specific implementation details, refer to the source code and configuration examples.
