Skip to content

Checkpointing

Writing multi-gigabyte neural network weights and optimizer states to disk can freeze your GPU/TPU accelerators for several seconds, reducing hardware utilization.

Roxxel's Checkpointer leverages Orbax Checkpoint Manager to offload PyTree serialization to background threads, allowing your training loop to continue JAX/Flax calculations immediately.


Model & Optimizer Save/Restore Flow

Here is a complete example showing how to initialize a model, configure an Optax optimizer using Flax NNX, and manage state restoration and periodic saving:

import jax
import jax.numpy as jnp
import optax
from flax import nnx
from roxxel.checkpoint import Checkpointer

# 1. Initialize Flax NNX model and optimizer
class SimpleModel(nnx.Module):
    def __init__(self, rngs):
        self.linear = nnx.Linear(10, 5, rngs=rngs)
    def __call__(self, x):
        return self.linear(x)

rngs = nnx.Rngs(42)
model = SimpleModel(rngs)
tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

# 2. Instantiate the Checkpointer
checkpointer = Checkpointer(
    checkpoint_path="./checkpoints",
    model=model,
    optimizer=optimizer,
    max_to_keep=3
)

# 3. Restore existing checkpoint if present
start_step = checkpointer.restore()
print(f"Resumed from step: {start_step}")

# 4. Training Loop
for step in range(start_step, 1000):
    # Train step...
    loss_val = 0.52

    # Save asynchronously (does not block JAX compilation or execution)
    if step % 100 == 0:
        checkpointer.save(
            step=step,
            metrics_dict={"loss": loss_val}
        )

Core Features

1. Zero-Latency Background Offloading

When you call checkpointer.save(...), Roxxel splits the Flax NNX states, packages them into abstract Orbax structures, and transfers the I/O execution to a background thread pool. Your TPU/GPU training is not blocked.

2. NNX Topology Agnostic

Roxxel's checkpointer avoids saving model architecture graphs directly to disk. Instead, it queries an abstract template of your model on restoration. This decoupling means you can update your Python model definition (e.g. adding layers or modifying attributes) without breaking compatibility with older saved weights.

3. Automated Best-Loss Tracking

Under the hood, the checkpointer tracks the evaluation metrics you pass via metrics_dict. By default, it automatically identifies and preserves the checkpoint achieving the lowest training loss (best_mode='min'), ensuring you never lose your best model state.


API Reference

roxxel.checkpoint.Checkpointer

Asynchronous JAX/Flax NNX module and optimizer checkpointer.

Uses Orbax Checkpoint Manager underneath to perform zero-overhead, multi-threaded state serialization on background threads, preventing disk writes from blocking accelerator (GPU/TPU) training.

Supports topology-agnostic PyTree reconstruction and automated best-loss tracking.

Source code in roxxel/checkpoint.py
class Checkpointer:
    """
    Asynchronous JAX/Flax NNX module and optimizer checkpointer.

    Uses Orbax Checkpoint Manager underneath to perform zero-overhead, multi-threaded
    state serialization on background threads, preventing disk writes from blocking 
    accelerator (GPU/TPU) training.

    Supports topology-agnostic PyTree reconstruction and automated best-loss tracking.
    """
    def __init__(self, checkpoint_path: str, model: nnx.Module, optimizer: nnx.Optimizer, max_to_keep: int = 3, timeout: int = 1000):
        """
        Args:
            checkpoint_path (str): The local or cloud storage path where checkpoints are written.
            model (flax.nnx.Module): The JAX/Flax NNX model state to serialize.
            optimizer (flax.nnx.Optimizer): The JAX/Flax NNX optimizer state containing optimizer parameters.
            max_to_keep (int, optional): The maximum number of recent checkpoints to retain. Defaults to 3.
            timeout (int, optional): The timeout in seconds for background asynchronous operations. Defaults to 1000.
        """
        self.checkpoint_path = os.path.abspath(checkpoint_path)
        self.model = model
        self.optimizer = optimizer
        self.graphdef, _ = nnx.split((self.model, self.optimizer))

        async_opts = ocp.options.AsyncOptions(
            timeout_secs=timeout,                
            create_directories_asynchronously=True
        )

        options = ocp.CheckpointManagerOptions(
            max_to_keep=max_to_keep, 
            create=True,
            async_options=async_opts,
            best_fn=lambda metrics: metrics['loss'], 
            best_mode='min'                           
        )
        self.mngr = ocp.CheckpointManager(self.checkpoint_path, options=options)

    def save(self, step: int, metrics_dict: dict):
        """Extracts the global variable states and delegates storage and optimization metrics to Orbax.

        Args:
            step (int): The current training step number (used as the checkpoint subdirectory key).
            metrics_dict (dict): A dictionary containing metrics (like 'loss') to track best checkpoints.
        """
        _, state = nnx.split((self.model, self.optimizer))
        loss_val = float(metrics_dict.get("loss", 999.0))

        # Orbax handles background multi-threading, file validation, and rank-zero safely.
        self.mngr.save(
            int(step),
            args=ocp.args.StandardSave(state),
            metrics={'loss': loss_val}
        )

    def restore(self) -> int:
        """Restores parameters and optimizer tracking vectors natively without dictionary nesting.

        Returns:
            The step index of the restored checkpoint, or 0 if no checkpoint was found.
        """
        latest_step = self.mngr.latest_step()
        if latest_step is None:
            return 0

        def get_abstract_state_template():
            abstract_model = nnx.eval_shape(lambda: nnx.merge(self.graphdef, nnx.state((self.model, self.optimizer))))
            _, abstract_state = nnx.split(abstract_model)
            return abstract_state

        abstract_template = get_abstract_state_template()

        restored_state = self.mngr.restore(
            latest_step,
            args=ocp.args.StandardRestore(abstract_template)
        )

        nnx.update((self.model, self.optimizer), restored_state)
        return int(latest_step)

