1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -0,0 +1,187 @@
# Dream
> 📄 Paper: [Dream 7B: Diffusion Large Language Models](https://arxiv.org/abs/2508.15487) 💻 Code: [github.com/DreamLM/Dream](https://github.com/DreamLM/Dream)
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **Dream**.
## Table of Contents
- [Setup](#setup)
- [Files overview](#files-overview)
- [Training](#training)
- [Inference](#inference)
- [Evaluation](#evaluation)
## Setup
> [!IMPORTANT]
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
>
## Files overview
```
# tools relevant with Dream
dllm/pipelines/dream
├── __init__.py # Package initialization
├── models/
│ ├── configuration_dream.py # Dream model configuration
│ ├── generation_utils.py # Diffusion-based generation logic
│ ├── modeling_dream.py # Core Dream model architecture
│ └── tokenization_dream.py # Tokenizer implementation for Dream
├── generator.py # Inference logic
├── trainer.py # Training logic (pretraining and SFT)
└── utils.py # Auxiliary utilities and helper functions
# example entry points for training / inference / evaluation
examples/dream
├── chat.py # Interactive inference example
├── eval.sh # Automatic evaluation script
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
└── sft.py # Supervised finetuning example
```
<!-- > [!NOTE]
> We slightly modified [`modeling_dream.py`](/dllm/pipelines/dream/models/modeling_dream.py) so that the `model.forward()` supports 2-D attention masks. We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
>
> We fixed bugs in `chat_template` and standardize `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
## Training
### Finetuning
For example, to SFT [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) for instruction following on 8 GPUs, run:
```shell
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/dream/sft.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
```shell
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
<!-- **Reproducing [Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Base-7B)**. We tried our best to reproduce Dream-v0-Instruct-7B by finetuning Dream-v0-Base-7B using our training pipeline on the public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
#### Reproducing [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)
We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) by finetuning [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) using our training pipeline on the public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
```shell
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
--num_proc 64
# train on 24*8=192 A100s with FSDP, take about 8 hours
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "data/sft/dream/tulu-3-sft-mixture" \
--load_preprocessed_data True \
--output_dir "models/Dream-7B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
--max_length 2048 \
--truncation "right" \
--group_by_length True \
--num_train_epochs 5 \
--learning_rate 1e-5 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 2 \
--eval_on_start False \
--eval_steps 0.1 \
--save_steps 0.05
```
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
### Pretraining
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
```shell
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/pt.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "mlfoundations/dclm-baseline-1.0" \
--output_dir "models/Dream-7B-PT/dclm-baseline-1.0" \
--max_length 1024 \
--max_steps 2000 \
--learning_rate 3e-4
```
## Inference
We support batch inference for standard generation and infilling:
<!-- See [`examples/dream/generate.py`](/examples/dream/generate.py) for a full example: -->
```shell
python examples/dream/generate.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
```
We also support interactive multi-turn dialogue with visualization:
```shell
python examples/dream/chat.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
```
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
```shell
# Use model_args to adjust the generation arguments for evalution.
accelerate launch --num_processes 4 \
dllm/pipelines/dream/eval.py \
--tasks "mmlu_pro" \
--model "dream" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=Dream-org/Dream-v0-Instruct-7B,mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
```
To automatically evaluate [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) and [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on all benchmarks, run:
```shell
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" --instruct True
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Base-7B" --instruct False
```
### Evaluation results
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [Dream](https://github.com/DreamLM/Dream) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
| | MMLU | BBH | ARC&#8209;C | ARC&#8209;E | Hellaswag | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | RACE | Countdown | Sudoku | Trip&nbsp;planning |
|:----------------|:-------:|:-------:|:-----:|:-----:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:-----------:|:----:|:-----------:|
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (reported) | 69.5 | 57.9 | 59.9 | 83.9 | 73.3 | 74.8 | 75.8 | 77.2 | 39.6 | 36.6 | 57.9 | 56.2 | 44.7 | 16.0 | 81.0 | 17.8 |
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (evaluated) | | | 59.7 | 83.3 | 73.1 | 72.9 | 72.0 | 69.6 | | 35.5 | 45.8 | | 43.0 | | | |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 1. Evaluation results of
<a href="https://huggingface.co/Dream-org/Dream-v0-Base-7B" style="color: #808080; text-decoration: none;">
<code>Dream-8B-Base</code>
</a>.
</p>
| | MMLU | MMLU-Pro | GSM8K | Math | GPQA | HumanEval | MBPP | IFEval |
|:----------------|:----:|:---------:|:-----:|:----:|:----:|:-----------:|:----:|:----:|
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(reported) | 67.0 | 43.3 | 81.0 | 39.2 | 33.0 | 55.5 | 58.8 | 62.5 |
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(evaluated) | | 43.0 | 82.6 | 39.9 | 32.4 | 59.1 | | 62.3 |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 2. Evaluation results of
<a href="https://huggingface.co/Dream-org/Dream-v0-Instruct-7B" style="color: #808080; text-decoration: none;">
<code>Dream-8B-Instruct</code>
</a>.
</p>

View File

@ -0,0 +1,75 @@
"""
Interactive chat / generation script for Dream models.
Examples
--------
# Chat mode (multi-turn, chat template)
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
# Raw single-turn generation
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False
"""
import sys
from dataclasses import dataclass
import transformers
import dllm
from dllm.pipelines import dream
from dllm.tools.chat import multi_turn_chat, single_turn_generate
@dataclass
class ScriptArguments:
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
seed: int = 42
chat: bool = True
visualize: bool = True
def __post_init__(self):
# same base-path resolution logic as in generate.py
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(dream.DreamGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
temperature: float = 0.2
top_p: float = 0.95
alg: str = "entropy"
alg_temp: float = 0.0
def main():
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
if script_args.chat:
multi_turn_chat(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
else:
print("\nSingle-turn generation (no chat template).")
single_turn_generate(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nInterrupted. Bye!")
sys.exit(0)

139
dllm/examples/dream/eval.sh Normal file
View File

@ -0,0 +1,139 @@
#!/usr/bin/env bash
# ===== Mandatory for proper import and evaluation =====
export PYTHONPATH=.:$PYTHONPATH
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
# ===== Optional but recommended for stability and debugging =====
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
# ===== Input Arguments =====
model_name_or_path="Dream-org/Dream-v0-Instruct-7B"
instruct=True
num_gpu=4
while [[ $# -gt 0 ]]; do
case "$1" in
--model_name_or_path)
model_name_or_path="$2"; shift 2 ;;
--instruct)
instruct="$2"; shift 2 ;;
--num_gpu)
num_gpu="$2"; shift 2 ;;
esac
done
# ===== Conditional Configurations =====
if [ "$instruct" = "True" ]; then
echo ">>> Running in INSTRUCT mode"
common_args="--model dream --apply_chat_template"
else
echo ">>> Running in BASE mode"
common_args="--model dream"
fi
# =======================
# Generation / Instruct Tasks
# =======================
if [ "$instruct" = "True" ]; then
# Instruct Tasks
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mmlu_generative --num_fewshot 4 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mmlu_pro --num_fewshot 4 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gsm8k_cot --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=256,max_length=256,steps=256,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks minerva_math --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.0,top_p=1.0,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks humaneval_instruct_dream --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=768,max_length=768,steps=768,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mbpp_instruct --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1024,max_length=1024,steps=1024,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks ifeval --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1280,max_length=1280,steps=1280,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
else
# Base Generation Tasks
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks humaneval --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gsm8k_cot --num_fewshot 8 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=256,steps=256,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mbpp --num_fewshot 3 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks minerva_math --num_fewshot 4 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks bbh --num_fewshot 3 ${common_args} \
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
fi
# =======================
# Likelihood Tasks (Base Only)
# =======================
if [ "$instruct" != "True" ]; then
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks mmlu --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks arc_easy --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks arc_challenge --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks hellaswag --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks piqa --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks winogrande --num_fewshot 5 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
--tasks race --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
fi

View File

@ -0,0 +1,117 @@
"""
python -u examples/dream/generate.py --model_name_or_path "YOUR_MODEL_PATH"
"""
from dataclasses import dataclass
import transformers
import dllm
from dllm.tools.chat import decode_trim
from dllm.pipelines import dream
@dataclass
class ScriptArguments:
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
seed: int = 42
visualize: bool = True
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(dream.DreamGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
temperature: float = 0.2
top_p: float = 0.95
alg: str = "entropy"
alg_temp: float = 0.0
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
# Load model & tokenizer
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
tokenizer=tokenizer
)
# --- Example 1: Batch generation ---
print("\n" + "=" * 80)
print("TEST: dream.generate()".center(80))
print("=" * 80)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
for iter, s in enumerate(sequences):
print("\n" + "-" * 80)
print(f"[Case {iter}]")
print("-" * 80)
print(s.strip() if s.strip() else "<empty>")
print("\n" + "=" * 80 + "\n")
if script_args.visualize:
terminal_visualizer.visualize(outputs.histories, rich=True)
# --- Example 2: Batch fill-in-the-blanks ---
print("\n" + "=" * 80)
print("TEST: dream.infilling()".center(80))
print("=" * 80)
masked_messages = [
[
{"role": "user", "content": tokenizer.mask_token * 20},
{
"role": "assistant",
"content": "Sorry, I do not have answer to this question.",
},
],
[
{"role": "user", "content": "Please write an educational python function."},
{
"role": "assistant",
"content": "def hello_" + tokenizer.mask_token * 20 + " return",
},
],
]
inputs = tokenizer.apply_chat_template(
masked_messages,
add_generation_prompt=False,
tokenize=True,
)
outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
for iter, (i, s) in enumerate(zip(inputs, sequences)):
print("\n" + "-" * 80)
print(f"[Case {iter}]")
print("-" * 80)
print("[Masked]:\n" + tokenizer.decode(i))
print("\n[Filled]:\n" + (s.strip() if s.strip() else "<empty>"))
print("\n" + "=" * 80 + "\n")
if script_args.visualize:
terminal_visualizer.visualize(outputs.histories, rich=True)

162
dllm/examples/dream/pt.py Normal file
View File

@ -0,0 +1,162 @@
"""
Local users
------------
- 1 GPU (4bit quant & LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/dream/pt.py \
--load_in_4bit True --lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/dream/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 24 Nodes, 192 GPUs (FSDP):
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/pt.py"
"""
import os
import functools
from dataclasses import dataclass, field
import torch
import transformers
import accelerate
import dllm
from dllm.pipelines import dream
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
text_field: str = "text"
streaming: bool = True
drop_tail: bool = True
insert_eos: bool = field(
default=True,
metadata={
"help": "False when adjacent samples from the datasets are semantically coherent."
},
)
random_length_ratio: float = field(
default=0.01,
metadata={
"help": (
"The probability of randomly cut sequences during training. "
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md."
)
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = (
"models/Dream-7B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
learning_rate: float = 3e-4
max_steps: int = 2_000
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 4
eval_steps: float = 0.05
save_steps: float = 0.05
# Dream PT specific args
# Note: Since Dreams pretraining recipe is not public,
# this is only a reference implementation following LLaDAs data processing approach.
loss_weight_type: str = field(
default="cart[geo_p:0.3]",
metadata={
"help": (
"The loss weight type. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
def train():
# ----- Parse & setup --------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary for streaming dataset
if data_args.streaming:
training_args.accelerator_config.dispatch_batches = False
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ---------------------------------------------------------------
# initialize model weights from scratch
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
with dllm.utils.init_device_context_manager():
model = transformers.AutoModel.from_config(config, dtype=torch.bfloat16)
# ----- Tokenizer -----------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Optional PEFT: LoRA -------------------------------------------------
model = dllm.utils.load_peft(model=model, model_args=model_args)
# ----- Dataset -------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_pt_dataset(
data_args.dataset_args,
streaming=data_args.streaming,
)
dataset = dataset.map(
functools.partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=data_args.text_field,
seq_length=data_args.max_length,
insert_eos=data_args.insert_eos,
drop_tail=data_args.drop_tail,
),
batched=True,
remove_columns=dataset["train"].column_names,
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
)
if data_args.streaming:
dataset = dataset.shuffle(seed=training_args.seed)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dream.DreamTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
loss_weight_type=training_args.loss_weight_type,
data_collator=dream.utils.DreamPTCollator(
tokenizer,
return_tensors="pt",
padding=True,
random_length_ratio=data_args.random_length_ratio,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()

192
dllm/examples/dream/sft.py Normal file
View File

@ -0,0 +1,192 @@
"""
Local users
------------
- 1 GPU (4bit quant & LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/dream/sft.py \
--load_in_4bit True --lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/dream/sft.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (FSDP):
sbatch --gres=gpu:1 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py"
- 2 Nodes, 16 GPUs (FSDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py"
"""
import os
from dataclasses import dataclass, field
from functools import partial
import transformers
import accelerate
import dllm
from dllm.pipelines import dream
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
load_preprocessed_data: bool = False
mask_prompt_loss: bool = field(
default=True,
metadata={"help": "Whether to mask the loss on the prompt tokens"},
)
# Dream SFT specific args
perbatch_cutoff: bool = field(
default=True,
metadata={
"help": (
"Randomly pick a response length from batch and trim other responses. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
resp_cutoff_ratio: float = field(
default=0.0,
metadata={
"help": (
"The probability of randomly cutting sequences during training. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = "models/Dream-7B-SFT"
group_by_length: bool = True
# Dream SFT specific args
loss_weight_type: str = field(
default="cart[geo_p:0.3]",
metadata={
"help": (
"The loss weight type. "
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
)
},
)
# ------------------------------------------------------------------------------
# SFT mapping function
# ------------------------------------------------------------------------------
def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool) -> dict:
"""
Build Dream SFT features from a chat-format row.
Returns:
dict with input_ids, labels, attention_mask, prompt_len
"""
prompt_tokens = tokenizer.apply_chat_template(
row["messages"][:-1], tokenize=True, add_generation_prompt=True
)
prompt_response_tokens = tokenizer.apply_chat_template(
row["messages"], tokenize=True, add_generation_prompt=False
)
labels = prompt_response_tokens.copy()
if mask_prompt_loss:
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
else:
# When training on all tokens, prepend a BOS token (if missing)
# so the model can predict the first token.
if prompt_response_tokens[0] != tokenizer.bos_token_id:
bos = [tokenizer.bos_token_id]
prompt_response_tokens = bos + prompt_response_tokens
prompt_tokens = bos + prompt_tokens
labels = bos + labels
labels[0] = -100 # ignore loss on BOS
return {
"input_ids": prompt_response_tokens,
"labels": labels,
"prompt_len": len(prompt_tokens),
}
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary when batch contains customized fields
training_args.remove_unused_columns = False
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_sft_dataset(
data_args.dataset_args,
load_preprocessed_data=data_args.load_preprocessed_data,
)
if not data_args.load_preprocessed_data:
map_fn = partial(
sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
dataset = dataset.map(
map_fn,
num_proc=data_args.num_proc,
desc="Mapping dataset to SFT format",
)
# truncate / filter long sequences if needed
dataset = dllm.utils.post_process_dataset(dataset, data_args)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dream.DreamTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
loss_weight_type=training_args.loss_weight_type,
data_collator=dream.utils.DreamSFTCollator(
tokenizer,
return_tensors="pt",
padding=True,
perbatch_cutoff=data_args.perbatch_cutoff,
resp_cutoff_ratio=data_args.resp_cutoff_ratio,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()