Skip to content

Trainer & ModelState

Roxxel's Trainer is a curriculum-aware training orchestrator specifically designed for JAX and Flax NNX.

It handles training loops, dynamic sequence and batch transitions, metric logging, Orbax checkpointing, and model evaluations with minimal boilerplate.


Easiest Trainer Configuration

With Roxxel, you do not need to write custom training states, or explicitly instantiate checkpointers and loggers. Simply supply your model, optimizer, curriculum, and a loss_fn, along with a unified save_path directory.

import jax
import optax
from flax import nnx
from roxxel import Roxxel, Phase, Curriculum, Trainer

# 1. Define Flax NNX model and optimizer
model = nnx.Linear(10, 5, rngs=nnx.Rngs(42))
tx = optax.sgd(0.01)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

# 2. Define the curriculum
phases = [Phase(steps=1000, batch_size=4, seq_len=10)]
curriculum = Curriculum(primary_streamer=Roxxel("./data_*.rox"), phases=phases)

# 3. Define the loss function
def loss_fn(model, batch):
    logits = model(batch[:, :-1].astype(jax.numpy.float32))
    targets = batch[:, 1:].astype(jax.numpy.float32)
    return jax.numpy.mean((logits - targets) ** 2)

# 4. Initialize the Trainer
# Setting save_path automatically initializes the checkpointer and logger
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    curriculum=curriculum,
    loss_fn=loss_fn,
    save_path="./run_delta",
    checkpoint_every=100,
    log_every=10
)

# 5. Run curriculum training
trainer.run()

Core Features

1. Automated ModelState Creation

When you pass a standard JAX model and optimizer separately, the trainer constructs a ModelState object internally. It maintains: - state.model: Reference to the Flax NNX Module. - state.optimizer: Reference to the Flax NNX Optimizer. - state.step: An nnx.Variable representing the global optimization step.

If you already have a pre-constructed custom state object containing model and optimizer attributes, the trainer automatically detects it for backward compatibility.

2. Internal JIT Train Step Compilation

The Trainer automatically defines and compiles a standard Flax JIT training step (@nnx.jit) on initialization. It executes: - Forward pass through your loss_fn. - Gradient computation via nnx.value_and_grad. - Optimizer parameters update. - Step counter incrementation.

3. Robust Loss wrapping

If your loss_fn returns multiple outputs (e.g. (loss, aux_data) or {"loss": loss, "accuracy": acc}), Trainer wraps it using loss_wrapper to ensure only the scalar loss is supplied to JAX gradient compilation, avoiding JAX compiler errors while preserving metrics.

4. Automatic Resource Management

If save_path is passed, the trainer automatically initializes: - A Checkpointer located in save_path/checkpoints. - A Logger saving metrics and system logs directly inside save_path.

Alternatively, you can pass custom checkpointer and logger instances or individual overrides as paths directly to checkpointer and logger arguments.

The trainer automatically executes all process-critical training steps within the logger's asynchronous context manager to guarantee tracebacks are logged and flushing occurs even during training crashes. It also executes asynchronous checkpointer flushes and close routines in final cleanup hooks.


API Reference

Trainer

roxxel.trainer.Trainer

Curriculum-aware pre-training orchestrator designed for JAX/Flax NNX.

Accepts the Curriculum schedule (which wraps the Roxxel dataset streamers) and manages the pre-training loop execution, boundary transitions, hot-swapping, asynchronous logging, evaluations, and Orbax checkpointing.

