YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

ReMDM Planner for MiniHack

PyTorch implementation of ReMDM (Remasking Discrete Diffusion Model) for action-sequence planning in MiniHack navigation environments. A dual-stream transformer generates 64-step action plans by iteratively denoising masked token sequences, conditioned on a 9x9 local crop and the full 21x79 dungeon map.

The primary training method is DAgger with BFS oracle supervision: the model is trained from scratch, with the buffer seeded by pure expert trajectories on the first iteration. A standalone offline BC mode is also available as an independent baseline trained on pre-collected datasets. The paper compares both methods head-to-head; neither depends on the other. An offline BC checkpoint can optionally warm-start DAgger, but this is not used in the paper. Generalises zero-shot from 4 in-distribution environments to 3 out-of-distribution environments.


Pipeline

[Primary]  DAgger online training          main.py --mode dagger
               |  (seed buffer with oracle demos on iter 0,
               |   collect with model, label with oracle,
               |   efficiency filter, curriculum sampling)
               v  checkpoint
[Evaluate] ID + OOD evaluation             main.py --mode inference --checkpoint iter8000.pth
**Other modes:**

[Collect]     Collect oracle demonstrations main.py --mode collect
[Offline BC]  Train on pre-collected data   main.py --mode offline --data dataset.pt
[Smoke test]  Quick end-to-end check        main.py --mode smoke

DAgger trains from scratch and is the recommended pipeline. Offline BC (`--mode collect` + `--mode offline`) is an independent training method compared against DAgger in the paper. An offline BC checkpoint can optionally warm-start DAgger via `--checkpoint`, but this was not used in the paper results.

Environments

In-distribution (training):

Environment Description
MiniHack-Room-Random-5x5-v0 Small random room
MiniHack-Room-Random-15x15-v0 Large random room
MiniHack-Corridor-R2-v0 Two-room corridor
MiniHack-MazeWalk-9x9-v0 Small maze

Out-of-distribution (zero-shot evaluation):

Environment Description
MiniHack-Room-Dark-15x15-v0 Dark room (limited visibility)
MiniHack-Corridor-R5-v0 Five-room corridor
MiniHack-MazeWalk-45x19-v0 Large maze

Installation

Prerequisites

Python 3.12+ is required.

macOS (arm64): Install cmake via Homebrew (needed to compile nle from source):

brew install cmake

Linux (x86_64): Pre-built wheels are available, but if building from source:

sudo apt-get install build-essential cmake bison flex libbz2-dev

Setup

uv sync

This installs all dependencies from the lockfile, including nle>=1.2.0 (from the maintained NetHack-LE fork), minihack, torch>=2.11.0, wandb, polars, orjson, and scipy.

GPU support (optional)

By default PyTorch runs on CPU. For NVIDIA CUDA 12:

uv pip install torch --index-url https://download.pytorch.org/whl/cu121

Verify GPU is detected:

uv run python -c "import torch; print(torch.cuda.is_available())"

Usage

All modes share a single entry point. Defaults load from configs/defaults.yaml; any value can be overridden via key=value pairs.

python main.py --mode <MODE> [--config PATH] [key=value ...]

Smoke test

Collects a few oracle trajectories, trains under a tiny 5k env-step budget, and prints ID evaluation results.

python main.py --mode smoke

Collect oracle demonstrations

Run the BFS oracle across all 4 ID environments and save the trajectories as a .pt dataset for offline BC training. Uses multiprocessing for parallelism.

# Default: 5000 episodes per env, output to data/dataset.pt
python main.py --mode collect

# Custom episode count and output
python main.py --mode collect collect_episodes_per_env=2000 \
    collect_output=data/small_dataset.pt

# Fewer workers (default: 8)
python main.py --mode collect collect_num_workers=4

# Reproducible with fixed seed
python main.py --mode collect seed=42

The output .pt file is directly consumable by --mode offline:

python main.py --mode collect
python main.py --mode offline --data data/dataset.pt

Offline BC (optional)

