YatNMN-Softplus d=12 Chinchilla (261M params)
A 261M-parameter nanochat-architecture GPT with the MLP feedforward swapped for YatNMN-Softplus β the YatNMN (mlnomadpy/nmn) Neural Matter Network layer with softplus_bias=True and learnable_epsilon=True.
Trained on English C4 to Chinchilla-optimal token budget (20Γ params β 5.22 B tokens) on a single TPU v6e-8 in europe-west4-a.
Final result
| Metric | Value |
|---|---|
| Final loss | 2.9802 π |
| Smooth final loss | 3.0588 |
| Tokens | 5.22 B |
| Steps | 19,922 |
| Wall time | 2.2 h |
| Throughput | ~600 K tok/s |
Comparison
Same architecture, same optimizer, same Chinchilla budget, same data β only the MLP activation/bias/epsilon treatment differs:
| MLP variant | Final loss | Ξ vs. best |
|---|---|---|
| Stock ReLUΒ² (nanochat default, README scaling law baseline) | 3.42 | +0.44 |
| Plain YatNMN (use_bias=True, learnable alpha, fixed Ξ΅=1e-3) | 3.13 | +0.15 |
| YatNMN-Softplus (softplus_bias=True, learnable_epsilon=True, learnable alpha) | 2.98 | β |
Softplus on bias and epsilon give a consistent 0.15 loss reduction vs plain YatNMN, and 0.44 vs stock ReLUΒ², at identical compute.
Architecture
Full nanochat stack with YatNMN-Softplus only in the MLP:
- RoPE (base 100K), no learned positional embeddings
- MHA (n_head = n_kv_head = 12, head_dim = 64; GQA-capable but all d=12 models use full MHA)
- QK norm with 1.2Γ scaling for sharper attention
- Sliding window attention with pattern
"SSSL"(short/short/short/long layers) - Tied embeddings (wte β lm_head)
- Parameterless RMSNorm post-embedding and per block
- Value embeddings (ResFormer-style, alternating layers)
- Per-layer learnable residual scalars (
resid_lambdas,x0_lambdas) - Smear β learnable bigram gate on first 24 dims of the token embedding
- Backout β subtract mid-layer (layer 6) residual to strip low-level features
- Logit soft-capping via
tanh(x/15)Β·15 - No biases in Linear layers (attn Q/K/V/proj, MLP out-projection)
- YatNMN-Softplus FFN:
y = Ξ± Β· (xΒ·W + softplus(b_raw))Β² / (βx β WβΒ² + softplus(Ξ΅_raw))Ξ±: learnable scalar, init = 1.0b_raw: learnable vector, init = 0 β effective bias = softplus(0) = ln(2) β 0.693Ξ΅_raw: learnable scalar, init so softplus(Ξ΅_raw) = 1e-3
See config.json for the exact values.
Training
- Data:
allenai/c4(English split, streamed) - Tokenizer:
mistralai/Mistral-7B-v0.1(vocab 32,768) - Sequence length: 1024
- Batch: 32/device Γ 8 devices = 256 global (262 K tokens/step)
- Optimizer: plain AdamW, Ξ² = (0.9, 0.999), wd = 0.01, global-norm grad clip at 1.0
- LR: warmup-cosine-decay, peak = 0.03, warmup = 500, decay_end = 5% of peak
- Seed: 0
- Budget: Chinchilla 20Γ params β 5.22 B tokens, 19,922 steps
- Hardware: TPU v6e-8 (TRC), europe-west4-a
The LR=0.03 comes from a d=12 LR sweep on the plain YatNMN variant; we carried it over to softplus since the two variants are near-identical at initialization (softplus-bias effective value at init is ~0.69, softplus-epsilon is configured to initialize at 1e-3).
Contents
.
βββ 19922/ β FINAL checkpoint, end of training
β βββ _CHECKPOINT_METADATA
β βββ metadata/ β JSON: {"final_loss": 2.9802, "smooth": 3.0588}
β βββ model/ β nnx.Param state (all weights)
β βββ optimizer/ β optax AdamW state (Adam m/v + step count)
βββ config.json β full architecture + training config
βββ README.md
The optimizer/ subdirectory makes this checkpoint resumable β flaxchat.checkpoint.restore_model_from_checkpoint can load both model and optimizer state for exact training continuation (Adam moments preserved, step count preserved, LR schedule resumes at the right position).
Loading
This model depends on the flaxchat codebase β the architecture (GPT class, Block, attention, value embeddings, smear/backout/soft-cap, RMSNorm), the checkpoint I/O, and the YatNMN-Softplus MLP patch all live in flaxchat + nmn. You need to clone flaxchat first, then use its code to load this checkpoint.
Step 1 β Clone flaxchat
git clone https://github.com/mlnomadpy/flaxchat
cd flaxchat
Step 2 β Install dependencies
With pixi (recommended, matches the training environment):
pixi install
Or with plain pip:
pip install "jax[tpu]" "flax>=0.11" "optax>=0.2.4" \
"orbax-checkpoint>=0.11" "nmn>=0.2.28" \
transformers datasets wandb
pip install -e .
Step 3 β Download the checkpoint
from huggingface_hub import snapshot_download
repo_dir = snapshot_download("mlnomad/yatnmn-softplus-d12-chinchilla-261M")
Or via CLI:
huggingface-cli download mlnomad/yatnmn-softplus-d12-chinchilla-261M --local-dir ./hf_model
Step 4 β Load the model
Save this as load_model.py inside your flaxchat clone (or grab the bundled copy from code/load_model.py in this HF repo):
# load_model.py β inside your flaxchat clone
from flax import nnx
import jax, jax.numpy as jnp
from nmn.nnx.layers import YatNMN
from flaxchat.gpt import GPT, GPTConfig, Block
from flaxchat.checkpoint import restore_model_from_checkpoint
# Patch Block to use YatNMN-Softplus MLP (must match training)
_orig_block = Block.__init__
def _block_init(self, config, layer_idx, *, rngs, use_remat=False):
_orig_block(self, config, layer_idx, rngs=rngs, use_remat=use_remat)
class YatFFN(nnx.Module):
def __init__(self, n, ff, *, rngs):
self.c_fc = YatNMN(
n, ff,
use_bias=True,
softplus_bias=True,
learnable_epsilon=True,
epsilon=1e-3,
rngs=rngs,
)
self.c_proj = nnx.Linear(ff, n, use_bias=False, rngs=rngs)
def __call__(self, x):
return self.c_proj(self.c_fc(x))
self.mlp = YatFFN(config.n_embd, 4 * config.n_embd, rngs=rngs)
Block.__init__ = _block_init
# Zero-init c_proj (matches training-time init for stable start)
_orig_init = GPT._init_weights
def _patched_init(self):
_orig_init(self)
for b in self.blocks:
b.mlp.c_proj.kernel[...] = jnp.zeros_like(b.mlp.c_proj.kernel[...])
GPT._init_weights = _patched_init
# Build the model with matching config
config = GPTConfig(
sequence_len=1024, vocab_size=32768,
n_layer=12, n_head=12, n_kv_head=12, n_embd=768,
window_pattern="SSSL", tie_embeddings=True,
)
model = GPT(config, rngs=nnx.Rngs(0))
# Load weights
import sys
ckpt_path = sys.argv[1] if len(sys.argv) > 1 else "./hf_model/19922"
meta = restore_model_from_checkpoint(model, ckpt_path)
print(f"Loaded β saved final loss: {meta['final_loss']:.4f}")
# Optional: quick generation
from transformers import AutoTokenizer
from flaxchat.engine import generate_fast
tok = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = tok("The quick brown fox", return_tensors="np").input_ids[0]
out = generate_fast(model, prompt, max_tokens=64, temperature=0.8)
print(tok.decode(out))
Run it:
python load_model.py ./hf_model/19922
Expected output: Loaded β saved final loss: 2.9802.
Reference code in this repo (code/ directory)
For inspection / inline reading, this HF repo also contains:
code/gpt.pyβ the GPT architecture (identical to flaxchat at the commit this was trained on)code/checkpoint.pyβ Orbax save/restore helpers (thesave_checkpointvariant that saves both model + optimizer state, so the19922/optimizer/dir round-trips cleanly)code/config.pyβGPTConfig+TrainingConfig+FlaxChatConfig.from_depthcode/train_d12_chinchilla.pyβ the exact training script (reproduces this run:--mlp yatnmn-softplus --lr 0.03 --seed 0 --save-every 2000 --chinchilla-mult 1)code/load_model.pyβ the loader above, ready to drop into your flaxchat clone
For actually running the model, clone flaxchat and use its code β the code/ files here are snapshots for reference; the flaxchat repo is the source of truth.
Training run
- W&B: irf-sic/flaxchat/runs/ljmvjqch
- Run name:
yatnmn-softplus-d12-chinchilla-lr0.03-seed0
Citation
If you use this model or the YatNMN-Softplus variant:
- nmn β github.com/mlnomadpy/nmn (YatNMN + softplus variants)
- flaxchat β github.com/mlnomadpy/flaxchat (JAX/Flax NNX, nanochat port)
License
Apache 2.0.
- Downloads last month
- 36