YatNMN-Softplus d=12 Chinchilla (261M params)

A 261M-parameter nanochat-architecture GPT with the MLP feedforward swapped for YatNMN-Softplus β€” the YatNMN (mlnomadpy/nmn) Neural Matter Network layer with softplus_bias=True and learnable_epsilon=True.

Trained on English C4 to Chinchilla-optimal token budget (20Γ— params β‰ˆ 5.22 B tokens) on a single TPU v6e-8 in europe-west4-a.

Final result

Metric Value
Final loss 2.9802 πŸ†
Smooth final loss 3.0588
Tokens 5.22 B
Steps 19,922
Wall time 2.2 h
Throughput ~600 K tok/s

Comparison

Same architecture, same optimizer, same Chinchilla budget, same data β€” only the MLP activation/bias/epsilon treatment differs:

MLP variant Final loss Ξ” vs. best
Stock ReLUΒ² (nanochat default, README scaling law baseline) 3.42 +0.44
Plain YatNMN (use_bias=True, learnable alpha, fixed Ξ΅=1e-3) 3.13 +0.15
YatNMN-Softplus (softplus_bias=True, learnable_epsilon=True, learnable alpha) 2.98 β€”

Softplus on bias and epsilon give a consistent 0.15 loss reduction vs plain YatNMN, and 0.44 vs stock ReLUΒ², at identical compute.

Architecture

Full nanochat stack with YatNMN-Softplus only in the MLP:

  • RoPE (base 100K), no learned positional embeddings
  • MHA (n_head = n_kv_head = 12, head_dim = 64; GQA-capable but all d=12 models use full MHA)
  • QK norm with 1.2Γ— scaling for sharper attention
  • Sliding window attention with pattern "SSSL" (short/short/short/long layers)
  • Tied embeddings (wte ↔ lm_head)
  • Parameterless RMSNorm post-embedding and per block
  • Value embeddings (ResFormer-style, alternating layers)
  • Per-layer learnable residual scalars (resid_lambdas, x0_lambdas)
  • Smear β€” learnable bigram gate on first 24 dims of the token embedding
  • Backout β€” subtract mid-layer (layer 6) residual to strip low-level features
  • Logit soft-capping via tanh(x/15)Β·15
  • No biases in Linear layers (attn Q/K/V/proj, MLP out-projection)
  • YatNMN-Softplus FFN:
    y = Ξ± Β· (xΒ·W + softplus(b_raw))Β² / (β€–x βˆ’ Wβ€–Β² + softplus(Ξ΅_raw))
    
    • Ξ±: learnable scalar, init = 1.0
    • b_raw: learnable vector, init = 0 β†’ effective bias = softplus(0) = ln(2) β‰ˆ 0.693
    • Ξ΅_raw: learnable scalar, init so softplus(Ξ΅_raw) = 1e-3

See config.json for the exact values.

Training

  • Data: allenai/c4 (English split, streamed)
  • Tokenizer: mistralai/Mistral-7B-v0.1 (vocab 32,768)
  • Sequence length: 1024
  • Batch: 32/device Γ— 8 devices = 256 global (262 K tokens/step)
  • Optimizer: plain AdamW, Ξ² = (0.9, 0.999), wd = 0.01, global-norm grad clip at 1.0
  • LR: warmup-cosine-decay, peak = 0.03, warmup = 500, decay_end = 5% of peak
  • Seed: 0
  • Budget: Chinchilla 20Γ— params β‰ˆ 5.22 B tokens, 19,922 steps
  • Hardware: TPU v6e-8 (TRC), europe-west4-a

The LR=0.03 comes from a d=12 LR sweep on the plain YatNMN variant; we carried it over to softplus since the two variants are near-identical at initialization (softplus-bias effective value at init is ~0.69, softplus-epsilon is configured to initialize at 1e-3).

Contents

.
β”œβ”€β”€ 19922/                          ← FINAL checkpoint, end of training
β”‚   β”œβ”€β”€ _CHECKPOINT_METADATA
β”‚   β”œβ”€β”€ metadata/                   ← JSON: {"final_loss": 2.9802, "smooth": 3.0588}
β”‚   β”œβ”€β”€ model/                      ← nnx.Param state (all weights)
β”‚   └── optimizer/                  ← optax AdamW state (Adam m/v + step count)
β”œβ”€β”€ config.json                     ← full architecture + training config
└── README.md

The optimizer/ subdirectory makes this checkpoint resumable β€” flaxchat.checkpoint.restore_model_from_checkpoint can load both model and optimizer state for exact training continuation (Adam moments preserved, step count preserved, LR schedule resumes at the right position).

Loading