Source code in roxxel/trainer.py
class Trainer:
    """
    Curriculum-aware pre-training orchestrator designed for JAX/Flax NNX.

    Accepts the Curriculum schedule (which wraps the Roxxel dataset streamers) and
    manages the pre-training loop execution, boundary transitions, hot-swapping,
    asynchronous logging, evaluations, and Orbax checkpointing.
    """
    def __init__(
        self,
        model,
        optimizer,
        curriculum: Curriculum,
        loss_fn,
        save_path=None,
        checkpointer=None,
        logger=None,
        eval_fn=None,
        eval_every: int = 500,
        checkpoint_every: int = 100,
        log_every: int = 100,
        seed: int = 42,
        mesh=None,
        data_sharding=None,
        max_to_keep: int = 3,
        timeout: int = 1000,
        async_queue_depth: int = 2,
    ):
        """
        Args:
            model (flax.nnx.Module): The JAX training state / model instance. If a pre-constructed state
                object containing `model` and `optimizer` attributes is passed, the trainer
                automatically detects it for backward compatibility.
            optimizer (flax.nnx.Optimizer): The Optax optimizer/Flax NNX optimizer instance. Can be None if a
                pre-constructed state is passed as the first argument.
            curriculum (Curriculum): The curriculum schedule object.
            loss_fn (callable): The loss function: loss_fn(model, batch) -> scalar or tuple (loss, aux).
            save_path (str, optional): The root directory where checkpoints and logs are saved.
                If provided, `checkpointer` defaults to `save_path/checkpoints` and `logger`
                defaults to `save_path`.
            checkpointer (Checkpointer, str, optional): Asynchronous Checkpointer instance
                or directory path to automatically initialize it.
            logger (Logger, str, optional): Asynchronous Logger instance or directory
                path to automatically initialize it.
            eval_fn (callable, optional): Callback for periodic evaluations: eval_fn(state) -> str/None.
            eval_every (int, optional): Run evaluations every N steps. Defaults to 500.
            checkpoint_every (int, optional): Save checkpoint every N steps. Defaults to 100.
            log_every (int, optional): Log training metrics every N steps. Defaults to 100.
            seed (int, optional): Base random seed for stream replication. Defaults to 42.
            mesh (jax.sharding.Mesh, optional): JAX hardware mesh sharding specification.
            data_sharding (jax.sharding.NamedSharding, optional): JAX named sharding specification.
            max_to_keep (int, optional): Max checkpoints to keep when initializing checkpointer path. Defaults to 3.
            timeout (int, optional): Timeout for async operations when initializing checkpointer path. Defaults to 1000.
            async_queue_depth (int, optional): Maximum number of asynchronous steps to queue on device
                before blocking host to prevent memory buildup. Defaults to 2.
        """
        # Check if the first parameter is actually a state object
        if hasattr(model, "model") and hasattr(model, "optimizer"):
            self.state = model
            self.model = model.model
            self.optimizer = optimizer if optimizer is not None else model.optimizer
        else:
            self.model = model
            self.optimizer = optimizer
            self.state = ModelState(model, optimizer)

        self.curriculum = curriculum
        self.loss_fn = loss_fn
        self.eval_fn = eval_fn
        self.eval_every = eval_every
        self.checkpoint_every = checkpoint_every
        self.log_every = log_every
        self.seed = seed
        self.mesh = mesh
        self.data_sharding = data_sharding
        self.async_queue_depth = async_queue_depth

        # Merge checkpointer and logger if save_path is provided
        if save_path is not None:
            if checkpointer is None:
                checkpointer = os.path.join(save_path, "checkpoints")
            if logger is None:
                logger = save_path

        # Handle Checkpointer initialization
        self._own_checkpointer = False
        if isinstance(checkpointer, str):
            from roxxel.checkpoint import Checkpointer
            self.checkpointer = Checkpointer(
                checkpoint_path=checkpointer,
                model=self.model,
                optimizer=self.optimizer,
                max_to_keep=max_to_keep,
                timeout=timeout
            )
            self._own_checkpointer = True
        else:
            self.checkpointer = checkpointer

        # Handle Logger initialization
        self._own_logger = False
        if isinstance(logger, str):
            from roxxel.logging import Logger
            self.logger = Logger(log_dir=logger)
            self._own_logger = True
        else:
            self.logger = logger

        # Build and JIT compile the training step internally
        @nnx.jit
        def train_step(state, batch):
            def loss_wrapper(model):
                out = self.loss_fn(model, batch)
                # Ensure only the scalar loss is returned for gradients
                if isinstance(out, (tuple, list)):
                    return out[0]
                elif isinstance(out, dict):
                    if "loss" in out:
                        return out["loss"]
                    return next(iter(out.values()))
                return out

            loss, grads = nnx.value_and_grad(loss_wrapper)(state.model)
            try:
                state.optimizer.update(state.model, grads)
            except TypeError:
                state.optimizer.update(grads)
            try:
                state.step[...] += 1
            except (TypeError, ValueError, AttributeError, KeyError):
                state.step.value += 1
            return {"loss": loss, "ppl": jnp.exp(loss)}

        self.train_step_fn = train_step

    def run(self):
        """
        Executes the curriculum training loop, automatically handling skips, resumptions,
        blending weights, and dynamic shape transitions at phase boundaries.
        """
        if self._own_logger and self.logger:
            with self.logger:
                self._run()
        else:
            self._run()

    def _run(self):
        # 1. Restore checkpoints if available
        start_step = 0
        if self.checkpointer:
            start_step = self.checkpointer.restore()

        # Update JAX-state step counter
        if hasattr(self.state, "step"):
            try:
                self.state.step[...] = jnp.array(start_step, dtype=jnp.int32)
            except (TypeError, ValueError, AttributeError, KeyError):
                if hasattr(self.state.step, "value"):
                    self.state.step.value = jnp.array(start_step, dtype=jnp.int32)

        # 2. Determine initial curriculum configuration window
        accumulated_steps = 0
        completed_phases_ledger = []

        current_seq_len = None
        current_batch_size = None
        current_phase_total_steps = None
        current_weights = None

        for idx, phase in enumerate(self.curriculum.phases):
            p_steps = phase.steps
            p_batch = phase.batch_size
            p_seq = phase.seq_len
            p_weights = phase.weights

            if start_step >= accumulated_steps + p_steps:
                # Add fully completed phases to the historical ledger
                completed_phases_ledger.append((p_steps, p_batch, p_seq))
                accumulated_steps += p_steps
            else:
                # Active phase configuration branch located
                current_batch_size = p_batch
                current_seq_len = p_seq
                current_phase_total_steps = p_steps
                current_weights = p_weights
                break

        # Calculate remaining target steps for the active streaming window session
        steps_already_done_in_current_phase = start_step - accumulated_steps
        remaining_steps_for_session = current_phase_total_steps - steps_already_done_in_current_phase

        total_train_steps = sum(p.steps for p in self.curriculum.phases)

        if self.logger:
            self.logger.log_message(f"🎯 Total Optimization Horizon: {total_train_steps} global steps.")
            self.logger.log_message(f"♻️ Resuming active phase layout: [SEQ: {current_seq_len} | BATCH: {current_batch_size}]")
            self.logger.log_message(f"📊 Remaining steps for this configuration window: {remaining_steps_for_session}")

        dataset = self.curriculum.primary_streamer
        if not dataset._is_open:
            dataset.open()

        try:
            # Helper function to construct sharded dataset streams
            def make_stream(seq_len, batch_size, step, ledger, steps_limit, weights):
                return dataset.stream(
                    seq_len=seq_len,
                    batch_size=batch_size,
                    seed=self.seed,
                    start_step=step,
                    completed_phases=ledger,
                    total_steps=steps_limit,
                    mesh=self.mesh,
                    data_sharding=self.data_sharding,
                    mix_datasets=self.curriculum.mix_streamers,
                    weights=weights
                )

            loader_stream = make_stream(
                current_seq_len,
                current_batch_size,
                start_step,
                completed_phases_ledger,
                remaining_steps_for_session,
                current_weights
            )

            from collections import deque
            metrics_buffer = deque()

            def drain_buffer():
                while metrics_buffer:
                    oldest_m = metrics_buffer.popleft()
                    if isinstance(oldest_m, dict) and "loss" in oldest_m:
                        try:
                            oldest_m["loss"].block_until_ready()
                        except Exception:
                            pass

            curr_step = start_step
            while curr_step < total_train_steps:
                for batch in loader_stream:
                    metrics = self.train_step_fn(self.state, batch)

                    # Prevent JAX asynchronous dispatch queue buildup and activation memory leaks
                    # by bounding the maximum queue depth while keeping execution pipelined/async.
                    if self.async_queue_depth is not None and self.async_queue_depth > 0:
                        metrics_buffer.append(metrics)
                        if len(metrics_buffer) >= self.async_queue_depth:
                            oldest = metrics_buffer.popleft()
                            if isinstance(oldest, dict) and "loss" in oldest:
                                try:
                                    oldest["loss"].block_until_ready()
                                except Exception:
                                    pass

                    if hasattr(self.state, "step"):
                        try:
                            curr_step = int(self.state.step[...])
                        except (TypeError, ValueError, AttributeError, KeyError):
                            try:
                                curr_step = int(self.state.step.value)
                            except (TypeError, ValueError, AttributeError):
                                try:
                                    curr_step = int(self.state.step)
                                except (TypeError, ValueError, AttributeError):
                                    curr_step += 1
                    else:
                        curr_step += 1

                    # 1. Asynchronous system logging
                    if curr_step % self.log_every == 0 and self.logger:
                        loss_val = float(metrics["loss"])
                        ppl = float(metrics["ppl"])
                        self.logger.log_message(f"S{curr_step} | Loss: {loss_val:.4f} | PPL: {ppl:.2f}")
                        self.logger.log_metrics_summary(step=curr_step, metrics={"loss": loss_val, "perplexity": ppl})

                    # 2. Asynchronous checkpointing
                    if curr_step % self.checkpoint_every == 0 and self.checkpointer:
                        self.checkpointer.save(curr_step, metrics_dict={"loss": float(metrics["loss"])})

                    # 3. Model sampling/evaluation
                    if self.eval_fn and curr_step % self.eval_every == 0:
                        drain_buffer()
                        if self.logger:
                            self.logger.log_message(f"🧪 Running Evaluation Check at Step {curr_step}...")
                        story = self.eval_fn(self.state)
                        if self.logger and story:
                            self.logger.log_message(f"EVALUATION OUTPUT:\n{story}\n")

                    # 4. Extensible phase transition swap
                    phase_boundary_accumulator = 0
                    for phase_idx, phase in enumerate(self.curriculum.phases[:-1]):
                        phase_boundary_accumulator += phase.steps

                        if curr_step == phase_boundary_accumulator:
                            drain_buffer()
                            next_phase = self.curriculum.phases[phase_idx + 1]
                            next_steps = next_phase.steps
                            next_batch = next_phase.batch_size
                            next_seq = next_phase.seq_len
                            next_weights = next_phase.weights

                            if self.logger:
                                self.logger.log_message(f"🎯 Step {curr_step} hit! Swapping dynamically to Phase {phase_idx + 2} [SEQ: {next_seq} | BATCH: {next_batch}]...")

                            # Expand historical ledger
                            completed_phases_ledger = [
                                (p.steps, p.batch_size, p.seq_len)
                                for p in self.curriculum.phases[:phase_idx + 1]
                            ]

                            # Re-instantiate JAX stream with updated shape configurations
                            loader_stream = make_stream(
                                next_seq,
                                next_batch,
                                curr_step,
                                completed_phases_ledger,
                                next_steps,
                                next_weights
                            )
                            break

                    if curr_step >= total_train_steps:
                        drain_buffer()
                        if self.logger:
                            self.logger.log_message(f"🏁 Curriculum complete: {curr_step}/{total_train_steps} steps finished successfully.")
                        break
        finally:
            drain_buffer()
            dataset.close()
            if self._own_checkpointer and self.checkpointer:
                # Wait for any pending async checkpoint saves to finish
                if hasattr(self.checkpointer, "mngr") and hasattr(self.checkpointer.mngr, "wait_until_finished"):
                    try:
                        self.checkpointer.mngr.wait_until_finished()
                    except Exception:
                        pass
                if hasattr(self.checkpointer, "mngr") and hasattr(self.checkpointer.mngr, "close"):
                    try:
                        self.checkpointer.mngr.close()
                    except Exception:
                        pass
            if self.logger:
                self.logger.log_message("✅ Global Multi-Phase Execution Complete. Roxxel Instance Closed Safely.")

