YatNMN-Softplus d=12 Chinchilla (261M) β PyTorch / HuggingFace Transformers
A 261M-parameter nanochat-architecture GPT with the YatNMN-Softplus MLP (per-neuron bias, softplus-positive bias, learnable epsilon, learnable Ξ±). Trained in JAX/Flax on TPU v6e-8 to Chinchilla-optimal token budget on C4, then ported to PyTorch for easy inference via the HuggingFace transformers API.
This is the best-performing 261M model in the ablation series:
| MLP variant | Final smooth loss | vs GELU |
|---|---|---|
| YatNMN-Softplus (per-neuron) | 2.98 | β0.13 |
| YatNMN-Softplus + scalar_bias | 3.06 | β0.05 |
| GELU | 3.11 | baseline |
Weights are bit-exact with the Flax checkpoint (mlnomad/yatnmn-softplus-d12-chinchilla-261M) β parity validated at max |Ξ logits| = 1.5e-5 on CPU/fp32.
Quick start
pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"mlnomad/yatnmn-softplus-d12-chinchilla-261M-pytorch",
trust_remote_code=True,
dtype=torch.float32,
).eval()
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
out = model.generate(
ids, max_new_tokens=50,
do_sample=True, temperature=0.8, top_p=0.9,
use_cache=True, pad_token_id=tokenizer.eos_token_id or 0,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
Greedy completion samples:
"The meaning of life is" β the same as life. The meaning of life is the same as lifeβ¦
"Once upon a time," β the world was a place where people could live and work. The world was a place where people couldβ¦
YatNMN-Softplus MLP
Each MLP block uses the YatNMN nonlinearity from nmn>=0.2.29:
y = Ξ± Β· (x Β· W + softplus(b))Β² / (||x β W||Β² + softplus(Ξ΅))
with per-neuron bias b of shape (4Β·n_embd,) = (3072,), scalar learnable epsilon of shape (1,), and scalar learnable Ξ± of shape (1,). Both bias and epsilon are passed through softplus to keep them strictly positive. The MLP is then c_proj (Linear β 768) on top of YatNMN's output.
Model details
| Parameters | 261,133,226 |
| Architecture | Nanochat-style GPT with YatNMN-Softplus MLP (ported from JAX/Flax NNX) |
| Config | d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window |
| Training data | allenai/c4 (English split), 5.22 B tokens (Chinchilla 20Γ) |
| Tokenizer | mistralai/Mistral-7B-v0.1 (vocab 32,768) |
| Optimizer | plain AdamW, peak LR 0.03, warmup-cosine |
| Hardware | TPU v6e-8 (TRC), europe-west4-a |
| Final loss (smooth) | 2.98 |
Architecture features
Full nanochat stack, faithfully ported to PyTorch:
- YatNMN-Softplus MLP (per-neuron bias, softplus-positive, learnable Ξ± and Ξ΅)
- RoPE (base 100,000), split-half layout
- MHA (n_head = n_kv_head = 12; the code supports GQA via n_kv_head < n_head, but all d=12 models use full MHA)
- QK-norm with 1.2Γ scaling (after RoPE)
- Parameterless RMSNorm (no learnable gain) post-embedding and per block
- Sliding-window attention with
"SSSL"pattern - Tied embeddings (lm_head = wte.T)
- Value embeddings on alternating layers (ResFormer-style)
- Per-layer learnable residual scalars (
resid_lambdas,x0_lambdas) - Smear β learnable gate on first 24 dims of token embedding mixes in prev token
- Backout β subtract mid-layer residual from late layers
- Logit soft-cap:
15 Β· tanh(logits / 15) - No biases in any Linear
KV cache
The YatGPTForCausalLM class implements a smear-aware KV cache for fast autoregressive generation. KV-cache parity vs full forward is validated at max |Ξ| < 3e-5. Pass use_cache=True (the default for .generate()).
Files in this repo
.
βββ config.json # HF config with auto_map β the classes below
βββ generation_config.json
βββ model.safetensors # ~1.04 GB, fp32 weights + persistent RoPE buffers
βββ yatnmn_gpt.py # pure PyTorch Yat_GPT module + YatNMN layer
βββ torch_gpt.py # shared building blocks (RMSNorm, RoPE, attention)
βββ configuration_yatnmn_gpt.py # PretrainedConfig subclass
βββ modeling_yatnmn_gpt.py # PreTrainedModel + GenerationMixin wrapper with KV cache
βββ README.md
Related
mlnomad/yatnmn-softplus-d12-chinchilla-261Mβ original JAX/Flax Orbax checkpoint (model + AdamW optimizer state, resumable)mlnomad/gelu-d12-chinchilla-261M-pytorchβ GELU baseline at identical compute, smooth loss 3.11- flaxchat β JAX/Flax training harness
nmnβ the YatNMN layer (used at training time; not required for inference here, the nonlinearity is reimplemented in pure PyTorch)
Wikitext-103 evaluation
| Metric | Value |
|---|---|
| Wikitext-103 test loss | 3.693 |
| Wikitext-103 test PPL | 40.15 |
Evaluated on ~330K tokens from wikitext-103 test set (model trained on C4 only β this is a zero-shot transfer metric).
License
Apache 2.0.
- Downloads last month
- 3,742
Model tree for mlnomad/yatnmn-softplus-d12-chinchilla-261M-pytorch
Base model
mlnomad/yatnmn-softplus-d12-chinchilla-261M