Checkpointing
Writing multi-gigabyte neural network weights and optimizer states to disk can freeze your GPU/TPU accelerators for several seconds, reducing hardware utilization.
Roxxel's Checkpointer leverages Orbax Checkpoint Manager to offload PyTree serialization to background threads, allowing your training loop to continue JAX/Flax calculations immediately.
Model & Optimizer Save/Restore Flow
Here is a complete example showing how to initialize a model, configure an Optax optimizer using Flax NNX, and manage state restoration and periodic saving:
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from roxxel.checkpoint import Checkpointer
# 1. Initialize Flax NNX model and optimizer
class SimpleModel(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(10, 5, rngs=rngs)
def __call__(self, x):
return self.linear(x)
rngs = nnx.Rngs(42)
model = SimpleModel(rngs)
tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
# 2. Instantiate the Checkpointer
checkpointer = Checkpointer(
checkpoint_path="./checkpoints",
model=model,
optimizer=optimizer,
max_to_keep=3
)
# 3. Restore existing checkpoint if present
start_step = checkpointer.restore()
print(f"Resumed from step: {start_step}")
# 4. Training Loop
for step in range(start_step, 1000):
# Train step...
loss_val = 0.52
# Save asynchronously (does not block JAX compilation or execution)
if step % 100 == 0:
checkpointer.save(
step=step,
metrics_dict={"loss": loss_val}
)
Core Features
1. Zero-Latency Background Offloading
When you call checkpointer.save(...), Roxxel splits the Flax NNX states, packages them into abstract Orbax structures, and transfers the I/O execution to a background thread pool. Your TPU/GPU training is not blocked.
2. NNX Topology Agnostic
Roxxel's checkpointer avoids saving model architecture graphs directly to disk. Instead, it queries an abstract template of your model on restoration. This decoupling means you can update your Python model definition (e.g. adding layers or modifying attributes) without breaking compatibility with older saved weights.
3. Automated Best-Loss Tracking
Under the hood, the checkpointer tracks the evaluation metrics you pass via metrics_dict. By default, it automatically identifies and preserves the checkpoint achieving the lowest training loss (best_mode='min'), ensuring you never lose your best model state.
API Reference
roxxel.checkpoint.Checkpointer
Asynchronous JAX/Flax NNX module and optimizer checkpointer.
Uses Orbax Checkpoint Manager underneath to perform zero-overhead, multi-threaded state serialization on background threads, preventing disk writes from blocking accelerator (GPU/TPU) training.
Supports topology-agnostic PyTree reconstruction and automated best-loss tracking.
Source code in roxxel/checkpoint.py
__init__(checkpoint_path, model, optimizer, max_to_keep=3, timeout=1000)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path
|
str
|
The local or cloud storage path where checkpoints are written. |
required |
model
|
Module
|
The JAX/Flax NNX model state to serialize. |
required |
optimizer
|
Optimizer
|
The JAX/Flax NNX optimizer state containing optimizer parameters. |
required |
max_to_keep
|
int
|
The maximum number of recent checkpoints to retain. Defaults to 3. |
3
|
timeout
|
int
|
The timeout in seconds for background asynchronous operations. Defaults to 1000. |
1000
|
Source code in roxxel/checkpoint.py
restore()
Restores parameters and optimizer tracking vectors natively without dictionary nesting.
Returns:
| Type | Description |
|---|---|
int
|
The step index of the restored checkpoint, or 0 if no checkpoint was found. |
Source code in roxxel/checkpoint.py
save(step, metrics_dict)
Extracts the global variable states and delegates storage and optimization metrics to Orbax.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step
|
int
|
The current training step number (used as the checkpoint subdirectory key). |
required |
metrics_dict
|
dict
|
A dictionary containing metrics (like 'loss') to track best checkpoints. |
required |