__init__(model, optimizer, curriculum, loss_fn, save_path=None, checkpointer=None, logger=None, eval_fn=None, eval_every=500, checkpoint_every=100, log_every=100, seed=42, mesh=None, data_sharding=None, max_to_keep=3, timeout=1000, async_queue_depth=2)

Parameters:

Name Type Description Default
model Module

The JAX training state / model instance. If a pre-constructed state object containing model and optimizer attributes is passed, the trainer automatically detects it for backward compatibility.

required
optimizer Optimizer

The Optax optimizer/Flax NNX optimizer instance. Can be None if a pre-constructed state is passed as the first argument.

required
curriculum Curriculum

The curriculum schedule object.

required
loss_fn callable

The loss function: loss_fn(model, batch) -> scalar or tuple (loss, aux).

required
save_path str

The root directory where checkpoints and logs are saved. If provided, checkpointer defaults to save_path/checkpoints and logger defaults to save_path.

None
checkpointer (Checkpointer, str)

Asynchronous Checkpointer instance or directory path to automatically initialize it.

None
logger (Logger, str)

Asynchronous Logger instance or directory path to automatically initialize it.

None
eval_fn callable

Callback for periodic evaluations: eval_fn(state) -> str/None.

None
eval_every int

Run evaluations every N steps. Defaults to 500.

