YatNMN-Softplus d=12 Chinchilla (261M) β€” PyTorch / HuggingFace Transformers

A 261M-parameter nanochat-architecture GPT with the YatNMN-Softplus MLP (per-neuron bias, softplus-positive bias, learnable epsilon, learnable Ξ±). Trained in JAX/Flax on TPU v6e-8 to Chinchilla-optimal token budget on C4, then ported to PyTorch for easy inference via the HuggingFace transformers API.

This is the best-performing 261M model in the ablation series:

MLP variant Final smooth loss vs GELU
YatNMN-Softplus (per-neuron) 2.98 βˆ’0.13
YatNMN-Softplus + scalar_bias 3.06 βˆ’0.05
GELU 3.11 baseline

Weights are bit-exact with the Flax checkpoint (mlnomad/yatnmn-softplus-d12-chinchilla-261M) β€” parity validated at max |Ξ” logits| = 1.5e-5 on CPU/fp32.

Quick start

pip install torch transformers safetensors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "mlnomad/yatnmn-softplus-d12-chinchilla-261M-pytorch",
    trust_remote_code=True,
    dtype=torch.float32,
).eval()

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

prompt = "The meaning of life is"
ids = tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
    out = model.generate(
        ids, max_new_tokens=50,
        do_sample=True, temperature=0.8, top_p=0.9,
        use_cache=True, pad_token_id=tokenizer.eos_token_id or 0,
    )
print(tokenizer.decode(out[0], skip_special_tokens=True))

Greedy completion samples:

"The meaning of life is" β†’ the same as life. The meaning of life is the same as life…

"Once upon a time," β†’ the world was a place where people could live and work. The world was a place where people could…

YatNMN-Softplus MLP

Each MLP block uses the YatNMN nonlinearity from nmn>=0.2.29:

y = Ξ± Β· (x Β· W + softplus(b))Β² / (||x βˆ’ W||Β² + softplus(Ξ΅))

with per-neuron bias b of shape (4Β·n_embd,) = (3072,), scalar learnable epsilon of shape (1,), and scalar learnable Ξ± of shape (1,). Both bias and epsilon are passed through softplus to keep them strictly positive. The MLP is then c_proj (Linear β†’ 768) on top of YatNMN's output.

Model details

Parameters 261,133,226
Architecture Nanochat-style GPT with YatNMN-Softplus MLP (ported from JAX/Flax NNX)
Config d=12, n_embd=768, n_head=12, n_kv_head=12, seq_len=1024, tied embeddings, SSSL sliding window
Training data allenai/c4 (English split), 5.22 B tokens (Chinchilla 20Γ—)
Tokenizer mistralai/Mistral-7B-v0.1 (vocab 32,768)
Optimizer plain AdamW, peak LR 0.03, warmup-cosine
Hardware TPU v6e-8 (TRC), europe-west4-a
Final loss (smooth) 2.98

Architecture features

Full nanochat stack, faithfully ported to PyTorch:

  • YatNMN-Softplus MLP (per-neuron bias, softplus-positive, learnable Ξ± and Ξ΅)
  • RoPE (base 100,000), split-half layout
  • MHA (n_head = n_kv_head = 12; the code supports GQA via n_kv_head < n_head, but all d=12 models use full MHA)
  • QK-norm with 1.2Γ— scaling (after RoPE)
  • Parameterless RMSNorm (no learnable gain) post-embedding and per block
  • Sliding-window attention with "SSSL" pattern
  • Tied embeddings (lm_head = wte.T)
  • Value embeddings on alternating layers (ResFormer-style)
  • Per-layer learnable residual scalars (resid_lambdas, x0_lambdas)
  • Smear β€” learnable gate on first 24 dims of token embedding mixes in prev token
  • Backout β€” subtract mid-layer residual from late layers
  • Logit soft-cap: 15 Β· tanh(logits / 15)
  • No biases in any Linear

KV cache

The YatGPTForCausalLM class implements a smear-aware KV cache for fast autoregressive generation. KV-cache parity vs full forward is validated at max |Ξ”| < 3e-5. Pass use_cache=True (the default for .generate()).

Files in this repo

.
β”œβ”€β”€ config.json                       # HF config with auto_map β†’ the classes below
β”œβ”€β”€ generation_config.json
β”œβ”€β”€ model.safetensors                 # ~1.04 GB, fp32 weights + persistent RoPE buffers
β”œβ”€β”€ yatnmn_gpt.py                     # pure PyTorch Yat_GPT module + YatNMN layer
β”œβ”€β”€ torch_gpt.py                      # shared building blocks (RMSNorm, RoPE, attention)
β”œβ”€β”€ configuration_yatnmn_gpt.py       # PretrainedConfig subclass
β”œβ”€β”€ modeling_yatnmn_gpt.py            # PreTrainedModel + GenerationMixin wrapper with KV cache
└── README.md

Related

Wikitext-103 evaluation

Metric Value
Wikitext-103 test loss 3.693
Wikitext-103 test PPL 40.15

Evaluated on ~330K tokens from wikitext-103 test set (model trained on C4 only β€” this is a zero-shot transfer metric).

License

Apache 2.0.

Downloads last month
3,742
Safetensors
Model size
0.3B params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for mlnomad/yatnmn-softplus-d12-chinchilla-261M-pytorch

Finetuned
(2)
this model

Dataset used to train mlnomad/yatnmn-softplus-d12-chinchilla-261M-pytorch

Space using mlnomad/yatnmn-softplus-d12-chinchilla-261M-pytorch 1