SAE for all-MiniLM-L6-v2 (FineWeb + RedPajama + Pile, 150M)

Sparse Autoencoder trained on sentence embeddings from all-MiniLM-L6-v2, decomposing 384-dimensional dense embeddings into sparse, interpretable features.

Available Models

Subfolder k Expansion Features Active FVU Dead % Best for
128_4 128 4x 1,536 26.2% 0.097 73.8% Best fine-grained accuracy, most distinct features
128_8 128 8x 3,072 23.4% 0.069 76.6% Best reconstruction, retrieval
64_8 64 8x 3,072 98.8% 0.156 1.2% Maximum feature coverage

Recommended: 128_4 — only 402 active features but best accuracy on hard tasks (CLINC150 79.6%, BANKING77 86.5%), most distinct features (MMCS 0.193), and half the parameters. Best balance of quality and efficiency.

Quick Start

from latentsae import Sae
from sentence_transformers import SentenceTransformer

# Load SAE
sae = Sae.load_from_hub("enjalot/sae-all-MiniLM-L6-v2-FineWeb-RedPajama-Pile-150M", "64_8")

# Embed text
emb_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = emb_model.encode(["Your text here"], normalize_embeddings=True)

# Extract sparse features
import torch
features = sae.encode(torch.tensor(embeddings))
print(f"Top feature indices: {features.top_indices}")
print(f"Top feature activations: {features.top_acts}")

Training Details

  • Embedding model: sentence-transformers/all-MiniLM-L6-v2 (384D)
  • Training data: 150M embeddings (50M each):
    • FineWeb-edu 10BT sample (120-token chunks)
    • RedPajama-Data-V2 10B sample (120-token chunks)
    • Pile uncopyrighted (120-token chunks)
  • Architecture: TopK SAE, k=64, 8x expansion (3,072 features), 2.4M parameters
  • Training: auxk_alpha=1/32, dead_feature_threshold=50K, cosine LR schedule
  • Hardware: A10G on Modal, 54 minutes, ~$1

Evaluation (Probe Accuracy)

Linear probes on SAE sparse features vs raw embeddings:

Task Raw 128_4 Gap 128_8 Gap 64_8 Gap
AG News (4-class) 89.9% 88.1% -1.8% 89.3% -0.6% 89.0% -0.9%
SST-2 (2-class) 80.6% 78.9% -1.7% 80.0% -0.6% 80.6% 0.0%
BANKING77 (77-class) 87.7% 86.5% -1.2% 85.7% -2.0% 85.2% -2.4%
CLINC150 (150-class) 84.1% 79.6% -4.5% 76.4% -7.7% 64.6% -19.6%
STS-B (spearman) 0.881 0.866 -0.015 0.871 -0.010 0.860 -0.021
SciFact (nDCG@10) 0.645 0.621 -0.024 0.626 -0.019 0.584 -0.061

Feature Quality

Metric 128_4 128_8 64_8
FVU 0.097 0.069 0.156
MMCS (redundancy) 0.193 0.232 0.287
Active features 402/1,536 719/3,072 2,954/3,072
Parameters 1.2M 2.4M 2.4M
Normalized entropy 0.808 0.788 0.835

Part of the latent-* ecosystem

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train enjalot/sae-all-MiniLM-L6-v2-FineWeb-RedPajama-Pile-150M