LLaDA
📄 Paper: Large Language Diffusion Models | 💻 Code: github.com/ML-GSAI/LLaDA
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models LLaDA.
Table of Contents
Setup
Important
Slurm users: Update
scripts/train.slurm.shandmkdir logps: see (optional) Slurm setup for details.MoE checkpoints: For models like
LLaDA-MoE-7B-A1B-Base, set"model_type"to"lladamoe"in the checkpoint’sconfig.json:- "model_type": "llada", + "model_type": "lladamoe",
Files overview
# tools relevant with LLaDA
dllm/pipelines/llada
├── __init__.py # Package initialization
├── models/
│ ├── configuration_lladamoe.py # LLaDA-MoE model configuration
│ ├── configuration_llada.py # LLaDA model configuration
│ ├── modeling_lladamoe.py # LLaDA-MoE model architecture
│ └── modeling_llada.py # LLaDA model architecture
├── generator.py # Inference logic
└── trainer.py # Training logic (pretraining and finetuning)
# example entry points for training / inference / evaluation
examples/llada
├── chat.py # Interactive inference example
├── eval.sh # Automatic evaluation script
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
└── sft.py # Supervised finetuning example
Training
Finetuning
For example, to SFT LLaDA-8B-Base for instruction following on 8 GPUs, run:
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/llada/sft.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
Reproducing LLaDA-8B-Instruct
Though LLaDA is trained on proprietary data, we tried our best to reproduce LLaDA-8B-Instruct by finetuning LLaDA-8B-Base using our training pipeline on public instruction-following dataset allenai/tulu-3-sft-mixture:
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--sft_map_fn_path "dllm.utils.default_sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/llada/tulu-3-sft-mixture" \
--num_proc 64
# train on 24*8=192 A100s with FSDP, take about 8 hours
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "data/sft/llada/tulu-3-sft-mixture" \
--load_preprocessed_data True \
--output_dir "models/LLaDA-8B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
--max_length 2048 \
--truncation "right" \
--group_by_length True \
--num_train_epochs 5 \
--learning_rate 1e-5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--eval_on_start False \
--eval_steps 0.1 \
--save_steps 0.05
Pretraining
Pretrain on mlfoundations/dclm-baseline-1.0 from scratch using 192 GPUs (24x8) and FSDP:
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/pt.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "mlfoundations/dclm-baseline-1.0" \
--output_dir "models/LLaDA-8B-PT/dclm-baseline-1.0" \
--max_length 1024 \
--max_steps 2000 \
--learning_rate 3e-4
Inference
We support batch inference for standard generation and infilling:
python examples/llada/generate.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
We also support interactive multi-turn dialogue with visualization:
python examples/llada/chat.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
Evaluation
Read (optional) Evaluation setup before running evaluation.
For example, to evaluate LLaDA-8B-Instruct on MMLU-Pro using 4 GPUs, run:
# Use model_args to adjust the generation arguments for evalution.
accelerate launch --num_processes 4 \
dllm/pipelines/llada/eval.py \
--tasks "mmlu_pro" \
--model "llada" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
To automatically evaluate LLaDA-8B-Base and LLaDA-8B-Instruct on all benchmarks, run:
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Instruct --instruct True
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Base --instruct False
Evaluation results
Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the LLaDA repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
| MMLU | BBH | ARC‑C | Hellaswag | TruthfulQA | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | CEval | CMMLU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
LLaDA-8B-Base(reported) |
65.9 | 49.7 | 45.9 | 70.5 | 46.1 | 74.8 | 73.6 | 70.3 | 31.4 | 25.2 | 35.4 | 40.0 | 70.5 | 69.9 |
LLaDA-8B-Base(evaluated) |
65.8 | – | 45.7 | 69.3 | 45.6 | 70.7 | 70.6 | 70.4 | – | – | 32.3 | 38.8 | 70.2 | 69.9 |
Table 1. Evaluation results of
LLaDA-8B-Base
.
| MMLU | MMLU‑Pro | ARC‑C | Hellaswag | GSM8K | Math | GPQA | HumanEval | MBPP | |
|---|---|---|---|---|---|---|---|---|---|
LLaDA-8B-Instruct(reported) |
65.5 | 37.0 | 88.5 | 74.6 | 69.4 | 31.9 | 33.3 | 49.4 | 41.0 |
LLaDA-8B-Instruct(evaluated) |
67.3 | 36.2 | 86.6 | 76.7 | 81.1 | – | – | 65.0 | 70.2 |
Table 2. Evaluation results of
LLaDA-8B-Instruct
.