__init__(checkpoint_path, model, optimizer, max_to_keep=3, timeout=1000)

Parameters:

Name Type Description Default
checkpoint_path str

The local or cloud storage path where checkpoints are written.

required
model Module

The JAX/Flax NNX model state to serialize.

required
optimizer Optimizer

The JAX/Flax NNX optimizer state containing optimizer parameters.

required
max_to_keep int

The maximum number of recent checkpoints to retain. Defaults to 3.

3
timeout int

The timeout in seconds for background asynchronous operations. Defaults to 1000.

1000
Source code in roxxel/checkpoint.py
def __init__(self, checkpoint_path: str, model: nnx.Module, optimizer: nnx.Optimizer, max_to_keep: int = 3, timeout: int = 1000):
    """
    Args:
        checkpoint_path (str): The local or cloud storage path where checkpoints are written.
        model (flax.nnx.Module): The JAX/Flax NNX model state to serialize.
        optimizer (flax.nnx.Optimizer): The JAX/Flax NNX optimizer state containing optimizer parameters.
        max_to_keep (int, optional): The maximum number of recent checkpoints to retain. Defaults to 3.
        timeout (int, optional): The timeout in seconds for background asynchronous operations. Defaults to 1000.
    """
    self.checkpoint_path = os.path.abspath(checkpoint_path)
    self.model = model
    self.optimizer = optimizer
    self.graphdef, _ = nnx.split((self.model, self.optimizer))

    async_opts = ocp.options.AsyncOptions(
        timeout_secs=timeout,                
        create_directories_asynchronously=True
    )

    options = ocp.CheckpointManagerOptions(
        max_to_keep=max_to_keep, 
        create=True,
        async_options=async_opts,
        best_fn=lambda metrics: metrics['loss'], 
        best_mode='min'                           
    )
    self.mngr = ocp.CheckpointManager(self.checkpoint_path, options=options)

restore()

Restores parameters and optimizer tracking vectors natively without dictionary nesting.

Returns:

Type Description
int

The step index of the restored checkpoint, or 0 if no checkpoint was found.

Source code in roxxel/checkpoint.py
def restore(self) -> int:
    """Restores parameters and optimizer tracking vectors natively without dictionary nesting.

    Returns:
        The step index of the restored checkpoint, or 0 if no checkpoint was found.
    """
    latest_step = self.mngr.latest_step()
    if latest_step is None:
        return 0

    def get_abstract_state_template():
        abstract_model = nnx.eval_shape(lambda: nnx.merge(self.graphdef, nnx.state((self.model, self.optimizer))))
        _, abstract_state = nnx.split(abstract_model)
        return abstract_state

    abstract_template = get_abstract_state_template()

    restored_state = self.mngr.restore(
        latest_step,
        args=ocp.args.StandardRestore(abstract_template)
    )

    nnx.update((self.model, self.optimizer), restored_state)
    return int(latest_step)

save(step, metrics_dict)

Extracts the global variable states and delegates storage and optimization metrics to Orbax.

Parameters:

Name Type Description Default
step int

The current training step number (used as the checkpoint subdirectory key).

required
metrics_dict dict

A dictionary containing metrics (like 'loss') to track best checkpoints.

required
Source code in roxxel/checkpoint.py
def save(self, step: int, metrics_dict: dict):
    """Extracts the global variable states and delegates storage and optimization metrics to Orbax.

    Args:
        step (int): The current training step number (used as the checkpoint subdirectory key).
        metrics_dict (dict): A dictionary containing metrics (like 'loss') to track best checkpoints.
    """
    _, state = nnx.split((self.model, self.optimizer))
    loss_val = float(metrics_dict.get("loss", 999.0))

    # Orbax handles background multi-threading, file validation, and rank-zero safely.
    self.mngr.save(
        int(step),
        args=ocp.args.StandardSave(state),
        metrics={'loss': loss_val}
    )