End-to-End Tutorial: Multi-Phase Curriculum Learning
This tutorial shows how to construct an end-to-end distributed pre-training pipeline using Roxxel. We will showcase how to build a dynamic curriculum learning schedule (progressively shifting from short-context length to long-context length) and execute it using Roxxel's built-in Trainer and Curriculum system.
🌟 Architecture Overview
Roxxel's training runner infrastructure uses a modular, decoupled hierarchy to manage data streaming and execution:
- Streamer (
Roxxel): The core class representing virtualized, memory-mapped shards of dataset blocks. Curriculum: Manages the training roadmap (phases, sequence lengths, batch sizes, dataset weights, and dataset blending). It wraps the primary and optional secondary dataset streamers.Trainer: The orchestrator class. It accepts theCurriculumschedule, JAX model/optimizer, JIT-compiled train step, asyncLogger, andCheckpointer. It runs the training loop, automatically hot-swapping streams and reshaping JAX arrays at step boundaries, saving checkpoints, and executing evaluations.
graph TD
A[Roxxel Streamers] --> B[Curriculum]
B --> C[Trainer]
D[Flax NNX State & Optimizer] --> C
E[Checkpointer & Logger] --> C
C --> F[Optimized Loop Execution]
Complete Curriculum Pre-Training Cookbook
Here is a complete, real-world implementation combining data compilation, multi-phase curriculum streaming, asynchronous logging, JAX hardware device sharding, and Orbax asynchronous checkpointing:
import os
import jax
import jax.numpy as jnp
import optax
import numpy as np
from flax import nnx
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils
from roxxel import Roxxel, Phase, Curriculum, Trainer
# --- 1. DEFINE ARCHITECTURE ---
class Xenron(nnx.Module):
"""Xenron model architecture."""
def __init__(self, num_layers: int, rngs: nnx.Rngs):
self.embed = nnx.Embed(10000, 256, rngs=rngs)
self.linear = nnx.Linear(256, 10000, rngs=rngs)
def __call__(self, x):
return self.linear(self.embed(x))
# --- 2. COMPILE TOY DATASET ---
def token_generator():
"""Generates continuous tokenized integer sequences."""
for i in range(10000):
yield np.random.randint(0, 10000, size=(128,), dtype=np.int32)
DATASET_PATTERN = "./wiki_*.rox"
rox = Roxxel(DATASET_PATTERN)
# Compile raw token generator into uniform 4KB block archives
rox.write(token_generator(), separator=b"\x00", block_size=4096, max_shard_bytes=1024**3)
# --- 3. TRAINING HYPERPARAMETERS ---
GLOBAL_SEED = 42
LR = 3e-4
checkpoint_path = "./checkpoints"
# --- 4. SAMPLING AND TRAINING FUNCTIONS ---
def sample_now(state) -> str:
"""Mock sampling function for step evaluation."""
return f"[Decoded Text from step {int(state.step[...])}]: Once upon a time in a JAX device cluster..."
# Define loss function
def loss_fn(model, batch):
# Predict next tokens (causal shift)
logits = model(batch[:, :-1])
targets = batch[:, 1:]
loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()
return loss
# --- 5. MAIN TRAINING EXECUTION ---
def main():
print("🚀 Initializing Distributed Pre-training Cluster...")
# Initialize model
rngs = nnx.Rngs(GLOBAL_SEED)
model = Xenron(4, rngs)
# Distributed hardware sharding paths
devices = jax.devices()
mesh = Mesh(mesh_utils.create_device_mesh((len(devices),)), axis_names=('data',))
data_sharding = NamedSharding(mesh, P('data', None))
# 1. Setup our Roxxel dataset streamers
with Roxxel(filepath=DATASET_PATTERN) as init_ds:
phase1_steps = init_ds.estimate_steps(seq_len=1025, batch_size=16)
phase2_full_steps = init_ds.estimate_steps(seq_len=32769, batch_size=1)
phase2_steps = int(phase2_full_steps * 0.20) # 20% of long-context epoch
# 2. Define the curriculum schedule
# Format: Phase(steps, batch_size, seq_len, optional_weights)
phases = [
Phase(steps=phase1_steps, batch_size=16, seq_len=1025), # Phase 1: Base Pre-training
Phase(steps=phase2_steps, batch_size=1, seq_len=32769), # Phase 2: Context Extension
]
# Instantiate primary dataset curriculum
curriculum = Curriculum(
primary_streamer=Roxxel(DATASET_PATTERN),
phases=phases
)
# 3. Calculate total optimizer tracking steps
total_train_steps = sum(p.steps for p in phases)
# Continuous decay schedule spanning the full curriculum duration
tx = optax.chain(
optax.clip_by_global_norm(1.0),
optax.nadamw(
learning_rate=optax.warmup_cosine_decay_schedule(
init_value=1e-7,
peak_value=LR,
warmup_steps=int(total_train_steps * 0.05),
decay_steps=total_train_steps,
end_value=LR * 0.01
),
weight_decay=0.01
)
)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
# 4. Define Trainer orchestrator
# The Trainer automatically initializes the async Checkpointer, Logger, and ModelState internally
trainer = Trainer(
model=model,
optimizer=optimizer,
curriculum=curriculum,
loss_fn=loss_fn,
save_path="run_delta",
eval_fn=sample_now,
eval_every=500,
checkpoint_every=100,
log_every=100,
seed=GLOBAL_SEED,
mesh=mesh,
data_sharding=data_sharding
)
# 5. Execute training
trainer.run()
if __name__ == "__main__":
main()