This tiny model is intended for debugging. It is randomly initialized using the configuration adapted from google/diffusiongemma-26B-A4B-it.

Requires transformers>=5.12.0, where DiffusionGemma is available.

File path Size
model.safetensors 5.6MB

Example usage:

def disable_broken_torchaudio_probe():
    import importlib.util as importlib_util

    original_find_spec = importlib_util.find_spec

    def find_spec(name, package=None):
        if name == "torchaudio" or name.startswith("torchaudio."):
            return None
        return original_find_spec(name, package)

    importlib_util.find_spec = find_spec

disable_broken_torchaudio_probe()

import torch
from transformers import AutoTokenizer, DiffusionGemmaForBlockDiffusion

model_id = "tiny-random/diffusiongemma-26B-A4B-it"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = DiffusionGemmaForBlockDiffusion.from_pretrained(
    model_id,
    dtype=dtype,
).to(device)
messages = [
    {
        "role": "user",
        "content": "Why is the sky blue?",
    },
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True,
).to(device)
input_ids = inputs["input_ids"]
outputs = model.generate(
    input_ids,
    max_new_tokens=4,
    max_denoising_steps=2,
)
print(tokenizer.decode(outputs.sequences[0], skip_special_tokens=False))
print("tokens_per_forward:", outputs.tokens_per_forward)

Codes to create this repo:

Click to expand
import json
import shutil
from pathlib import Path

def disable_broken_torchaudio_probe():
    import importlib.util as importlib_util

    original_find_spec = importlib_util.find_spec

    def find_spec(name, package=None):
        if name == "torchaudio" or name.startswith("torchaudio."):
            return None
        return original_find_spec(name, package)

    importlib_util.find_spec = find_spec

disable_broken_torchaudio_probe()

import torch
from huggingface_hub import file_exists, hf_hub_download
from transformers import (
    AutoConfig,
    AutoTokenizer,
    DiffusionGemmaForBlockDiffusion,
    DiffusionGemmaGenerationConfig,
    set_seed,
)

source_model_id = "google/diffusiongemma-26B-A4B-it"
save_folder = "/tmp/tiny-random/diffusiongemma-26B-A4B-it"

tokenizer = AutoTokenizer.from_pretrained(source_model_id)
tokenizer.save_pretrained(save_folder)
for filename in ("chat_template.jinja", "processor_config.json"):
    if file_exists(filename=filename, repo_id=source_model_id, repo_type="model"):
        src = hf_hub_download(source_model_id, filename=filename, repo_type="model")
        dst = Path(save_folder, filename)
        dst.parent.mkdir(parents=True, exist_ok=True)
        shutil.copyfile(src, dst)

with open(
    hf_hub_download(source_model_id, filename="config.json", repo_type="model"),
    "r",
    encoding="utf-8",
) as f:
    config_json = json.load(f)

config_json["canvas_length"] = 4
config_json["text_config"].update(
    {
        "global_head_dim": 32,
        "head_dim": 32,
        "hidden_size": 8,
        "intermediate_size": 64,
        "layer_types": [
            "sliding_attention",
            "full_attention",
        ],
        "moe_intermediate_size": 32,
        "num_attention_heads": 4,
        "num_experts": 4,
        "num_hidden_layers": 2,
        "num_key_value_heads": 4,
        "top_k_experts": 2,
    }
)
config_json["vision_config"].update(
    {
        "global_head_dim": 8,
        "head_dim": 8,
        "hidden_size": 32,
        "intermediate_size": 64,
        "num_attention_heads": 4,
        "num_hidden_layers": 2,
        "num_key_value_heads": 4,
    }
)

with open(f"{save_folder}/config.json", "w", encoding="utf-8") as f:
    json.dump(config_json, f, indent=2)

config = AutoConfig.from_pretrained(
    save_folder,
    trust_remote_code=True,
)
print(config)

torch.set_default_dtype(torch.bfloat16)
model = DiffusionGemmaForBlockDiffusion(config)
torch.set_default_dtype(torch.float32)
if file_exists(
    filename="generation_config.json", repo_id=source_model_id, repo_type="model"
):
    model.generation_config = DiffusionGemmaGenerationConfig.from_pretrained(
        source_model_id,
    )
set_seed(42)
model = model.cpu()
all_numels = 0
for name, p in sorted(model.named_parameters()):
    all_numels += p.numel()
with torch.no_grad():
    for name, p in sorted(model.named_parameters()):
        torch.nn.init.normal_(p, 0, 0.2)
        print(name, p.shape, f"{p.numel() / all_numels * 100: .4f}%")
model.save_pretrained(save_folder)

Printing the model:

