e representation, batched execution via JAX transformations, and deterministic stochasticity management.
Step 1: Define the State Schema
Game state must be represented as a flat, immutable data structure that can be batched across thousands of parallel instances. Instead of nested objects, we use a single named tuple or dictionary containing tensor arrays.
import jax
import jax.numpy as jnp
from typing import NamedTuple
class GameEngineState(NamedTuple):
wall_tiles: jnp.ndarray # (batch, tile_count)
hand_state: jnp.ndarray # (batch, players, hand_size)
discard_pile: jnp.ndarray # (batch, players, discard_size)
round_counter: jnp.ndarray # (batch,)
is_terminal: jnp.ndarray # (batch,)
rng_key: jnp.ndarray # (batch, 2)
Why this choice: JAX's JIT compiler requires static shapes and immutable data. A NamedTuple with fixed tensor dimensions allows jax.vmap to automatically parallelize operations across the batch dimension without Python-level loops. Keeping all state in a single structure prevents accidental mutation and simplifies state threading through jax.lax.scan.
Step 2: Implement Vectorized Transitions
The core step function must accept a batch of states and actions, apply game rules using tensor operations, and return updated states. Conditional logic (e.g., valid move checks) is replaced with jnp.where to maintain differentiability and avoid control flow breaks.
def advance_game(state: GameEngineState, action: jnp.ndarray, rng: jnp.ndarray) -> GameEngineState:
# Draw tile based on action type
draw_mask = (action == 0)
discard_mask = (action == 1)
# Sample new tile from wall
new_tile = jnp.take_along_axis(state.wall_tiles, rng % state.wall_tiles.shape[-1], axis=-1)
# Update hands and walls using conditional masking
updated_hand = jnp.where(
draw_mask[..., None, None],
jnp.concatenate([state.hand_state, new_tile[..., None, None]], axis=-1),
state.hand_state
)
updated_wall = jnp.where(
draw_mask[..., None],
jnp.concatenate([state.wall_tiles[:, 1:], jnp.zeros_like(state.wall_tiles[:, :1])], axis=-1),
state.wall_tiles
)
# Increment round and check terminal condition
new_round = state.round_counter + 1
terminal = new_round >= 34 # Example fixed length
return state._replace(
hand_state=updated_hand,
wall_tiles=updated_wall,
round_counter=new_round,
is_terminal=terminal,
rng_key=rng
)
Why this choice: Avoiding if/else branches inside the step function prevents JAX from falling back to Python execution during tracing. jnp.where ensures the computation graph remains static and fully vectorizable. The batch dimension is implicitly handled by vmap, which we apply in the rollout loop.
Step 3: Batched Rollout Execution
Training requires thousands of parallel episodes. We use jax.vmap to map the step function across the batch dimension, and jax.lax.scan to iterate until terminal states are reached.
def run_batched_rollout(initial_state: GameEngineState, policy_fn, num_steps: int):
def step_fn(carry, _):
state, rng = carry
rng, subkey = jax.random.split(rng)
# Policy outputs actions for all players in batch
actions = policy_fn(state.hand_state, subkey)
# Split RNG per batch element for stochastic transitions
batch_rng = jax.random.split(subkey, state.wall_tiles.shape[0])
new_state = jax.vmap(advance_game, in_axes=(0, 0, 0))(state, actions, batch_rng)
return (new_state, subkey), new_state
final_carry, trajectory = jax.lax.scan(step_fn, (initial_state, jax.random.PRNGKey(0)), None, length=num_steps)
return trajectory
Why this choice: jax.lax.scan compiles the entire rollout into a single XLA kernel, eliminating Python loop overhead. jax.vmap automatically handles batch dimension mapping, while explicit RNG splitting guarantees deterministic, reproducible stochasticity across parallel instances. This pattern scales linearly with GPU memory and compute capacity.
Step 4: Visualization and Debugging
Vectorized environments obscure individual game states, making debugging difficult. A decoupled visualization pipeline samples specific batch indices, reconstructs the game history, and renders it asynchronously without blocking the training loop.
def extract_debug_trace(trajectory: GameEngineState, batch_idx: int) -> list:
return [
{
"step": i,
"hand": trajectory.hand_state[i, batch_idx],
"wall": trajectory.wall_tiles[i, batch_idx],
"terminal": trajectory.is_terminal[i, batch_idx]
}
for i in range(trajectory.round_counter.shape[0])
]
Why this choice: Keeping visualization separate from the core engine prevents I/O bottlenecks. Sampling specific indices allows developers to inspect edge cases, validate rule implementations, and verify reward calculations without sacrificing throughput.
Pitfall Guide
Building GPU-vectorized environments introduces subtle failure modes that rarely appear in sequential code. The following pitfalls are drawn from production deployments and benchmarking cycles.
1. Mutable State in Parallel Loops
Explanation: Attempting to modify state arrays in-place inside jax.lax.scan or vmap triggers tracing errors or silent data corruption. JAX enforces functional purity; in-place mutations break the computation graph.
Fix: Always return a new state object using _replace() or explicit reconstruction. Never use += or direct index assignment on JAX arrays.
2. RNG State Leakage Across Batches
Explanation: Using a single PRNG key for the entire batch causes correlated randomness, breaking stochastic independence. This leads to biased reward estimation and unstable policy gradients.
Fix: Split the PRNG key per batch element using jax.random.split(key, batch_size) before passing to the step function. Thread the split keys through the scan loop.
3. Over-Vectorizing Conditional Rule Checks
Explanation: Applying vmap to complex rule validation (e.g., tile matching, win detection) can explode memory usage or trigger dynamic shape errors. Not all game logic benefits from vectorization.
Fix: Vectorize only independent, tensor-friendly operations. Use jnp.where for branching, and isolate heavy rule checks to a post-processing step that runs on a reduced batch or CPU fallback when necessary.
4. Ignoring Memory Bandwidth Constraints
Explanation: High-dimensional state tensors (e.g., full discard history, tile counts, player positions) consume GPU memory rapidly. Excessive tensor size causes OOM errors or forces frequent host-device transfers, killing throughput.
Fix: Compress state representations. Use int32 or uint8 for tile IDs, pack boolean flags into bitmasks, and drop historical data that isn't required for policy input. Profile memory with jax.profiler.
5. Visualization Blocking the Training Loop
Explanation: Synchronous rendering or logging inside the step function introduces Python callbacks that halt JIT compilation. The GPU sits idle while the CPU handles I/O.
Fix: Decouple visualization entirely. Use jax.device_get asynchronously, sample trajectories at fixed intervals, and render in a separate process or thread. Never call print() or matplotlib inside traced functions.
6. Reward Shaping in Multi-Agent Settings
Explanation: Naive reward assignment (e.g., +1 for winning, -1 for losing) creates sparse signals and high variance in four-player games. Agents struggle to credit actions across long horizons.
Fix: Implement rank-based rewards with proper normalization. Use advantage estimation tailored for zero-sum or general-sum multi-agent settings. Smooth rewards with temporal discounting and baseline subtraction.
7. JAX Tracing Overhead from Dynamic Shapes
Explanation: Variable-length discard piles, dynamic hand sizes, or conditional rule branches force JAX to retrace the function on every step. This negates JIT benefits and causes severe slowdowns.
Fix: Pad all tensors to fixed maximum dimensions. Use masking to ignore padded values. Mark rule variants as static_argnums in @jax.jit to prevent unnecessary retracing.
Production Bundle
Action Checklist
Decision Matrix
| Scenario | Recommended Approach | Why | Cost Impact |
|---|
| Research prototype / rapid iteration | CPU-sequential with supervised pretraining | Faster setup, lower hardware requirements, immediate baselines | Low compute cost, high data dependency |
| Production RL pipeline / self-play | JAX-vectorized on single A100 | Balances throughput (~250k steps/sec) with development speed | Moderate GPU cost, high sample efficiency |
| Multi-agent tournament simulation | Multi-GPU JAX with pmap | Maximizes parallelism (1M–2M steps/sec across 8x A100s) | High infrastructure cost, fastest convergence |
| Rule variant experimentation (red/no-red) | Static rule flags with static_argnums | Prevents JAX retracing, maintains throughput across variants | Negligible cost, improves stability |
Configuration Template
# config/env_config.py
import jax
import jax.numpy as jnp
class RiichiEnvConfig:
# Batch and parallelism settings
BATCH_SIZE: int = 4096
NUM_GPUS: int = 8
USE_PMAP: bool = True
# State dimensions (fixed for JAX tracing)
WALL_SIZE: int = 136
HAND_SIZE: int = 14
DISCARD_SIZE: int = 18
MAX_ROUNDS: int = 34
# Stochasticity
PRNG_SEED: int = 42
RNG_SPLIT_COUNT: int = BATCH_SIZE
# Training hyperparameters
GAMMA: float = 0.99
GAE_LAMBDA: float = 0.95
REWARD_NORMALIZATION: bool = True
RANK_BASED_REWARDS: bool = True
# Debugging
DEBUG_SAMPLE_INTERVAL: int = 1000
DEBUG_BATCH_IDX: int = 0
@classmethod
def validate(cls):
assert cls.BATCH_SIZE % cls.NUM_GPUS == 0, "Batch size must be divisible by GPU count"
assert cls.WALL_SIZE > 0, "Wall size must be positive"
return cls
Quick Start Guide
- Install dependencies:
pip install jax jaxlib flax optax (ensure CUDA-compatible jaxlib for your GPU driver)
- Initialize the environment: Load
RiichiEnvConfig, create batched initial states using jax.random.split, and JIT-compile the step function with @jax.jit
- Run a benchmark rollout: Execute
run_batched_rollout with a dummy policy, measure steps/sec using jax.profiler, and verify terminal state detection
- Attach a policy network: Replace the dummy policy with a Flax/Equinox model, integrate
optax for gradient updates, and enable pmap for multi-GPU scaling
- Validate with visualization: Sample trajectories at
DEBUG_SAMPLE_INTERVAL, extract debug traces, and confirm rule compliance before launching full training