When training models from the gemma4 series using GRPO, an abnormally high grad norm was observed
I used verl to train a GEMMA-4 model with the FSDP and VLLM backends, and tested all instruction-based models using the GSM-8K dataset, with results matching expectations. However, when validating the image-to-text task using the Geo-3K dataset, the model training crashed. This occurred despite the fact that the same code worked normally when training the 31B model. Compared to models in the QWEN3 series, we observed abnormally high gradient norms in Gemma4, ranging from tens to hundreds during training. Upon inspecting the code, this appears to be caused by embed scaling. I have modified the dataset and rewards according to the Gemma4 thinking mode and attempted to adjust the learning rate and maximum gradient norm, but the training crashes persist. Below is my training script.
export PYTHONPATH="/root/verl:$PYTHONPATH"
export VLLM_USE_V1=1
export VLLM_ALLREDUCE_USE_SYMM_MEM=0
export SWANLAB_API_KEY=xxxxx
export SWANLAB_MODE=cloud
echo $NODE_IP_LIST > env.txt
sed "s/:/ slots=/g" env.txt | sed "s/,/\n/g" > "hostfile"
sed "s/:.//g" env.txt | sed "s/,/\n/g" > "pssh.hosts"
source ~/.bashrc
bash stop_ray.sh
bash start_ray.sh
########################### Quick Config ###########################
GEN_TP=${GEN_TP:-8}
ALL_OFFLOAD=${ALL_OFFLOAD:-True}
rollout_name="vllm"
project_name='verl_grpo_gemma_4b'
exp_name='gemma_a4b_fsdp_i2t_grpo'
OUTPUT_DIR=xxxx
CKPTS_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/ckpts"
ROLLOUT_DATA_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/rollout_data"
VALIDATION_DATA_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/validation_data"
LOG_DIR="${OUTPUT_DIR}/${project_name}/${exp_name}/logs/$(date +%Y%m%d_%H%M%S)"
SWANLAB_LOG_DIR=${OUTPUT_DIR}/"${project_name}"/"${exp_name}/swanlog"
mkdir -p ${CKPTS_DIR}
mkdir -p ${ROLLOUT_DATA_DIR}
mkdir -p ${VALIDATION_DATA_DIR}
mkdir -p ${LOG_DIR}
mkdir -p ${SWANLAB_LOG_DIR}
adv_estimator=grpo
NNODES=4
max_prompt_length=2048
max_response_length=8192
train_prompt_bsz=64
train_prompt_mini_bsz=${train_prompt_bsz}
train_temperature=1.2
train_rollot_n=8
train_top_p=1.0
train_top_k=-1
# val rollout params
val_temperature=0
val_rollot_n=1
val_top_p=1.0
val_top_k=-1
val_do_sample=False
HF_MODEL_PATH=${HF_MODEL_PATH:-/model/google/gemma-4-26B-A4B-it/hf}
DATASET_DIR=${DATASET_DIR:-/data}
train_path=${train_path:-${DATASET_DIR}/rl/virl39k_gemma4/train.parquet}
test_path=${test_path:-${DATASET_DIR}/rl/virl39k_gemma4/test.parquet}
########################### Parameter Arrays ###########################
DATA=(
data.train_files=${train_path}
data.val_files=${test_path}
data.train_batch_size=${train_prompt_bsz}
data.max_prompt_length=${max_prompt_length}
data.max_response_length=${max_response_length}
data.truncation='error'
data.filter_overlong_prompts=True
data.filter_overlong_prompts_workers=8
+data.apply_chat_template_kwargs.enable_thinking=True
)
MODEL=(
actor_rollout_ref.model.path=${HF_MODEL_PATH}
actor_rollout_ref.model.trust_remote_code=True
actor_rollout_ref.model.use_remove_padding=False
+actor_rollout_ref.model.override_config.attn_implementation=sdpa
)
ACTOR=(
actor_rollout_ref.actor.strategy=fsdp
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30720
actor_rollout_ref.actor.use_dynamic_bsz=False
actor_rollout_ref.actor.use_kl_loss=True
actor_rollout_ref.actor.kl_loss_coef=0.01
actor_rollout_ref.actor.kl_loss_type=low_var_kl
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.grad_clip=1.0
actor_rollout_ref.actor.fsdp_config.param_offload=${ALL_OFFLOAD}
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ALL_OFFLOAD}
actor_rollout_ref.actor.fsdp_config.offload_policy=${ALL_OFFLOAD}
actor_rollout_ref.actor.fsdp_config.reshard_after_forward=True
actor_rollout_ref.actor.checkpoint.save_contents='["model","optimizer","extra","hf_model"]'
)
ROLLOUT=(
actor_rollout_ref.rollout.name=${rollout_name}
actor_rollout_ref.rollout.tensor_model_parallel_size=${GEN_TP}
actor_rollout_ref.rollout.gpu_memory_utilization=0.5
actor_rollout_ref.rollout.n=${train_rollot_n}
actor_rollout_ref.rollout.enforce_eager=False
actor_rollout_ref.rollout.temperature=${train_temperature}
actor_rollout_ref.rollout.top_p=${train_top_p}
actor_rollout_ref.rollout.top_k=${train_top_k}
actor_rollout_ref.rollout.dtype=bfloat16
actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature}
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}
actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k}
actor_rollout_ref.rollout.val_kwargs.do_sample=${val_do_sample}
actor_rollout_ref.rollout.val_kwargs.n=${val_rollot_n}
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=30720
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False
actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=6000
)
REF=(
actor_rollout_ref.ref.strategy=fsdp
actor_rollout_ref.ref.fsdp_config.param_offload=${ALL_OFFLOAD}
actor_rollout_ref.ref.fsdp_config.offload_policy=${ALL_OFFLOAD}
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=30720
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False
)
ALGORITHM=(
algorithm.adv_estimator=${adv_estimator}
algorithm.use_kl_in_reward=False
)
REWARD=(
reward.custom_reward_function.path=/root/verl/gemma4_A4B_it/reward_fn.py
reward.custom_reward_function.name=compute_score
)
TRAINER=(
trainer.critic_warmup=0
trainer.logger='["console","swanlab"]'
trainer.project_name=${project_name}
trainer.experiment_name=${exp_name}
trainer.default_local_dir=${CKPTS_DIR}
trainer.n_gpus_per_node=8
trainer.nnodes=${NNODES}
trainer.save_freq=50
trainer.test_freq=5
trainer.total_epochs=15
)
RAY_KWARGS=(
+ray_kwargs.ray_init.runtime_env.env_vars.SWANLAB_LOG_DIR="${SWANLAB_LOG_DIR}"
+ray_kwargs.ray_init.runtime_env.env_vars.SWANLAB_MODE="${SWANLAB_MODE}"
+ray_kwargs.ray_init.runtime_env.env_vars.SWANLAB_API_KEY="${SWANLAB_API_KEY}"
+ray_kwargs.ray_init.runtime_env.env_vars.TORCH_NCCL_ASYNC_ERROR_HANDLING="'1'"
+ray_kwargs.ray_init.runtime_env.env_vars.TORCH_NCCL_BLOCKING_WAIT="'1'"
+ray_kwargs.ray_init.runtime_env.env_vars.TORCH_NCCL_TIMEOUT="'1800'"
+ray_kwargs.ray_init.runtime_env.env_vars.NCCL_ASYNC_ERROR_HANDLING="'1'"
)
########################### Launch ###########################
python3 -m verl.trainer.main_ppo \
--config-path=config \
--config-name='ppo_trainer.yaml' \
"${DATA[@]}" \
"${ALGORITHM[@]}" \
"${REWARD[@]}" \
"${MODEL[@]}" \
"${ROLLOUT[@]}" \
"${ACTOR[@]}" \
"${REF[@]}" \
"${TRAINER[@]}" \
"${RAY_KWARGS[@]}" \
"$@" 2>&1 | tee "$LOG_DIR/train.log"
bash stop_ray.sh
You can access the relevant experiment monitoring at https://swanlab.cn/@allenzpma/verl_grpo_gemma_4b/runs
If additional task information is required,