This model depends on the flaxchat codebase β€” the architecture (GPT class, Block, attention, value embeddings, smear/backout/soft-cap, RMSNorm), the checkpoint I/O, and the YatNMN-Softplus MLP patch all live in flaxchat + nmn. You need to clone flaxchat first, then use its code to load this checkpoint.

Step 1 β€” Clone flaxchat

git clone https://github.com/mlnomadpy/flaxchat
cd flaxchat

Step 2 β€” Install dependencies

With pixi (recommended, matches the training environment):

pixi install

Or with plain pip:

pip install "jax[tpu]" "flax>=0.11" "optax>=0.2.4" \
            "orbax-checkpoint>=0.11" "nmn>=0.2.28" \
            transformers datasets wandb
pip install -e .

Step 3 β€” Download the checkpoint

from huggingface_hub import snapshot_download
repo_dir = snapshot_download("mlnomad/yatnmn-softplus-d12-chinchilla-261M")

Or via CLI:

huggingface-cli download mlnomad/yatnmn-softplus-d12-chinchilla-261M --local-dir ./hf_model

Step 4 β€” Load the model

Save this as load_model.py inside your flaxchat clone (or grab the bundled copy from code/load_model.py in this HF repo):

# load_model.py β€” inside your flaxchat clone
from flax import nnx
import jax, jax.numpy as jnp
from nmn.nnx.layers import YatNMN
from flaxchat.gpt import GPT, GPTConfig, Block
from flaxchat.checkpoint import restore_model_from_checkpoint

# Patch Block to use YatNMN-Softplus MLP (must match training)
_orig_block = Block.__init__
def _block_init(self, config, layer_idx, *, rngs, use_remat=False):
    _orig_block(self, config, layer_idx, rngs=rngs, use_remat=use_remat)
    class YatFFN(nnx.Module):
        def __init__(self, n, ff, *, rngs):
            self.c_fc = YatNMN(
                n, ff,
                use_bias=True,
                softplus_bias=True,
                learnable_epsilon=True,
                epsilon=1e-3,
                rngs=rngs,
            )
            self.c_proj = nnx.Linear(ff, n, use_bias=False, rngs=rngs)
        def __call__(self, x):
            return self.c_proj(self.c_fc(x))
    self.mlp = YatFFN(config.n_embd, 4 * config.n_embd, rngs=rngs)
Block.__init__ = _block_init

# Zero-init c_proj (matches training-time init for stable start)
_orig_init = GPT._init_weights
def _patched_init(self):
    _orig_init(self)
    for b in self.blocks:
        b.mlp.c_proj.kernel[...] = jnp.zeros_like(b.mlp.c_proj.kernel[...])
GPT._init_weights = _patched_init

# Build the model with matching config
config = GPTConfig(
    sequence_len=1024, vocab_size=32768,
    n_layer=12, n_head=12, n_kv_head=12, n_embd=768,
    window_pattern="SSSL", tie_embeddings=True,
)
model = GPT(config, rngs=nnx.Rngs(0))

# Load weights
import sys
ckpt_path = sys.argv[1] if len(sys.argv) > 1 else "./hf_model/19922"
meta = restore_model_from_checkpoint(model, ckpt_path)
print(f"Loaded β€” saved final loss: {meta['final_loss']:.4f}")

# Optional: quick generation
from transformers import AutoTokenizer
from flaxchat.engine import generate_fast
tok = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = tok("The quick brown fox", return_tensors="np").input_ids[0]
out = generate_fast(model, prompt, max_tokens=64, temperature=0.8)
print(tok.decode(out))

Run it:

python load_model.py ./hf_model/19922

Expected output: Loaded β€” saved final loss: 2.9802.

Reference code in this repo (code/ directory)

For inspection / inline reading, this HF repo also contains:

  • code/gpt.py β€” the GPT architecture (identical to flaxchat at the commit this was trained on)
  • code/checkpoint.py β€” Orbax save/restore helpers (the save_checkpoint variant that saves both model + optimizer state, so the 19922/optimizer/ dir round-trips cleanly)
  • code/config.py β€” GPTConfig + TrainingConfig + FlaxChatConfig.from_depth
  • code/train_d12_chinchilla.py β€” the exact training script (reproduces this run: --mlp yatnmn-softplus --lr 0.03 --seed 0 --save-every 2000 --chinchilla-mult 1)
  • code/load_model.py β€” the loader above, ready to drop into your flaxchat clone

For actually running the model, clone flaxchat and use its code β€” the code/ files here are snapshots for reference; the flaxchat repo is the source of truth.

Training run

Citation

If you use this model or the YatNMN-Softplus variant:

License

Apache 2.0.

Downloads last month
36
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for mlnomad/yatnmn-softplus-d12-chinchilla-261M

Finetunes
2 models

Dataset used to train mlnomad/yatnmn-softplus-d12-chinchilla-261M