Files
MIDIFoundationModel/dllm/examples/llada/README.md
2025-11-27 15:44:17 +08:00

207 lines
10 KiB
Markdown
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# LLaDA
> 📄 Paper: [Large Language Diffusion Models](https://arxiv.org/abs/2502.09992) | 💻 Code: [github.com/ML-GSAI/LLaDA](https://github.com/ML-GSAI/LLaDA)
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **LLaDA**.
## Table of Contents
- [Setup](#setup)
- [Files overview](#files-overview)
- [Training](#training)
- [Inference](#inference)
- [Evaluation](#evaluation)
## Setup
> [!IMPORTANT]
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
>
> **MoE checkpoints:** For models like [`LLaDA-MoE-7B-A1B-Base`](https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base), set `"model_type"` to `"lladamoe"` in the checkpoints `config.json`:
> ```diff
> - "model_type": "llada",
> + "model_type": "lladamoe",
> ```
>
## Files overview
```
# tools relevant with LLaDA
dllm/pipelines/llada
├── __init__.py # Package initialization
├── models/
│ ├── configuration_lladamoe.py # LLaDA-MoE model configuration
│ ├── configuration_llada.py # LLaDA model configuration
│ ├── modeling_lladamoe.py # LLaDA-MoE model architecture
│ └── modeling_llada.py # LLaDA model architecture
├── generator.py # Inference logic
└── trainer.py # Training logic (pretraining and finetuning)
# example entry points for training / inference / evaluation
examples/llada
├── chat.py # Interactive inference example
├── eval.sh # Automatic evaluation script
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
└── sft.py # Supervised finetuning example
```
<!-- > [!NOTE] -->
<!-- > - We fixed attention mask bugs in [`modeling_lladamoe.py`](/dllm/pipelines/llada/models/modeling_lladamoe.py) and [`modeling_llada.py`](/dllm/pipelines/llada/models/modeling_llada.py). We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
>
> - We fixed bugs in `chat_template` and assign `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
<!-- > [!WARNING]
> Before loading MoE checkpoints (e.g., [inclusionAI/LLaDA-MoE-7B-A1B-Base](https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base)), first overwrite the `model_type` field from `inclusionAI/LLaDA-MoE-7B-A1B-Base/config.json`:
> ```diff
> - "model_type": "llada",
> + "model_type": "lladamoe",
> ``` -->
## Training
### Finetuning
For example, to SFT [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) for instruction following on 8 GPUs, run:
```shell
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/llada/sft.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
```shell
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
<!-- **Reproducing [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)**. Though LLaDA is trained on proprietary data, we tried our best to reproduce LLaDA-8B-Instruct by finetuning LLaDA-8B-Base using our training pipeline on public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
#### Reproducing [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)
Though LLaDA is trained on proprietary data, we tried our best to reproduce [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) by finetuning [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) using our training pipeline on public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
```shell
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--sft_map_fn_path "dllm.utils.default_sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/llada/tulu-3-sft-mixture" \
--num_proc 64
# train on 24*8=192 A100s with FSDP, take about 8 hours
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "data/sft/llada/tulu-3-sft-mixture" \
--load_preprocessed_data True \
--output_dir "models/LLaDA-8B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
--max_length 2048 \
--truncation "right" \
--group_by_length True \
--num_train_epochs 5 \
--learning_rate 1e-5 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--eval_on_start False \
--eval_steps 0.1 \
--save_steps 0.05
```
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
### Pretraining
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
```shell
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/pt.py" \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
--dataset_args "mlfoundations/dclm-baseline-1.0" \
--output_dir "models/LLaDA-8B-PT/dclm-baseline-1.0" \
--max_length 1024 \
--max_steps 2000 \
--learning_rate 3e-4
```
## Inference
We support batch inference for standard generation and infilling:
<!-- See [`examples/llada/generate.py`](/examples/llada/generate.py) for a full example: -->
```shell
python examples/llada/generate.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
```
We also support interactive multi-turn dialogue with visualization:
<!-- See [`examples/llada/chat.py`](/examples/llada/chat.py) for a full example. -->
```shell
python examples/llada/chat.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
```
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [MMLU-Pro](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
```shell
# Use model_args to adjust the generation arguments for evalution.
accelerate launch --num_processes 4 \
dllm/pipelines/llada/eval.py \
--tasks "mmlu_pro" \
--model "llada" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
```
To automatically evaluate [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) and [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on all benchmarks, run:
```shell
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Instruct --instruct True
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Base --instruct False
```
### Evaluation results
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results refer to those from the original paper.
> All evaluation parameters follow the configurations in the [LLaDA](https://github.com/ML-GSAI/LLaDA) repository.
> Since the original evaluations were conducted with OpenCompass, task settings were adjusted for compatibility with the LLaDA model under `lm-eval`.
> Complete evaluation results will be released soon. -->
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [LLaDA](https://github.com/ML-GSAI/LLaDA) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
<!-- <div align="center" style="min-width:1300px;"> -->
| | MMLU | BBH | ARC&#8209;C | Hellaswag | TruthfulQA | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | CEval | CMMLU |
|:----------------|:----:|:---:|:-----:|:-----------:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:------:|
| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(reported)| 65.9 | 49.7 | 45.9 | 70.5 | 46.1 | 74.8 | 73.6 | 70.3 | 31.4 | 25.2 | 35.4 | 40.0 | 70.5 | 69.9 |
| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(evaluated)| 65.8 | | 45.7 | 69.3 | 45.6 | 70.7 | 70.6 | 70.4 | | | 32.3 | 38.8 | 70.2 | 69.9 |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 1. Evaluation results of
<a href="https://huggingface.co/GSAI-ML/LLaDA-8B-Base" style="color: #808080; text-decoration: none;">
<code>LLaDA-8B-Base</code>
</a>.
</p>
| | MMLU | MMLU&#8209;Pro | ARC&#8209;C | Hellaswag | GSM8K | Math | GPQA | HumanEval | MBPP |
|:----------------|:----:|:---------:|:-----:|:-----------:|:-----:|:----:|:----:|:-----------:|:----:|
| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(reported) | 65.5 | 37.0 | 88.5 | 74.6 | 69.4 | 31.9 | 33.3 | 49.4 | 41.0 |
| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(evaluated) | 67.3 | 36.2 | 86.6 | 76.7 | 81.1 | | | 65.0 | 70.2 |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 2. Evaluation results of
<a href="https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct" style="color: #808080; text-decoration: none;">
<code>LLaDA-8B-Instruct</code>
</a>.
</p>