Skip to content

Logging

Roxxel's Logger implements a high-performance, queue-based logging architecture. To prevent system I/O (like writes to stdout, files, and CSVs) from slowing down high-throughput TPU/GPU training loops, all log writing is offloaded to background worker threads.


Why the Context Manager is Critical

Because logging happens asynchronously on a separate background thread, standard print statements or un-managed logs are highly vulnerable. If your TPU/GPU throws an Out of Memory (OOM) error or JAX crashes: 1. The main thread terminates instantly. 2. The logging queue gets cut off. 3. The most critical debug messages/tracebacks at the end of the run are lost.

Using the with context manager solves this completely:

import time
from roxxel import Logger

# Initialize the async logger context
with Logger(log_dir="./run_directory") as logger:
    logger.log_message("Initializing deep pre-training cluster...")

    # Under the hood, any exceptions raised here are caught by the context manager.
    # It logs the traceback, drains/flushes the async queue to disk, and then propagates the error.
    time.sleep(1)
    raise RuntimeError("TPU Device Out of Memory!")

# The background thread is safely joined and shut down here.

Automatic Crash Traceback Capture

When a crash occurs inside the Logger context: 1. The traceback is immediately intercepted. 2. It writes the traceback cleanly to both stdout and {log_dir}/{prefix}_system.log. 3. It forces the queue to block and drain entirely, guaranteeing that every single log line is written to disk before the program terminates.


Multi-Host TPU Rank-Zero Filter

When scaling JAX code across TPU Pods or multi-node GPU clusters, standard print statements are executed by every worker node simultaneously, resulting in corrupted, duplicate log files.

Roxxel's Logger detects JAX rank automatically: * Only Rank 0 writes messages to stdout, log files, or CSVs. * Other ranks (1..N) execute logging statements as safe noop operations, preventing file conflicts and terminal pollution.


Asynchronous Metrics CSV Logging

You can record training metrics (like loss, learning rate, and perplexity) directly to a CSV file without blocking JAX JIT execution:

from roxxel import Logger

with Logger(log_dir="./logs", filename_prefix="run_alpha") as logger:
    for step in range(100):
        # Your JAX training loop here
        loss = 2.5 - (step * 0.01)
        lr = 3e-4

        # This pushes metrics to a background queue instantly (0ms overhead)
        logger.log_metrics_summary(
            step=step,
            metrics={"loss": loss, "lr": lr}
        )
This produces logs/run_alpha_metrics.csv automatically with properly aligned column headers on step resumption.


API Reference

roxxel.logging.Logger

A high-performance, non-blocking, asynchronous logger designed for distributed JAX/Flax pre-training clusters (e.g. multi-host TPU/GPU Pods).

Offloads heavy I/O operations (stdout writes, system logs, and metrics CSV writes) to background threads, taking 0ms on the main training loop thread.

Guarantees rank-zero execution (only rank 0 logs to files/stdout/metrics) to avoid terminal spam and multi-process file locking contention.

Supports context manager ('with' statements) to guarantee that if a TPU/GPU crashes, OOMs, or is forcefully interrupted, all logging queues and threads are completely flushed and drained to disk before termination.