Train the diffusion model on pre-collected oracle demonstrations. The run length is controlled by total_timesteps β€” each env-step of the unified budget corresponds to one dataset sample, so total gradient steps = total_timesteps // offline_batch_size.

Periodic ID + OOD evaluation runs during training on the cadence defined by id_eval_every_timesteps / ood_eval_every_timesteps (env-step units, converted internally to grad-step deltas via // offline_batch_size), mirroring the DAgger eval pattern. Results are logged to eval_id/ and eval_ood/ W&B namespaces.

python main.py --mode offline --data path/to/dataset.pt

# Shorter / longer run (the same knob the DAgger and SB3 baselines use):
python main.py --mode offline --data dataset.pt total_timesteps=500000

# Resume from a step-level checkpoint (restores optimizer, scheduler,
# step counter, and W&B run)
python main.py --mode offline --data path/to/dataset.pt \
    --checkpoint checkpoints/offline_step2000.pth

Step-level checkpoints are written every checkpoint_every_timesteps env-step equivalents (converted internally to / offline_batch_size grad steps). Set to 0 to disable:

python main.py --mode offline --data dataset.pt checkpoint_every_timesteps=0

Compute-match overrides (paper-fair BC vs DAgger)

For research comparisons against a specific DAgger checkpoint, four optional offline-only overrides bypass the env-step budget derivation. The sample-to-grad-step ratio between the two modes (~50Γ—) makes a single shared total_timesteps budget unfair to one side; these knobs pin offline metrics in grad-step units instead. All default to null (backwards compatible).

Key Purpose
offline_total_grad_steps Pin gradient budget. Overrides total_timesteps // offline_batch_size. Use to match a DAgger iteration count (e.g. 60000 = 600 iters Γ— 100 grad_steps_per_iter).
offline_eval_every_grad_steps ID/OOD eval cadence in grad-step units. Without this, env-step cadence applied to BC's dense per-sample budget yields hundreds of evals.
offline_checkpoint_every_grad_steps Checkpoint cadence in grad-step units. Same motivation.
offline_buffer_capacity Distinct from buffer_capacity (sized for DAgger's small FIFO). The full BC dataset has ~500k–1M sliding windows; using DAgger's cap silently truncates.

Example: train a fair offline BC baseline matched to DAgger@iter600 (60k AdamW updates Γ— 2048 batch):

python main.py --mode offline --data data/oracle_bc_qmul.pt \
    --config configs/final_qmul_gpu.yaml

The final_qmul_gpu.yaml and final_ucl_gpu.yaml configs both ship with these overrides pre-set and with cross-cluster-identical training hyperparameters (only collection-worker counts and output paths differ).

DAgger online training

Full DAgger loop: seed buffer with oracle data, collect with model, label with BFS oracle, filter by efficiency, train on buffer.

# From scratch (seeds buffer with oracle data automatically)
python main.py --mode dagger

# Resume from local checkpoint
python main.py --mode dagger --checkpoint checkpoints/iter3000.pth

# Resume from a W&B artifact
python main.py --mode dagger \
    --wandb-artifact entity/project/checkpoint-iter3000:latest

# Skip warm-start from checkpoint (reinitialise model, keep config)
python main.py --mode dagger --checkpoint checkpoints/iter3000.pth --no-warm-start

# Override hyperparameters (total_timesteps is the unified run-length knob)
python main.py --mode dagger total_timesteps=1000000 dagger_lr=0.0001

# Use a GPU-optimised config (paper run, QMUL H200)
python main.py --mode dagger --config configs/final_qmul_gpu.yaml

Inference

Evaluate a checkpoint on specified environments. Accepts either --checkpoint (local path) or --wandb-artifact (W&B artifact reference).

# All ID + OOD environments
python main.py --mode inference --checkpoint checkpoints/iter8000.pth

# From a W&B artifact
python main.py --mode inference \
    --wandb-artifact entity/project/checkpoint-iter8000:latest

# Specific environments, save JSON
python main.py --mode inference \
    --checkpoint checkpoints/iter8000.pth \
    --envs MiniHack-Room-Random-5x5-v0 MiniHack-MazeWalk-45x19-v0 \
    --episodes 100 \
    --output results.json

# Custom .des scenario files
python main.py --mode inference \
    --checkpoint checkpoints/iter8000.pth \
    --des environments/custom_level.des

# Local-only ablation (zero out global map)
python main.py --mode inference \
    --checkpoint checkpoints/iter8000.pth --blind-global

# Use training weights instead of EMA
python main.py --mode inference --checkpoint iter8000.pth --no-ema

Baselines (SB3 + Decision Transformer)

Train and evaluate the head-to-head baselines used in the paper comparison. Six algorithms are wired in: standard discrete-action RL via Stable-Baselines3 (ppo, a2c, dqn, ppo-rnn), Behavioural Cloning (bc) on oracle demonstrations, and a causal Decision Transformer (dt) with target-return conditioning. All six share the unified cfg.total_timesteps budget so the numbers are directly comparable to DAgger and offline BC.

Hyperparameters live under the baselines_* namespace in configs/defaults.yaml (BC epochs / batch / LR, DT context length / depth / width, oracle episodes per env, eval cadence, DQN replay buffer, parallel SubprocVecEnv count, etc.). The runner writes per-seed checkpoints, SB3 logs, and an aggregated results JSON under cfg.baselines_output_dir (default outputs/baselines/); W&B runs land in a separate project (cfg.baselines_wandb_project, default remdm-baselines) so they don't pollute the main training leaderboards.

# PPO on the 4 ID maps for the unified env-step budget, 1 seed
python main.py --mode baselines --algo ppo

# DQN with a custom budget and 3 seeds
python main.py --mode baselines --algo dqn \
    --seeds 0 1 2 \
    total_timesteps=1000000

# Behavioural Cloning baseline (oracle demos -> SB3 ActorCriticPolicy)
python main.py --mode baselines --algo bc --n-seeds 3

# Decision Transformer (causal R/s/a transformer with target-return)
python main.py --mode baselines --algo dt --seeds 0 1 2

# Override the aggregated-results JSON destination
python main.py --mode baselines --algo ppo --output results/ppo_smoke.json

# Paper-fair comparison against the ReMDM online budget (~5.65M env-steps)
python main.py --mode baselines --algo ppo total_timesteps=5650000

The BC and DT defaults (50 epochs, 5000 oracle trajectories per ID env, 64-token DT context, 256-D DT embedding) are tuned to match the data and compute scale of the offline BC and ReMDM runs reported in the paper.

CLI flags

Flag Description
--mode Required. One of smoke, collect, offline, dagger, inference, baselines
--config PATH Config file (default: configs/defaults.yaml)
--algo NAME Baseline algorithm (ppo, a2c, dqn, ppo-rnn, bc, dt); required with --mode baselines
--seeds N [N ...] Explicit seed list for --mode baselines
--n-seeds N Number of seeds starting from 0 (alternative to --seeds)
--data PATH Dataset .pt file (offline mode)
--checkpoint PATH Checkpoint .pth file
--wandb-artifact REF W&B artifact reference (e.g. entity/project/name:latest)
--no-warm-start Skip model warm-start from checkpoint (DAgger)
--no-ema Use training weights instead of EMA for inference
--envs ENV [ENV ...] Override evaluation environments
--des PATH [PATH ...] Custom .des scenario files for evaluation
--episodes N Episodes per environment (default: 50)
--output PATH Save evaluation results / aggregated baselines JSON
--blind-global Zero out global map observations (local-only ablation)

Architecture

LocalDiffusionPlannerWithGlobal (~5.2M parameters):

Local stream:   9x9 glyphs -> Embedding(6000,64) -> CNN(64->32->64) -> Linear -> 1 token
Global stream:  21x79 glyphs -> Embedding(6000,32) -> CNN(32->32->64) -> Pool(2,4) -> 8 tokens
                Goal head: mean(global) -> MLP -> [B,2] staircase coords (aux loss)
                Gate: sigmoid(learnable scalar, init=-3.0) * global_tokens
Action stream:  Embedding(14, 256) + timestep_emb(100, 256) + position_emb(64, 256)
Transformer:    concat [1 + 8 + 64 = 73 tokens] -> 4-layer encoder (256D, 4 heads, pre-norm)
Output head:    last 64 tokens -> Linear(256, 12) -> action logits

The model takes (local_obs, global_obs, noisy_action_seq, t_discrete) and returns {"actions": [B,64,12], "goal_pred": [B,2]}.

A LocalDiffusionPlanner variant (no global stream, no goal head) is also available for ablation studies.


Diffusion

Forward process (MDLM): Each action token is independently replaced with MASK (token 12) with probability 1 - alpha(t), where alpha(t) follows a linear or cosine schedule. PAD tokens (13) are never masked.

Loss: Cross-entropy on masked positions only, averaged globally across the batch. By default uses a flat average (matching the reference implementation). Optional SUBS importance weighting w(t) = -alpha'(t) / (1 - alpha(t)), clipped to [0, 1000], can be enabled via use_importance_weighting: true. Optional label smoothing via label_smoothing (default 0.0).

Reverse sampling (ReMDM): Over K denoising steps (default 10):

  1. Model predicts logits; apply temperature scaling and top-K filtering.
  2. Sample predictions; compute per-token confidence.
  3. MaskGIT unmask: commit the n_unmask highest-confidence masked positions.
  4. ReMDM remask: stochastically re-mask committed positions to allow refinement.
  5. Final step: commit all remaining positions.

Greedy sampling: Used during DAgger data collection for deterministic rollouts. Same MaskGIT progressive unmasking loop but with argmax decoding (no temperature, no top-K, no remasking). Uses fewer denoising steps (diffusion_steps_collect: 5) for faster collection.

Remasking strategies

Strategy Formula Description
rescale p = eta * sigma_max Proportional to noise level
cap p = min(eta, sigma_max) Fixed upper bound
conf p = eta * sigma_max * (1 - confidence) Low-confidence tokens remasked more

Configuration

Key hyperparameters

Model

Parameter Default Description
n_embd 256 Transformer hidden dimension
n_head 4 Attention heads
n_layer 4 Transformer blocks
n_global_tokens 8 Global stream context tokens
seq_len 64 Action plan length
dropout 0.0 Transformer dropout (0.0 -- forward masking regularises)
ema_decay 0.999 EMA smoothing for inference weights
global_gate_init -3.0 Initial value for global gate logit

Diffusion

Parameter Default Description
noise_schedule linear linear or cosine
num_diffusion_steps 100 Discrete timestep resolution
diffusion_steps_eval 10 Denoising iterations at inference
diffusion_steps_collect 5 Denoising iterations during DAgger collection
remask_strategy conf rescale, cap, or conf
eta 0.15 Remasking strength
temperature 0.5 Sampling temperature
top_k 4 Top-K filtering
replan_every 16 Env steps before replanning
loss_weight_clip 1000.0 SUBS importance weight clip bound
label_smoothing 0.0 Label smoothing for cross-entropy
use_importance_weighting false SUBS w(t) in loss (off = flat average)
physics_aware_sampling false Penalise hazardous actions at inference

Training budget (unified)

Offline BC, DAgger, and the SB3 baselines all share a single env-step budget expressed in total_timesteps (matching the SB3 convention). This is the only knob that should change to scale a run up or down.

Parameter Default Description
total_timesteps 2,000,000 Env-step budget shared across offline / DAgger / SB3
id_eval_every_timesteps 25,000 ID eval cadence (env-steps)
ood_eval_every_timesteps 25,000 OOD eval cadence (env-steps)
checkpoint_every_timesteps 125,000 Checkpoint cadence (env-steps)
  • Offline BC: each dataset sample is one env.step() equivalent, so total gradient steps = total_timesteps // offline_batch_size. The cosine LR schedule's T_max derives from the same quantity, so runs of different lengths still decay to the 10% floor at their end.
  • DAgger: the training loop tracks cumulative env.step() calls (model + oracle rollouts combined) and halts when the running total reaches total_timesteps. episodes_per_iteration and grad_steps_per_iteration control the collect/train ratio but must not scale with the budget.
  • Fairness caveat β€” ema_decay: this is an absolute-update-count constant (half-life ~ 1 / (1 βˆ’ decay) steps). If total_timesteps shifts by more than ~2Γ— from the default, the fraction of training covered by the EMA window changes. For very short or very long runs, consider setting a matching decay manually.

Training

Parameter Default Description
offline_lr 0.0003 BC learning rate (cosine-decayed to 10% over total_grad_steps)
dagger_lr 0.00003 DAgger learning rate (constant)
offline_batch_size 3584 Offline BC batch size
dagger_batch_size 3584 DAgger batch size
offline_grad_clip 1.0 Gradient norm clip (offline)
dagger_grad_clip 1.0 Gradient norm clip (DAgger)
weight_decay 0.0001 AdamW weight decay (both optimizers)
grad_steps_per_iteration 100 Gradient steps per DAgger iteration
episodes_per_iteration 30 Episodes collected per DAgger iteration
aux_loss_weight 0.5 Weight for auxiliary goal loss
buffer_capacity 10000 Replay buffer size (windows)
efficiency_multiplier 1.5 DAgger efficiency filter threshold
curriculum_preseed true Pre-seed curriculum with 50/50 prior
curriculum_queue_size 100 Curriculum window size per environment

Data Collection

Parameter Default Description
collect_episodes_per_env 5000 Oracle episodes per ID environment
collect_num_workers 8 Parallel process workers for collection
collect_output data/dataset.pt Output path for collected dataset

Evaluation

Parameter Default Description
eval_episodes_per_env 50 Episodes per environment at eval time
checkpoint_eval_episodes 50 Episodes per env at checkpoint eval

(Eval and checkpoint cadences are expressed in env-steps under Training budget (unified) above.)

Performance

Parameter Default Description
use_amp false Mixed-precision (FP16) training via torch.amp
torch_compile false torch.compile the model for fused kernels
num_collection_workers 8 Parallel workers for DAgger episode collection

Logging

Parameter Default Description
use_wandb true Enable W&B logging
wandb_project remdm-minihack W&B project name
wandb_resume_id null W&B run ID for resumption
offline_log_every 10 Stdout/W&B log frequency (offline steps)
seed null RNG seed (null = random)

Config presets

File Purpose
configs/defaults.yaml Base defaults for all modes
configs/smoke.yaml Fast smoke test (total_timesteps=5000, small buffer, W&B off)
configs/ucl_gpu_bigger_model.yaml UCL GPU exploration with a larger model (384D, 6 heads)
configs/ucl_gpu_learning_behaviour.yaml UCL GPU learning-behaviour study (eta=0.18, B=6144)
configs/final_qmul_gpu.yaml Paper run, QMUL H200. Drives both --mode dagger (reproduces the iter600 checkpoint) and --mode offline (compute-matched fair BC baseline: 60k grad steps Γ— B=2048). AMP + torch.compile + 32 collection workers.
configs/final_ucl_gpu.yaml Paper run, UCL 3090 Ti 24 GB. Identical training hyperparams to the QMUL config for cross-cluster fairness; only num_collection_workers (8 instead of 32) and output paths differ.

DAgger Training Loop

Each DAgger iteration:

  1. Curriculum sampling: Select an environment weighted by difficulty (low win-rate environments sampled more).
  2. Model rollout: Generate plans with the EMA model using greedy sampling; execute with replanning every 16 steps. Collects episodes_per_iteration (default 30) episodes per iteration.
  3. Oracle rollout: Run the BFS oracle on the same seed for comparison.
  4. Efficiency filter: Add the oracle trajectory to the buffer if the model failed or took >1.5x the oracle's steps.
  5. Budget accounting: Advance env_steps_total += model_steps + oracle_steps. The training loop halts when the running total reaches total_timesteps.
  6. Training: Sample from the replay buffer; run grad_steps_per_iteration gradient steps, updating EMA weights after each gradient step.

Collection uses GPU-batched rollouts when on CUDA with episodes_per_iteration > 1, falling back to threaded CPU collection or sequential collection as appropriate.

The BFS oracle uses a 5-tier priority: (1) kick adjacent doors, (2) BFS to staircase, (3) BFS to frontier, (4) BFS to farthest tile, (5) random cardinal.


Reward Shaping

The environment wrapper applies shaped rewards to guide learning:

Component Value Condition
Win bonus +20.0 Episode won
BFS progress +0.5 * (prev_dist - curr_dist) Closer to staircase
Exploration +0.05 New tile visited
Step penalty -0.01 Every step

Project Structure

minihack-ReMDM-planner/
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ defaults.yaml                   Base hyperparameters
β”‚   β”œβ”€β”€ smoke.yaml                      Smoke test overrides
β”‚   β”œβ”€β”€ ucl_gpu_bigger_model.yaml       UCL GPU (larger model: 384D, 6 heads)
β”‚   β”œβ”€β”€ ucl_gpu_learning_behaviour.yaml UCL GPU learning-behaviour study
β”‚   β”œβ”€β”€ final_qmul_gpu.yaml             Paper run: DAgger + fair offline BC (QMUL H200)
β”‚   └── final_ucl_gpu.yaml              Paper run: DAgger + fair offline BC (UCL 3090 Ti)
β”œβ”€β”€ environments/                      Custom .des scenario files
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ config.py                      YAML config loader with CLI overrides
β”‚   β”œβ”€β”€ buffer.py                      ReplayBuffer with offline-protected FIFO
β”‚   β”œβ”€β”€ curriculum.py                  DynamicCurriculum + efficiency_filter
β”‚   β”œβ”€β”€ diffusion/
β”‚   β”‚   β”œβ”€β”€ schedules.py               Linear and cosine noise schedules
β”‚   β”‚   β”œβ”€β”€ forward.py                 Forward masking process q(z_t | x_0)
β”‚   β”‚   β”œβ”€β”€ loss.py                    MDLM ELBO + auxiliary goal loss
β”‚   β”‚   └── sampling.py                ReMDM reverse sampling with remasking
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   └── denoiser.py                LocalDiffusionPlannerWithGlobal + ModelEMA
β”‚   β”œβ”€β”€ envs/
β”‚   β”‚   β”œβ”€β”€ minihack_env.py            AdvancedObservationEnv + BFS oracle
β”‚   β”‚   └── discovery.py               Env registry scanner + inference benchmark
β”‚   └── planners/
β”‚       β”œβ”€β”€ collect.py                 run_model_episode + DataCollector
β”‚       β”œβ”€β”€ collect_oracle.py          Standalone oracle data collection
β”‚       β”œβ”€β”€ offline.py                 Offline BC trainer
β”‚       β”œβ”€β”€ online.py                  DAgger Trainer + checkpointing
β”‚       β”œβ”€β”€ inference.py               Evaluator + result formatting
β”‚       β”œβ”€β”€ baselines.py               SB3 + Decision Transformer baselines
β”‚       β”œβ”€β”€ smoke.py                   Smoke-test runner
β”‚       └── logging.py                 Centralised W&B + stdout logging
β”œβ”€β”€ experiments/
β”‚   └── rl_finetuning/                 RL fine-tuning ablation suite
β”‚       β”œβ”€β”€ run_ablations.py           CLI entry point
β”‚       β”œβ”€β”€ configs/                   Ablation config files
β”‚       β”œβ”€β”€ ablations/                 Loss, optimizer, registry, training
β”‚       β”œβ”€β”€ diagnostics/               Gradient, representation, timestep metrics
β”‚       └── analysis/                  Plots, tables, reports
β”œβ”€β”€ scripts/
β”‚   β”œβ”€β”€ hf_upload.py                   HuggingFace Hub upload utility
β”‚   └── profile_dagger.py             DAgger iteration profiler
β”œβ”€β”€ main.py                            CLI entry point (smoke/collect/offline/dagger/inference/baselines)
β”œβ”€β”€ pyproject.toml                     PEP 621 project metadata + dependencies
β”œβ”€β”€ uv.lock                            Deterministic lockfile
└── README.md

W&B Metric Namespaces

Namespace Contents
diffusion/ loss, loss_diff, loss_aux
train/ buffer_size, buffer_online_frac, model_won, added_to_buffer, episodes_collected, model_steps, oracle_steps, efficiency_ratio, lr, grad_norm, global_gate, env_steps, progress
speed/ iter_time_sec, collect_time_sec, train_step_time_sec, samples_per_sec, env_steps_per_sec, gpu_memory_mb
perf/ iter_time_s, collect_time_s, train_time_s, grad_steps_per_sec (legacy compat)
model/ param_norm, param_drift_from_init, ema_gate_value (every 10 iters)
eval_id/{env}/ Per-environment win rate, avg steps, avg reward (in-distribution)
eval_ood/{env}/ Per-environment win rate, avg steps, avg reward (out-of-distribution)
eval_id/ mean_win_rate
eval_ood/ mean_win_rate
curriculum/{env}/ win_rate per training environment
ckpt_eval_id/, ckpt_eval_ood/ Per-env metrics at checkpoint time
ckpt_eval/ id_winrate, ood_winrate
offline/ final_loss, total_steps, total_timesteps (summary only)

Both DAgger and offline BC emit to eval_id/ and eval_ood/ namespaces. Offline mode reuses the same Evaluator and EMA-weight evaluation path as DAgger, so curves are directly comparable across modes.


Checkpoint Format

DAgger checkpoint:

{
    "model_state_dict":     ...,
    "ema_state_dict":       ...,
    "optimizer_state_dict": ...,
    "scheduler_state_dict": ...,
    "curriculum_state":     {...},
    "iteration":            int,
    "env_steps":            int,   # cumulative env.step() calls so far
    "wandb_run_id":         str | None,
    "rng_states":           {"torch", "numpy", "python"},
}

Offline BC checkpoint (step-level, file offline_step{N}.pth, saved when checkpoint_every_timesteps > 0):

{
    "model_state_dict":     ...,
    "ema_state_dict":       ...,
    "optimizer_state_dict": ...,
    "scheduler_state_dict": ...,
    "step":                 int,
    "env_steps":            int,   # step * offline_batch_size
    "wandb_run_id":         str | None,
}

Offline final checkpoint (saved at the end of offline training):

{
    "model_state_dict":     ...,
    "ema_state_dict":       ...,
    "wandb_run_id":         str | None,
}

Inference uses EMA weights by default. Pass --no-ema to use training weights.

W&B Artifacts

Checkpoints are automatically uploaded as versioned W&B artifacts (type "model") at each checkpoint save. Each artifact contains the .pth weights and a config.yaml snapshot of all hyperparameters used.

To resume from an artifact:

# DAgger resume
python main.py --mode dagger \
    --wandb-artifact entity/project/checkpoint-iter3000:latest

# Inference
python main.py --mode inference \
    --wandb-artifact entity/project/checkpoint-iter8000:v2

The artifact reference format is entity/project/artifact-name:version where version is latest, v0, v1, etc.

W&B Run Resumption

All training loops save the W&B run ID in their checkpoints. When resuming from a checkpoint, the run ID is automatically extracted and passed to wandb.init(resume="must"), so metrics continue on the same W&B curves with no gaps.

# DAgger: automatic -- run ID is read from the checkpoint
python main.py --mode dagger --checkpoint checkpoints/iter2000.pth

# Offline BC: automatic
python main.py --mode offline --data dataset.pt \
    --checkpoint checkpoints/offline_step2000.pth

# Manual override (e.g. checkpoint saved before this feature was added):
python main.py --mode dagger --checkpoint old_checkpoint.pth \
    wandb_resume_id=abc123xyz

# Ablation suite:
python experiments/rl_finetuning/run_ablations.py \
    --checkpoint path/to/ckpt.pth --all --use_wandb \
    --wandb_resume_id abc123xyz

The run ID is visible in the W&B dashboard URL: wandb.ai/.../runs/<run-id>.


Performance Tuning

Three config keys control performance optimisations. Defaults are set for GPU training; override for CPU or different hardware.

Mixed precision (use_amp: true)

Wraps training forward/backward in torch.amp.autocast("cuda") with GradScaler. Active in both offline BC and DAgger training.

  • Measured speedup: 2.2x on gradient steps, 1.7x on full smoke test wall-clock
  • Memory: peak GPU stays ~16 GB at B=3584 (same as FP32 due to embedding-heavy model)
  • Correctness: loss trajectory and win rates statistically equivalent to FP32
  • When to use: always on GPU. No effect on CPU (autocast is a no-op)
  • Default: false in defaults.yaml; enabled in GPU-specific configs

torch.compile (torch_compile: true)

Applies torch.compile(model, mode="default") before training. Falls back gracefully if no C compiler is found (common on managed GPU nodes).

  • Measured speedup: none beyond AMP alone. Not recommended for primary training.
  • Default: false in defaults.yaml; opt in via the final_*_gpu.yaml configs.
  • When to use: experimental only. May help on future PyTorch versions with better dynamic shape support.

Parallel collection (num_collection_workers: N)

DAgger episode collection supports three strategies (auto-selected):

  1. GPU-batched (default on CUDA with episodes_per_iteration > 1): all envs in lockstep
  2. Threaded CPU (fallback when num_collection_workers > 0): ThreadPoolExecutor with CPU model copies
  3. Sequential (reference behaviour): one episode at a time
  • Default: 8 workers in defaults.yaml
  • When to use: GPU-batched is preferred; workers primarily affect the CPU fallback path

Profiling

Run python scripts/profile_dagger.py [key=value ...] to profile DAgger iteration components. Supports all config overrides (e.g., use_amp=true).


Implementation Notes

  • MDLM loss returns 0.0 (not NaN) when no masked positions exist in the batch. Uses global averaging by default; SUBS importance weighting is opt-in via use_importance_weighting: true.
  • PAD tokens are never masked during the forward process and are excluded from the loss.
  • Sampling paths: Evaluation uses stochastic ReMDM sampling (temperature, top-K, remasking) with diffusion_steps_eval (default 10) steps. DAgger collection uses greedy argmax sampling (deterministic, no remasking) with diffusion_steps_collect (default 5) steps for faster rollouts.
  • remdm_sample guarantees a fully committed output (no MASK tokens) via a final-step commit and an assertion check. A min-keep 10% safety net prevents degenerate all-masked states.
  • EMA shadow weights are updated after every gradient step (not per iteration). The DataCollector syncs the latest EMA weights before each rollout.
  • Curriculum initialises with a 50/50 prior per environment (configurable via curriculum_preseed) and uses bucket-based weights over the rolling win-rate: low [0, 0.15) β†’ 0.2, medium [0.15, 0.85) β†’ 1.0, high [0.85, 1.0] β†’ 0.1.
  • Replay buffer pins offline data at the front; only online samples are FIFO-evicted. Returns None on empty buffer (callers handle gracefully).
  • Global gate initialises at sigmoid(-3.0) ~ 0.047, starting nearly closed to prevent the global stream from destabilising early training.
  • Dropout is set to 0.0 by default. The discrete diffusion forward masking already regularises; dropout on top is redundant.
  • DAgger warm-start: On iteration 0, the buffer is seeded with 3 oracle trajectories per ID environment (12 total), giving the curriculum and training loop data to work with immediately.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support