Click to expand
DiffusionGemmaForBlockDiffusion(
  (model): DiffusionGemmaModel(
    (encoder): DiffusionGemmaEncoderModel(
      (language_model): DiffusionGemmaEncoderTextModel(
        (embed_tokens): DiffusionGemmaTextScaledWordEmbedding(262144, 8, padding_idx=0)
        (layers): ModuleList(
          (0): DiffusionGemmaEncoderTextLayer(
            (self_attn): DiffusionGemmaEncoderTextAttention(
              (q_proj): Linear(in_features=8, out_features=128, bias=False)
              (k_proj): Linear(in_features=8, out_features=128, bias=False)
              (v_proj): Linear(in_features=8, out_features=128, bias=False)
              (o_proj): Linear(in_features=128, out_features=8, bias=False)
              (q_norm): DiffusionGemmaRMSNorm()
              (k_norm): DiffusionGemmaRMSNorm()
              (v_norm): DiffusionGemmaRMSNorm()
            )
            (mlp): DiffusionGemmaText4MLP(
              (gate_proj): Linear(in_features=8, out_features=64, bias=False)
              (up_proj): Linear(in_features=8, out_features=64, bias=False)
              (down_proj): Linear(in_features=64, out_features=8, bias=False)
              (act_fn): GELUTanh()
            )
            (input_layernorm): DiffusionGemmaRMSNorm()
            (post_attention_layernorm): DiffusionGemmaRMSNorm()
            (pre_feedforward_layernorm): DiffusionGemmaRMSNorm()
            (post_feedforward_layernorm): DiffusionGemmaRMSNorm()
            (router): DiffusionGemmaTextRouter(
              (norm): DiffusionGemmaRMSNorm()
              (proj): Linear(in_features=8, out_features=4, bias=False)
            )
            (experts): DiffusionGemmaTextExperts(
              (act_fn): GELUTanh()
            )
            (post_feedforward_layernorm_1): DiffusionGemmaRMSNorm()
            (post_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
            (pre_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
          )
          (1): DiffusionGemmaEncoderTextLayer(
            (self_attn): DiffusionGemmaEncoderTextAttention(
              (q_proj): Linear(in_features=8, out_features=128, bias=False)
              (k_proj): Linear(in_features=8, out_features=64, bias=False)
              (o_proj): Linear(in_features=128, out_features=8, bias=False)
              (q_norm): DiffusionGemmaRMSNorm()
              (k_norm): DiffusionGemmaRMSNorm()
              (v_norm): DiffusionGemmaRMSNorm()
            )
            (mlp): DiffusionGemmaText4MLP(
              (gate_proj): Linear(in_features=8, out_features=64, bias=False)
              (up_proj): Linear(in_features=8, out_features=64, bias=False)
              (down_proj): Linear(in_features=64, out_features=8, bias=False)
              (act_fn): GELUTanh()
            )
            (input_layernorm): DiffusionGemmaRMSNorm()
            (post_attention_layernorm): DiffusionGemmaRMSNorm()
            (pre_feedforward_layernorm): DiffusionGemmaRMSNorm()
            (post_feedforward_layernorm): DiffusionGemmaRMSNorm()
            (router): DiffusionGemmaTextRouter(
              (norm): DiffusionGemmaRMSNorm()
              (proj): Linear(in_features=8, out_features=4, bias=False)
            )
            (experts): DiffusionGemmaTextExperts(
              (act_fn): GELUTanh()
            )
            (post_feedforward_layernorm_1): DiffusionGemmaRMSNorm()
            (post_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
            (pre_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
          )
        )
        (norm): DiffusionGemmaRMSNorm()
        (rotary_emb): DiffusionGemmaTextRotaryEmbedding()
      )
      (vision_tower): Gemma4VisionModel(
        (patch_embedder): Gemma4VisionPatchEmbedder(
          (input_proj): Linear(in_features=768, out_features=32, bias=False)
        )
        (encoder): Gemma4VisionEncoder(
          (rotary_emb): Gemma4VisionRotaryEmbedding()
          (layers): ModuleList(
            (0-1): 2 x Gemma4VisionEncoderLayer(
              (self_attn): Gemma4VisionAttention(
                (q_proj): Gemma4ClippableLinear(
                  (linear): Linear(in_features=32, out_features=32, bias=False)
                )
                (k_proj): Gemma4ClippableLinear(
                  (linear): Linear(in_features=32, out_features=32, bias=False)
                )
                (v_proj): Gemma4ClippableLinear(
                  (linear): Linear(in_features=32, out_features=32, bias=False)
                )
                (o_proj): Gemma4ClippableLinear(
                  (linear): Linear(in_features=32, out_features=32, bias=False)
                )
                (q_norm): Gemma4RMSNorm()
                (k_norm): Gemma4RMSNorm()
                (v_norm): Gemma4RMSNorm()
              )
              (mlp): Gemma4VisionMLP(
                (gate_proj): Gemma4ClippableLinear(
                  (linear): Linear(in_features=32, out_features=64, bias=False)
                )
                (up_proj): Gemma4ClippableLinear(
                  (linear): Linear(in_features=32, out_features=64, bias=False)
                )
                (down_proj): Gemma4ClippableLinear(
                  (linear): Linear(in_features=64, out_features=32, bias=False)
                )
                (act_fn): GELUTanh()
              )
              (input_layernorm): Gemma4RMSNorm()
              (post_attention_layernorm): Gemma4RMSNorm()
              (pre_feedforward_layernorm): Gemma4RMSNorm()
              (post_feedforward_layernorm): Gemma4RMSNorm()
            )
          )
        )
        (pooler): Gemma4VisionPooler()
      )
      (embed_vision): DiffusionGemmaMultimodalEmbedder(
        (embedding_projection): Linear(in_features=32, out_features=8, bias=False)
        (embedding_pre_projection_norm): DiffusionGemmaRMSNorm()
      )
    )
    (decoder): DiffusionGemmaDecoderModel(
      (embed_tokens): DiffusionGemmaTextScaledWordEmbedding(262144, 8, padding_idx=0)
      (layers): ModuleList(
        (0): DiffusionGemmaDecoderTextLayer(
          (self_attn): DiffusionGemmaDecoderTextAttention(
            (q_proj): Linear(in_features=8, out_features=128, bias=False)
            (k_proj): Linear(in_features=8, out_features=128, bias=False)
            (v_proj): Linear(in_features=8, out_features=128, bias=False)
            (o_proj): Linear(in_features=128, out_features=8, bias=False)
            (q_norm): DiffusionGemmaRMSNorm()
            (k_norm): DiffusionGemmaRMSNorm()
            (v_norm): DiffusionGemmaRMSNorm()
          )
          (mlp): DiffusionGemmaText4MLP(
            (gate_proj): Linear(in_features=8, out_features=64, bias=False)
            (up_proj): Linear(in_features=8, out_features=64, bias=False)
            (down_proj): Linear(in_features=64, out_features=8, bias=False)
            (act_fn): GELUTanh()
          )
          (input_layernorm): DiffusionGemmaRMSNorm()
          (post_attention_layernorm): DiffusionGemmaRMSNorm()
          (pre_feedforward_layernorm): DiffusionGemmaRMSNorm()
          (post_feedforward_layernorm): DiffusionGemmaRMSNorm()
          (router): DiffusionGemmaTextRouter(
            (norm): DiffusionGemmaRMSNorm()
            (proj): Linear(in_features=8, out_features=4, bias=False)
          )
          (experts): DiffusionGemmaTextExperts(
            (act_fn): GELUTanh()
          )
          (post_feedforward_layernorm_1): DiffusionGemmaRMSNorm()
          (post_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
          (pre_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
        )
        (1): DiffusionGemmaDecoderTextLayer(
          (self_attn): DiffusionGemmaDecoderTextAttention(
            (q_proj): Linear(in_features=8, out_features=128, bias=False)
            (k_proj): Linear(in_features=8, out_features=64, bias=False)
            (o_proj): Linear(in_features=128, out_features=8, bias=False)
            (q_norm): DiffusionGemmaRMSNorm()
            (k_norm): DiffusionGemmaRMSNorm()
            (v_norm): DiffusionGemmaRMSNorm()
          )
          (mlp): DiffusionGemmaText4MLP(
            (gate_proj): Linear(in_features=8, out_features=64, bias=False)
            (up_proj): Linear(in_features=8, out_features=64, bias=False)
            (down_proj): Linear(in_features=64, out_features=8, bias=False)
            (act_fn): GELUTanh()
          )
          (input_layernorm): DiffusionGemmaRMSNorm()
          (post_attention_layernorm): DiffusionGemmaRMSNorm()
          (pre_feedforward_layernorm): DiffusionGemmaRMSNorm()
          (post_feedforward_layernorm): DiffusionGemmaRMSNorm()
          (router): DiffusionGemmaTextRouter(
            (norm): DiffusionGemmaRMSNorm()
            (proj): Linear(in_features=8, out_features=4, bias=False)
          )
          (experts): DiffusionGemmaTextExperts(
            (act_fn): GELUTanh()
          )
          (post_feedforward_layernorm_1): DiffusionGemmaRMSNorm()
          (post_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
          (pre_feedforward_layernorm_2): DiffusionGemmaRMSNorm()
        )
      )
      (norm): DiffusionGemmaRMSNorm()
      (rotary_emb): DiffusionGemmaTextRotaryEmbedding()
      (self_conditioning): DiffusionGemmaSelfConditioning(
        (pre_norm): DiffusionGemmaRMSNorm()
        (post_norm): DiffusionGemmaRMSNorm()
        (gate_proj): Linear(in_features=8, out_features=64, bias=False)
        (up_proj): Linear(in_features=8, out_features=64, bias=False)
        (down_proj): Linear(in_features=64, out_features=8, bias=False)
        (act_fn): GELUTanh()
      )
    )
  )
  (lm_head): Linear(in_features=8, out_features=262144, bias=False)
)

Test environment:

  • huggingface_hub: 1.19.0
  • torch: 2.10.0+cu128
  • transformers: 5.12.0
Downloads last month
-
Safetensors
Model size
2.82M params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for tiny-random/diffusiongemma-26B-A4B-it

Finetuned
(11)
this model