Source code in roxxel/logging.py
class Logger:
    """
    A high-performance, non-blocking, asynchronous logger designed for 
    distributed JAX/Flax pre-training clusters (e.g. multi-host TPU/GPU Pods).

    Offloads heavy I/O operations (stdout writes, system logs, and metrics CSV writes)
    to background threads, taking 0ms on the main training loop thread.

    Guarantees rank-zero execution (only rank 0 logs to files/stdout/metrics) to avoid 
    terminal spam and multi-process file locking contention.

    Supports context manager ('with' statements) to guarantee that if a TPU/GPU 
    crashes, OOMs, or is forcefully interrupted, all logging queues and threads 
    are completely flushed and drained to disk before termination.
    """
    def __init__(self, log_dir: str, filename_prefix: str = "roxxel", logger_name: str = "RoxxelCore"):
        """
        Args:
            log_dir (str): Directory where standard log files and metric CSV files will be saved.
            filename_prefix (str, optional): Prefix for generated log files. Defaults to "roxxel".
            logger_name (str, optional): Name of the underlying Python logger. Defaults to "RoxxelCore".
        """
        self.log_dir = log_dir

        try:
            import jax
            # Use getattr to safely handle any potential future JAX modifications
            process_index_fn = getattr(jax, "process_index", lambda: 0)
            self.is_rank_zero = (process_index_fn() == 0)
        except ImportError:
            self.is_rank_zero = True

        if self.is_rank_zero:
            os.makedirs(self.log_dir, exist_ok=True)

            # 1. Asynchronous System Logger Queue Setup
            self.log_queue = Queue(-1)
            self.sys_log_path = os.path.join(self.log_dir, f"{filename_prefix}_system.log")
            text_formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

            stdout_worker = logging.StreamHandler(sys.stdout)
            stdout_worker.setFormatter(text_formatter)
            file_worker = logging.FileHandler(self.sys_log_path, encoding="utf-8")
            file_worker.setFormatter(text_formatter)

            self.listener = QueueListener(self.log_queue, stdout_worker, file_worker)
            self.listener.start()

            self.queue_handler = QueueHandler(self.log_queue)
            self.logger = logging.getLogger(logger_name)
            self.logger.setLevel(logging.INFO)
            self.logger.handlers.clear()
            self.logger.addHandler(self.queue_handler)
            self.logger.propagate = False

            # 2. Asynchronous Metrics CSV Writer Setup
            self.metrics_csv_path = os.path.join(self.log_dir, f"{filename_prefix}_metrics.csv")
            self.metrics_queue = Queue(-1)
            self.metrics_thread = threading.Thread(target=self._metrics_writer_worker, daemon=True)
            self.metrics_thread.start()

    def _metrics_writer_worker(self):
        """Background worker thread that serializes metric dictionaries to the CSV file sequentially."""
        header_written = os.path.exists(self.metrics_csv_path)
        metrics_keys = None

        # If a CSV file already exists (e.g. on step resumption), read its header to keep columns aligned
        if header_written:
            try:
                with open(self.metrics_csv_path, "r", encoding="utf-8") as f:
                    first_line = f.readline().strip()
                    if first_line:
                        metrics_keys = first_line.split(",")[1:]  # Skip the first column ("step")
            except Exception:
                header_written = False

        while True:
            item = self.metrics_queue.get()
            if item is None:  # Shutdown sentinel
                self.metrics_queue.task_done()
                break

            step, metrics = item
            if metrics_keys is None:
                metrics_keys = list(metrics.keys())
                # Write header row
                with open(self.metrics_csv_path, "w", newline="", encoding="utf-8") as f:
                    f.write("step," + ",".join(metrics_keys) + "\n")
                header_written = True

            # Format values nicely (floating point floats mapped to .5f precision)
            vals = []
            for k in metrics_keys:
                val = metrics.get(k, "")
                if isinstance(val, float):
                    vals.append(f"{val:.5f}")
                elif isinstance(val, (int, bool)):
                    vals.append(str(val))
                else:
                    vals.append(str(val))

            # Direct flat text append is fast and thread-safe inside the dedicated worker thread
            with open(self.metrics_csv_path, "a", newline="", encoding="utf-8") as f:
                f.write(f"{step}," + ",".join(vals) + "\n")

            self.metrics_queue.task_done()

    def __enter__(self):
        """Returns the logger instance itself when entering the 'with' block."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Guarantees the queue is drained and threads are safely stopped on exit or crash.
        Automatically logs uncaught tracebacks to system logs on Rank 0 if a crash occurs.
        """
        if exc_type is not None and self.is_rank_zero:
            tb_lines = traceback.format_exception(exc_type, exc_val, exc_tb)
            self.logger.error("❌ CRITICAL: Uncaught exception occurred during execution!")
            for line in tb_lines:
                self.logger.error(line.rstrip())

        self.close()
        # Return False to let the exception bubble up normally after flushing logs
        return False

    def log_message(self, message: str, level: int = logging.INFO):
        """Passes a string to the system log queue. Takes 0ms on your main training loop thread.

        Args:
            message (str): The log message string.
            level (int, optional): The log level (e.g. logging.INFO, logging.WARNING). Defaults to logging.INFO.
        """
        if self.is_rank_zero:
            self.logger.log(level, message)

    def log_metrics_summary(self, step: int, metrics: dict):
        """Appends arbitrary metric dictionary data asynchronously to a persistent CSV file.

        Args:
            step (int): The current training step number.
            metrics (dict): Dict of metrics to write (e.g. {'loss': 0.1, 'accuracy': 0.9}).
        """
        if self.is_rank_zero:
            self.metrics_queue.put((step, metrics))

    def close(self):
        """Forces all background asynchronous write threads to drain and complete disk writes."""
        if self.is_rank_zero:
            # 1. Stop metrics writer and wait for queue to drain completely
            if hasattr(self, 'metrics_queue') and hasattr(self, 'metrics_thread'):
                self.metrics_queue.put(None)
                self.metrics_thread.join()

            # 2. Stop system logs listener
            if hasattr(self, 'listener'):
                self.listener.stop()  # Drains log queue completely to disk before closing
                for handler in self.listener.handlers:
                    if hasattr(handler, 'close'):
                        handler.close()

            # 3. Cleanly remove handler from singleton logger
            if hasattr(self, 'logger') and hasattr(self, 'queue_handler'):
                self.logger.removeHandler(self.queue_handler)

__enter__()

Returns the logger instance itself when entering the 'with' block.

Source code in roxxel/logging.py
def __enter__(self):
    """Returns the logger instance itself when entering the 'with' block."""
    return self

__exit__(exc_type, exc_val, exc_tb)

Guarantees the queue is drained and threads are safely stopped on exit or crash. Automatically logs uncaught tracebacks to system logs on Rank 0 if a crash occurs.

