Core Idea: Binary Gating via Gumbel-Softmax
PonderTTT introduces Adaptive Test-Time Training with learned SKIP/UPDATE decisions. Instead of applying TTT updates uniformly to all input chunks, we learn when to update using a binary gating mechanism trained via Gumbel-Softmax.
| Feature |
Fixed TTT |
PonderTTT (Binary Gating) |
| Decision |
Always UPDATE |
SKIP or UPDATE per chunk |
| Training |
N/A |
Gumbel-Softmax (differentiable) |
| Inference |
Fixed cost |
True computational savings |
| Cost |
3.0x (UPDATE_1) |
2.67x (83% update rate) |
Key Results (GPT-2 125M on Python)
- 4.5x perplexity improvement over non-adaptive baseline (26.36 → 5.85)
- Strong OOD generalization: JavaScript (2.5x), Java (6.2x), Go (70x)
- Learned policy captures universal "when to adapt" patterns
Technical Architecture
This project is a pure JAX/Flax NNX rewrite of the official TTT-LM, enhanced with adaptive gating.
- Base Model: Pretrained GPT-2 (125M, 350M) with frozen backbone weights
- Fast-Weight Layer (
TTTLayer): TTT-Linear with causal convolutions and dual-form updates
- Binary Gating Network: Lightweight MLP that makes SKIP/UPDATE decisions via Gumbel-Softmax
- Training Objective: Top-k Discriminative Gating. Instead of an explicit cost penalty, we use an implicit budget constraint by training the gate to identify the top-k% chunks with the highest TTT advantage.
- Loss Function: \( L_{total} = L_{CE} + \beta \cdot L_{TTT} + L_{gate} \)
- \( L_{CE} \): Main task cross-entropy (always computed with TTT updates)
- \( L_{TTT} \): TTT reconstruction loss (auxiliary self-supervised task)
- \( L_{gate} \): Binary Cross-Entropy (BCE) loss aligning the gate with the oracle Top-k ranking
Roadmap & Status
The project is currently in active development. Phase 1 is complete with a preprint available.
Phase 1: Complete (Preprint)
- Pure NNX GPT-2, TTT Layer with Binary Gating
- Gumbel-Softmax training for SKIP/UPDATE decisions
- End-to-End differentiable training with Top-k Discriminative Gating
- Results on GPT-2 (125M, 350M) with OOD evaluation
Phase 2: Planned (Conference Submission)
See PLAN.md for detailed roadmap.
- Scale to Gemma 3 (4B, 12B): Validate on modern, production-relevant architectures.
- LoRA-TTT for Efficiency: Replace full TTT updates with Low-Rank Adaptation.
- Reasoning Benchmarks: MATH500, GSM8K, LiveCodeBench, GPQA-Diamond.
- Advanced Gating Features: Entropy, VOG, Attention Dispersion.
Quick Start
Installation
# Install uv if you do not have it yet
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install the project in editable mode
uv pip install -e . --group gpu # or tpu/cpu
Reproduce Paper Results (Recommended)
Run the full suite of experiments (Training, OOD Evaluation, Latency, Ablations) with a single script:
chmod +x scripts/run_all_experiments.sh
./scripts/run_all_experiments.sh
Manual Training
python -m ponderttt.experiments.train_hard_skip \
--model_scale 125m \
--target_update_rate 0.5 \
--num_iterations 10000 \
--output_dir outputs/hard_skip
Citation
@article{sim2025ponderttt,
title={Learning to Ponder: Adaptive Compute Allocation via Test-Time Training},
author={Sim, Gihyeon},
year={2025}
}