It works!

#1
by ktsaou - opened

@Firworks the key problem with this is that it has 103 experts but they must be multiples of 32 for nvfp4 to work.

I managed to run it (with the help of Claude Code), under vllm 0.14.0rc1.dev140+g87f1b8ca2 with this:

#!/bin/bash

# Activate vLLM venv
source /opt/vllm/bin/activate

# Enable CUDA toolkit (nvcc required for FlashInfer detection)
export PATH=/usr/local/cuda/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH

# Set HuggingFace cache to use existing models
export HF_HOME=/opt/models/huggingface

export VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1

export VLLM_USE_FLASHINFER_MOE_FP4=1
export ENABLE_NVFP4_SM120=1
export VLLM_WORKER_MULTIPROC_METHOD=fork # mandatory for our patch to work
export SAFETENSORS_FAST_GPU=1

# Using 103-expert fix wrapper (pads gate layer 103β†’128 for CUTLASS FP4 alignment)
exec env CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} \
  python /opt/vllm/scripts/Firworks--Qwen3-Coder-REAP-25B-A3B-nvfp4-wrapper.py \
  serve Firworks/Qwen3-Coder-REAP-25B-A3B-nvfp4 \
  --host 0.0.0.0 \
  --port 8350 \
  --served-model-name qwen3-coder-25b-a3b-instruct \
  --gpu-memory-utilization 0.765 \
  --tool-call-parser qwen3_coder \
  --enable-auto-tool-choice \
  --max-model-len 196608 \
  --kv-cache-dtype fp8_e4m3 \
  --attention-backend FLASHINFER \
  --dtype auto \
  --trust-remote-code \
  --max-num-seqs 6 \
  --max-num-batched-tokens 8192 \
  --enable-prefix-caching \
  --enable-chunked-prefill \
  --enable-expert-parallel \

Note that I didn't run vllm directly. Instead I called a custom wrapper than expanded 103 experts to 128 (with zero weights)πŸ˜€

#!/usr/bin/env python3
"""
vLLM wrapper for Firworks/Qwen3-Coder-REAP-25B-A3B-nvfp4
Pads gate layer from 103 to 128 experts for CUTLASS FP4 alignment.
"""
import functools
import torch.nn.functional as F
from vllm.model_executor.models import qwen3_moe
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.config import VllmConfig

align = lambda n: ((n + 31) // 32) * 32

orig_init = qwen3_moe.Qwen3MoeSparseMoeBlock.__init__
@functools.wraps(orig_init)
def patched_init(self, vllm_config: VllmConfig, prefix: str = ""):
    config = vllm_config.model_config.hf_text_config
    self._original_num_experts = config.num_experts
    self._needs_gate_fix = config.num_experts % 32 != 0
    orig_init(self, vllm_config, prefix)
    if self._needs_gate_fix:
        self.gate = ReplicatedLinear(
            config.hidden_size, align(config.num_experts), bias=False,
            quant_config=vllm_config.quant_config, prefix=f"{prefix}.gate")
        self.gate._needs_weight_padding = True
qwen3_moe.Qwen3MoeSparseMoeBlock.__init__ = patched_init

orig_fwd = qwen3_moe.Qwen3MoeSparseMoeBlock.forward
@functools.wraps(orig_fwd)
def patched_forward(self, hidden_states):
    if not getattr(self, '_needs_gate_fix', False):
        return orig_fwd(self, hidden_states)
    is_1d = hidden_states.dim() == 1
    if is_1d: hidden_states = hidden_states.unsqueeze(0)
    hidden_states = hidden_states.view(-1, hidden_states.size(-1))
    logits, _ = self.gate(hidden_states)
    if logits.size(-1) > self._original_num_experts:
        logits = logits.clone()
        logits[:, self._original_num_experts:] = float('-inf')
        logits = logits[:, :self._original_num_experts]
    out = self.experts(hidden_states=hidden_states, router_logits=logits)
    return out.squeeze(0) if is_1d else out
qwen3_moe.Qwen3MoeSparseMoeBlock.forward = patched_forward

orig_loader = ReplicatedLinear.weight_loader
def patched_loader(self, param, loaded_weight):
    if getattr(self, '_needs_weight_padding', False) and param.size() != loaded_weight.size():
        if param.dim() == 2 and loaded_weight.dim() == 2 and param.size(1) == loaded_weight.size(1):
            if param.size(0) == align(loaded_weight.size(0)):
                param.data.copy_(F.pad(loaded_weight, (0, 0, 0, param.size(0) - loaded_weight.size(0))))
                return
    return orig_loader(self, param, loaded_weight)
ReplicatedLinear.weight_loader = patched_loader

print("[103-expert-fix] Patched qwen3_moe (103β†’128)")

if __name__ == '__main__':
    from vllm.scripts import main
    main()

I tested the model for agentic use. It has a few problems:

  1. Sometimes the tools are sent to the output of the model
  2. Sometimes the model enters text degeneration

Overall, very fast. Peaks at about 150 tokens per second (single request) on an rtx 5090.

Wow, nice! I wonder how much of those performance issues are the base REAP model vs the quantization. I'll update the model card to point to this discussion until VLLM updates and supports it without patching.

Sign up or log in to comment