500
checkpoint_every int

Save checkpoint every N steps. Defaults to 100.

100
log_every int

Log training metrics every N steps. Defaults to 100.

100
seed int

Base random seed for stream replication. Defaults to 42.

42
mesh Mesh

JAX hardware mesh sharding specification.

None
data_sharding NamedSharding

JAX named sharding specification.

None
max_to_keep int

Max checkpoints to keep when initializing checkpointer path. Defaults to 3.

3
timeout int

Timeout for async operations when initializing checkpointer path. Defaults to 1000.

1000
async_queue_depth int

Maximum number of asynchronous steps to queue on device before blocking host to prevent memory buildup. Defaults to 2.

2
Source code in roxxel/trainer.py
def __init__(
    self,
    model,
    optimizer,
    curriculum: Curriculum,
    loss_fn,
    save_path=None,
    checkpointer=None,
    logger=None,
    eval_fn=None,
    eval_every: int = 500,
    checkpoint_every: int = 100,
    log_every: int = 100,
    seed: int = 42,
    mesh=None,
    data_sharding=None,
    max_to_keep: int = 3,
    timeout: int = 1000,
    async_queue_depth: int = 2,
):
    """
    Args:
        model (flax.nnx.Module): The JAX training state / model instance. If a pre-constructed state
            object containing `model` and `optimizer` attributes is passed, the trainer
            automatically detects it for backward compatibility.
        optimizer (flax.nnx.Optimizer): The Optax optimizer/Flax NNX optimizer instance. Can be None if a
            pre-constructed state is passed as the first argument.
        curriculum (Curriculum): The curriculum schedule object.
        loss_fn (callable): The loss function: loss_fn(model, batch) -> scalar or tuple (loss, aux).
        save_path (str, optional): The root directory where checkpoints and logs are saved.
            If provided, `checkpointer` defaults to `save_path/checkpoints` and `logger`
            defaults to `save_path`.
        checkpointer (Checkpointer, str, optional): Asynchronous Checkpointer instance
            or directory path to automatically initialize it.
        logger (Logger, str, optional): Asynchronous Logger instance or directory
            path to automatically initialize it.
        eval_fn (callable, optional): Callback for periodic evaluations: eval_fn(state) -> str/None.
        eval_every (int, optional): Run evaluations every N steps. Defaults to 500.
        checkpoint_every (int, optional): Save checkpoint every N steps. Defaults to 100.
        log_every (int, optional): Log training metrics every N steps. Defaults to 100.
        seed (int, optional): Base random seed for stream replication. Defaults to 42.
        mesh (jax.sharding.Mesh, optional): JAX hardware mesh sharding specification.
        data_sharding (jax.sharding.NamedSharding, optional): JAX named sharding specification.
        max_to_keep (int, optional): Max checkpoints to keep when initializing checkpointer path. Defaults to 3.
        timeout (int, optional): Timeout for async operations when initializing checkpointer path. Defaults to 1000.
        async_queue_depth (int, optional): Maximum number of asynchronous steps to queue on device
            before blocking host to prevent memory buildup. Defaults to 2.
    """
    # Check if the first parameter is actually a state object
    if hasattr(model, "model") and hasattr(model, "optimizer"):
        self.state = model
        self.model = model.model
        self.optimizer = optimizer if optimizer is not None else model.optimizer
    else:
        self.model = model
        self.optimizer = optimizer
        self.state = ModelState(model, optimizer)

    self.curriculum = curriculum
    self.loss_fn = loss_fn
    self.eval_fn = eval_fn
    self.eval_every = eval_every
    self.checkpoint_every = checkpoint_every
    self.log_every = log_every
    self.seed = seed
    self.mesh = mesh
    self.data_sharding = data_sharding
    self.async_queue_depth = async_queue_depth

    # Merge checkpointer and logger if save_path is provided
    if save_path is not None:
        if checkpointer is None:
            checkpointer = os.path.join(save_path, "checkpoints")
        if logger is None:
            logger = save_path

    # Handle Checkpointer initialization
    self._own_checkpointer = False
    if isinstance(checkpointer, str):
        from roxxel.checkpoint import Checkpointer
        self.checkpointer = Checkpointer(
            checkpoint_path=checkpointer,
            model=self.model,
            optimizer=self.optimizer,
            max_to_keep=max_to_keep,
            timeout=timeout
        )
        self._own_checkpointer = True
    else:
        self.checkpointer = checkpointer

    # Handle Logger initialization
    self._own_logger = False
    if isinstance(logger, str):
        from roxxel.logging import Logger
        self.logger = Logger(log_dir=logger)
        self._own_logger = True
    else:
        self.logger = logger

    # Build and JIT compile the training step internally
    @nnx.jit
    def train_step(state, batch):
        def loss_wrapper(model):
            out = self.loss_fn(model, batch)
            # Ensure only the scalar loss is returned for gradients
            if isinstance(out, (tuple, list)):
                return out[0]
            elif isinstance(out, dict):
                if "loss" in out:
                    return out["loss"]
                return next(iter(out.values()))
            return out

        loss, grads = nnx.value_and_grad(loss_wrapper)(state.model)
        try:
            state.optimizer.update(state.model, grads)
        except TypeError:
            state.optimizer.update(grads)
        try:
            state.step[...] += 1
        except (TypeError, ValueError, AttributeError, KeyError):
            state.step.value += 1
        return {"loss": loss, "ppl": jnp.exp(loss)}

    self.train_step_fn = train_step

run()

Executes the curriculum training loop, automatically handling skips, resumptions, blending weights, and dynamic shape transitions at phase boundaries.

Source code in roxxel/trainer.py
def run(self):
    """
    Executes the curriculum training loop, automatically handling skips, resumptions,
    blending weights, and dynamic shape transitions at phase boundaries.
    """
    if self._own_logger and self.logger:
        with self.logger:
            self._run()
    else:
        self._run()

ModelState

roxxel.trainer.ModelState

Bases: Module

Unified JAX/Flax NNX state module containing the model, optimizer, and step counter. Created internally by the Trainer.

Source code in roxxel/trainer.py
class ModelState(nnx.Module):
    """
    Unified JAX/Flax NNX state module containing the model, optimizer,
    and step counter. Created internally by the Trainer.
    """
    def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer):
        self.model = model
        self.optimizer = optimizer
        self.step = nnx.Variable(jnp.array(0, dtype=jnp.int32))