1127 update to latest
This commit is contained in:
206
dllm/examples/llada/README.md
Normal file
206
dllm/examples/llada/README.md
Normal file
@ -0,0 +1,206 @@
|
||||
# LLaDA
|
||||
|
||||
> 📄 Paper: [Large Language Diffusion Models](https://arxiv.org/abs/2502.09992) | 💻 Code: [github.com/ML-GSAI/LLaDA](https://github.com/ML-GSAI/LLaDA)
|
||||
|
||||
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **LLaDA**.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Inference](#inference)
|
||||
- [Evaluation](#evaluation)
|
||||
|
||||
## Setup
|
||||
> [!IMPORTANT]
|
||||
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
|
||||
>
|
||||
> **MoE checkpoints:** For models like [`LLaDA-MoE-7B-A1B-Base`](https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base), set `"model_type"` to `"lladamoe"` in the checkpoint’s `config.json`:
|
||||
> ```diff
|
||||
> - "model_type": "llada",
|
||||
> + "model_type": "lladamoe",
|
||||
> ```
|
||||
>
|
||||
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# tools relevant with LLaDA
|
||||
dllm/pipelines/llada
|
||||
├── __init__.py # Package initialization
|
||||
├── models/
|
||||
│ ├── configuration_lladamoe.py # LLaDA-MoE model configuration
|
||||
│ ├── configuration_llada.py # LLaDA model configuration
|
||||
│ ├── modeling_lladamoe.py # LLaDA-MoE model architecture
|
||||
│ └── modeling_llada.py # LLaDA model architecture
|
||||
├── generator.py # Inference logic
|
||||
└── trainer.py # Training logic (pretraining and finetuning)
|
||||
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/llada
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
<!-- > [!NOTE] -->
|
||||
<!-- > - We fixed attention mask bugs in [`modeling_lladamoe.py`](/dllm/pipelines/llada/models/modeling_lladamoe.py) and [`modeling_llada.py`](/dllm/pipelines/llada/models/modeling_llada.py). We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
|
||||
>
|
||||
> - We fixed bugs in `chat_template` and assign `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
|
||||
|
||||
<!-- > [!WARNING]
|
||||
> Before loading MoE checkpoints (e.g., [inclusionAI/LLaDA-MoE-7B-A1B-Base](https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base)), first overwrite the `model_type` field from `inclusionAI/LLaDA-MoE-7B-A1B-Base/config.json`:
|
||||
> ```diff
|
||||
> - "model_type": "llada",
|
||||
> + "model_type": "lladamoe",
|
||||
> ``` -->
|
||||
|
||||
## Training
|
||||
### Finetuning
|
||||
|
||||
For example, to SFT [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) for instruction following on 8 GPUs, run:
|
||||
```shell
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/llada/sft.py \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
|
||||
```shell
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
|
||||
<!-- **Reproducing [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)**. Though LLaDA is trained on proprietary data, we tried our best to reproduce LLaDA-8B-Instruct by finetuning LLaDA-8B-Base using our training pipeline on public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
|
||||
|
||||
#### Reproducing [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)
|
||||
Though LLaDA is trained on proprietary data, we tried our best to reproduce [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) by finetuning [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) using our training pipeline on public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
|
||||
|
||||
```shell
|
||||
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
|
||||
python dllm/tools/preprocess_sft_dataset.py \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--sft_map_fn_path "dllm.utils.default_sft_map_fn" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "data/sft/llada/tulu-3-sft-mixture" \
|
||||
--num_proc 64
|
||||
|
||||
# train on 24*8=192 A100s with FSDP, take about 8 hours
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "data/sft/llada/tulu-3-sft-mixture" \
|
||||
--load_preprocessed_data True \
|
||||
--output_dir "models/LLaDA-8B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
|
||||
--max_length 2048 \
|
||||
--truncation "right" \
|
||||
--group_by_length True \
|
||||
--num_train_epochs 5 \
|
||||
--learning_rate 1e-5 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--eval_on_start False \
|
||||
--eval_steps 0.1 \
|
||||
--save_steps 0.05
|
||||
```
|
||||
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
|
||||
|
||||
|
||||
### Pretraining
|
||||
|
||||
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
|
||||
```shell
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/pt.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "mlfoundations/dclm-baseline-1.0" \
|
||||
--output_dir "models/LLaDA-8B-PT/dclm-baseline-1.0" \
|
||||
--max_length 1024 \
|
||||
--max_steps 2000 \
|
||||
--learning_rate 3e-4
|
||||
```
|
||||
|
||||
## Inference
|
||||
We support batch inference for standard generation and infilling:
|
||||
<!-- See [`examples/llada/generate.py`](/examples/llada/generate.py) for a full example: -->
|
||||
```shell
|
||||
python examples/llada/generate.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
|
||||
```
|
||||
We also support interactive multi-turn dialogue with visualization:
|
||||
<!-- See [`examples/llada/chat.py`](/examples/llada/chat.py) for a full example. -->
|
||||
```shell
|
||||
python examples/llada/chat.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [MMLU-Pro](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/llada/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "llada" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
|
||||
```
|
||||
|
||||
To automatically evaluate [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) and [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Instruct --instruct True
|
||||
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Base --instruct False
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results refer to those from the original paper.
|
||||
> All evaluation parameters follow the configurations in the [LLaDA](https://github.com/ML-GSAI/LLaDA) repository.
|
||||
> Since the original evaluations were conducted with OpenCompass, task settings were adjusted for compatibility with the LLaDA model under `lm-eval`.
|
||||
> Complete evaluation results will be released soon. -->
|
||||
|
||||
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [LLaDA](https://github.com/ML-GSAI/LLaDA) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
|
||||
|
||||
<!-- <div align="center" style="min-width:1300px;"> -->
|
||||
|
||||
| | MMLU | BBH | ARC‑C | Hellaswag | TruthfulQA | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | CEval | CMMLU |
|
||||
|:----------------|:----:|:---:|:-----:|:-----------:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:------:|
|
||||
| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(reported)| 65.9 | 49.7 | 45.9 | 70.5 | 46.1 | 74.8 | 73.6 | 70.3 | 31.4 | 25.2 | 35.4 | 40.0 | 70.5 | 69.9 |
|
||||
| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(evaluated)| 65.8 | – | 45.7 | 69.3 | 45.6 | 70.7 | 70.6 | 70.4 | – | – | 32.3 | 38.8 | 70.2 | 69.9 |
|
||||
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/GSAI-ML/LLaDA-8B-Base" style="color: #808080; text-decoration: none;">
|
||||
<code>LLaDA-8B-Base</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
| | MMLU | MMLU‑Pro | ARC‑C | Hellaswag | GSM8K | Math | GPQA | HumanEval | MBPP |
|
||||
|:----------------|:----:|:---------:|:-----:|:-----------:|:-----:|:----:|:----:|:-----------:|:----:|
|
||||
| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(reported) | 65.5 | 37.0 | 88.5 | 74.6 | 69.4 | 31.9 | 33.3 | 49.4 | 41.0 |
|
||||
| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(evaluated) | 67.3 | 36.2 | 86.6 | 76.7 | 81.1 | – | – | 65.0 | 70.2 |
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 2. Evaluation results of
|
||||
<a href="https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct" style="color: #808080; text-decoration: none;">
|
||||
<code>LLaDA-8B-Instruct</code>
|
||||
</a>.
|
||||
</p>
|
||||
74
dllm/examples/llada/chat.py
Normal file
74
dllm/examples/llada/chat.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Interactive chat / generation script for LLaDA models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Chat mode (multi-turn, chat template)
|
||||
python -u examples/llada/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
|
||||
# Raw single-turn generation
|
||||
python -u examples/llada/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import llada
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
151
dllm/examples/llada/eval.sh
Normal file
151
dllm/examples/llada/eval.sh
Normal file
@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
|
||||
# ===== Input Arguments =====
|
||||
model_name_or_path="GSAI-ML/LLaDA-8B-Instruct"
|
||||
instruct=True
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--instruct)
|
||||
instruct="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ===== Conditional Configurations =====
|
||||
if [ "$instruct" = "True" ]; then
|
||||
echo ">>> Running in INSTRUCT mode"
|
||||
common_args="--model llada --apply_chat_template"
|
||||
else
|
||||
echo ">>> Running in BASE mode"
|
||||
common_args="--model llada"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Generation Tasks
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" = "True" ]; then
|
||||
# Instruct Generation Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gsm8k_cot --num_fewshot 8 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks bbh --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks minerva_math --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks humaneval_instruct --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mbpp_llada_instruct --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
else
|
||||
# Base Generation Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gsm8k --num_fewshot 8 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks bbh --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks minerva_math --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks humaneval --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mbpp --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Likelihood Tasks
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" = "True" ]; then
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=3,steps=3,block_length=3,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=3,steps=3,block_length=3,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=5,steps=5,block_length=5,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gpqa_n_shot_gen --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=32,steps=32,block_length=32,cfg=0.0"
|
||||
|
||||
else
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks truthfulqa_mc2 --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=2.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks arc_challenge --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks hellaswag --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks winogrande --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks piqa --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mmlu --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks cmmlu --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks ceval-valid --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
fi
|
||||
116
dllm/examples/llada/generate.py
Normal file
116
dllm/examples/llada/generate.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""
|
||||
python -u examples/llada/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: llada.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
|
||||
# --- Example 2: Batch fill-in-the-blanks ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: llada.infilling()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
masked_messages = [
|
||||
[
|
||||
{"role": "user", "content": tokenizer.mask_token * 20},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Sorry, I do not have answer to this question.",
|
||||
},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Please write an educational python function."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "def hello_" + tokenizer.mask_token * 20 + " return",
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
masked_messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, (i, s) in enumerate(zip(inputs, sequences)):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print("[Masked]:\n" + tokenizer.decode(i))
|
||||
print("\n[Filled]:\n" + (s.strip() if s.strip() else "<empty>"))
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
174
dllm/examples/llada/pt.py
Normal file
174
dllm/examples/llada/pt.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/llada/pt.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/llada/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
# Uses only the configuration from model_name_or_path to initialize the model from scratch
|
||||
model_name_or_path: str = (
|
||||
"GSAI-ML/LLaDA-8B-Base" # "inclusionAI/LLaDA-MoE-7B-A1B-Base"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
text_field: str = "text"
|
||||
streaming: bool = True
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
random_length_ratio: float = field(
|
||||
default=0.01,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cut sequences during training. "
|
||||
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training for reference."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/LLaDA-8B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
learning_rate: float = 3e-4
|
||||
max_steps: int = 2_000
|
||||
per_device_train_batch_size: int = 4
|
||||
gradient_accumulation_steps: int = 4
|
||||
eval_steps: float = 0.05
|
||||
save_steps: float = 0.05
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary for streaming dataset
|
||||
if data_args.streaming:
|
||||
training_args.accelerator_config.dispatch_batches = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
# initialize model weights from scratch
|
||||
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(
|
||||
config, dtype=torch.bfloat16, init_params=True
|
||||
)
|
||||
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optional PEFT: LoRA ----------------------------------------------------
|
||||
model = dllm.utils.load_peft(model=model, model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
@dataclass
|
||||
class LLaDAPTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
# Reference: https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training
|
||||
# By default, 1% of the pre-training data are truncated to a random length
|
||||
random_length_ratio: float = 0.01
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
if torch.rand(1) < self.random_length_ratio:
|
||||
random_length = torch.randint(
|
||||
1, outputs["input_ids"].shape[1] + 1, (1,)
|
||||
)
|
||||
for key in ["input_ids", "labels", "attention_mask"]:
|
||||
if key in outputs:
|
||||
outputs[key] = outputs[key][:, :random_length]
|
||||
# Check if attention_mask is all ones and set it to None
|
||||
if torch.all(outputs["attention_mask"] == 1):
|
||||
outputs.pop("attention_mask")
|
||||
return outputs
|
||||
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=LLaDAPTCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
random_length_ratio=data_args.random_length_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
120
dllm/examples/llada/sft.py
Normal file
120
dllm/examples/llada/sft.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/llada/sft.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/llada/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/LLaDA-8B-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
group_by_length: bool = True
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
dllm.utils.default_sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=dllm.utils.NoAttentionMaskCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user