Source code in roxxel/logging.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """
    Guarantees the queue is drained and threads are safely stopped on exit or crash.
    Automatically logs uncaught tracebacks to system logs on Rank 0 if a crash occurs.
    """
    if exc_type is not None and self.is_rank_zero:
        tb_lines = traceback.format_exception(exc_type, exc_val, exc_tb)
        self.logger.error("❌ CRITICAL: Uncaught exception occurred during execution!")
        for line in tb_lines:
            self.logger.error(line.rstrip())

    self.close()
    # Return False to let the exception bubble up normally after flushing logs
    return False

__init__(log_dir, filename_prefix='roxxel', logger_name='RoxxelCore')

Parameters:

Name Type Description Default
log_dir str

Directory where standard log files and metric CSV files will be saved.

required
filename_prefix str

Prefix for generated log files. Defaults to "roxxel".

'roxxel'
logger_name str

Name of the underlying Python logger. Defaults to "RoxxelCore".

'RoxxelCore'
Source code in roxxel/logging.py
def __init__(self, log_dir: str, filename_prefix: str = "roxxel", logger_name: str = "RoxxelCore"):
    """
    Args:
        log_dir (str): Directory where standard log files and metric CSV files will be saved.
        filename_prefix (str, optional): Prefix for generated log files. Defaults to "roxxel".
        logger_name (str, optional): Name of the underlying Python logger. Defaults to "RoxxelCore".
    """
    self.log_dir = log_dir

    try:
        import jax
        # Use getattr to safely handle any potential future JAX modifications
        process_index_fn = getattr(jax, "process_index", lambda: 0)
        self.is_rank_zero = (process_index_fn() == 0)
    except ImportError:
        self.is_rank_zero = True

    if self.is_rank_zero:
        os.makedirs(self.log_dir, exist_ok=True)

        # 1. Asynchronous System Logger Queue Setup
        self.log_queue = Queue(-1)
        self.sys_log_path = os.path.join(self.log_dir, f"{filename_prefix}_system.log")
        text_formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

        stdout_worker = logging.StreamHandler(sys.stdout)
        stdout_worker.setFormatter(text_formatter)
        file_worker = logging.FileHandler(self.sys_log_path, encoding="utf-8")
        file_worker.setFormatter(text_formatter)

        self.listener = QueueListener(self.log_queue, stdout_worker, file_worker)
        self.listener.start()

        self.queue_handler = QueueHandler(self.log_queue)
        self.logger = logging.getLogger(logger_name)
        self.logger.setLevel(logging.INFO)
        self.logger.handlers.clear()
        self.logger.addHandler(self.queue_handler)
        self.logger.propagate = False

        # 2. Asynchronous Metrics CSV Writer Setup
        self.metrics_csv_path = os.path.join(self.log_dir, f"{filename_prefix}_metrics.csv")
        self.metrics_queue = Queue(-1)
        self.metrics_thread = threading.Thread(target=self._metrics_writer_worker, daemon=True)
        self.metrics_thread.start()

close()

Forces all background asynchronous write threads to drain and complete disk writes.

Source code in roxxel/logging.py
def close(self):
    """Forces all background asynchronous write threads to drain and complete disk writes."""
    if self.is_rank_zero:
        # 1. Stop metrics writer and wait for queue to drain completely
        if hasattr(self, 'metrics_queue') and hasattr(self, 'metrics_thread'):
            self.metrics_queue.put(None)
            self.metrics_thread.join()

        # 2. Stop system logs listener
        if hasattr(self, 'listener'):
            self.listener.stop()  # Drains log queue completely to disk before closing
            for handler in self.listener.handlers:
                if hasattr(handler, 'close'):
                    handler.close()

        # 3. Cleanly remove handler from singleton logger
        if hasattr(self, 'logger') and hasattr(self, 'queue_handler'):
            self.logger.removeHandler(self.queue_handler)

log_message(message, level=logging.INFO)

Passes a string to the system log queue. Takes 0ms on your main training loop thread.

Parameters:

Name Type Description Default
message str

The log message string.

required
level int

The log level (e.g. logging.INFO, logging.WARNING). Defaults to logging.INFO.

INFO
Source code in roxxel/logging.py
def log_message(self, message: str, level: int = logging.INFO):
    """Passes a string to the system log queue. Takes 0ms on your main training loop thread.

    Args:
        message (str): The log message string.
        level (int, optional): The log level (e.g. logging.INFO, logging.WARNING). Defaults to logging.INFO.
    """
    if self.is_rank_zero:
        self.logger.log(level, message)

log_metrics_summary(step, metrics)

Appends arbitrary metric dictionary data asynchronously to a persistent CSV file.

Parameters:

Name Type Description Default
step int

The current training step number.

required
metrics dict

Dict of metrics to write (e.g. {'loss': 0.1, 'accuracy': 0.9}).

required
Source code in roxxel/logging.py
def log_metrics_summary(self, step: int, metrics: dict):
    """Appends arbitrary metric dictionary data asynchronously to a persistent CSV file.

    Args:
        step (int): The current training step number.
        metrics (dict): Dict of metrics to write (e.g. {'loss': 0.1, 'accuracy': 0.9}).
    """
    if self.is_rank_zero:
        self.metrics_queue.put((step, metrics))