1127 update to latest
This commit is contained in:
190
dllm/examples/bert/README.md
Normal file
190
dllm/examples/bert/README.md
Normal file
@ -0,0 +1,190 @@
|
||||
# Generative BERT
|
||||
|
||||
[](https://huggingface.co/collections/dllm-collection/bert-chat)
|
||||
[](https://api.wandb.ai/links/asap-zzhou/101h5xvg)
|
||||
|
||||
This directory provides two key sets of resources:
|
||||
|
||||
1. **Toy Examples ([Warmup](#warmup)):** Scripts for pretraining and SFTing any BERT-style model on small datasets to generate text.
|
||||
2. **Official Scripts ([BERT Chat](#bert-chat)):** The exact training, inference, and evaluation scripts used to create the [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) checkpoints, two BERTs finetuned as Chatbots. For a deep dive into experimental results, lessons learned, and more reproduction details, please see our full [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
<p align="center" style="margin-top: 15px;">
|
||||
<img src="/examples/bert/assets/chat.gif" alt="chat" width="70%">
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>
|
||||
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
|
||||
</em>
|
||||
</p>
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/bert
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
|
||||
## Warmup
|
||||
|
||||
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text.
|
||||
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
|
||||
|
||||
### Pretrain
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "Trelis/tiny-shakespeare" \
|
||||
--text_field "Text" \
|
||||
--insert_eos False \
|
||||
--max_length 128 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tiny-shakespeare"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
# just press enter (empty prompt) if you want the model to generate text from scratch
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
|
||||
--chat False --remasking "random" --steps 128 --max_new_tokens 128
|
||||
```
|
||||
|
||||
### SFT
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "tatsu-lab/alpaca" \
|
||||
--max_length 512 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/alpaca"
|
||||
```
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/alpaca/checkpoint-final" --chat True
|
||||
```
|
||||
|
||||
## BERT Chat
|
||||
Here we show the exact commands we use to train and interact with the BERT Chat models:
|
||||
[`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0).
|
||||
For training curves and other details, please see [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
### Training
|
||||
|
||||
To reproduce [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-base/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
To reproduce [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" --chat True
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/bert/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "bert" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=dllm-collection/ModernBERT-large-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
```
|
||||
|
||||
To automatically evaluate [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-base-chat-v0"
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0"
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results are taken from the original paper.
|
||||
> Because the original work does not fully disclose its evaluation techniques or implementation tricks, we reproduce the setup using the best available methods. As a result, our reproduced scores may show a small residual gap relative to the reported numbers. -->
|
||||
|
||||
<!-- | [`GPT-2`](https://huggingface.co/openai-community/gpt2)(reported) | 0.460 | – | | | | | | | |
|
||||
| [`GPT-2`](https://huggingface.co/openai-community/gpt2)(evaluated) | 0.438 | 0.020 | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(reported) | 0.555 | – | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(evaluated) | 0.549 | 0.021 | | | | | | | | -->
|
||||
<!-- <div align="center" style="min-width:1500px;"> -->
|
||||
|
||||
| | LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU |
|
||||
|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
||||
| [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0)(evaluated) | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 |
|
||||
| [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0)(evaluated) | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 |
|
||||
| [`Qwen1.5-0.5B`](https://huggingface.co/Qwen/Qwen1.5-0.5B)(<ins>reported</ins> & evaluated) | 48.6 | <ins>22.0</ins> | <ins>50.5</ins> | <ins>18.3</ins> | <ins>3.1</ins> | <ins>39.2</ins> | 55.0 | 48.2 | <ins>46.6</ins> |
|
||||
| [`Qwen1.5-0.5B-Chat`](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat)(<ins>reported</ins> & evaluated) | 41.2 | <ins>11.3</ins> | <ins>37.2</ins> | 18.2 | 2.1 | <ins>35.0</ins> | 52.0 | 36.9 | 32.2 |
|
||||
| [`gpt2`](https://huggingface.co/openai-community/gpt2)(<ins>reported</ins> & evaluated) | <ins>46.0</ins> | 0.7 | 24.7 | 6.9 | 1.8 | 22.9 | 51.6 | 31.1 | 25.2 |
|
||||
| [`gpt2-medium`](https://huggingface.co/openai-community/gpt2-medium)(<ins>reported</ins> & evaluated) | <ins>55.5</ins> | 2.1 | 24.6 | 17.8 | 1.4 | 22.9 |53.1 | 39.4 | 0.3 |
|
||||
|
||||
|
||||
<p align="left" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-base-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-large-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B-Chat</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/openai-community/gpt2" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2</code>
|
||||
</a>, and
|
||||
<a href="https://huggingface.co/openai-community/gpt2-medium" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2-medium</code>
|
||||
</a>.
|
||||
<ins>Underlined entries</ins> are results from official reports: <a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf" style="color: #808080; text-decoration: none;">GPT-2 paper</a>, <a href="https://qwen.ai/blog?id=qwen1.5" style="color: #808080; text-decoration: none;">Qwen 1.5 blog</a>, and <a href="https://huggingface.co/Qwen/Qwen2-0.5B-Instruct" style="color: #808080; text-decoration: none;">Qwen2-0.5B-Instruct model card</a>. All other results are evaluated using our framework.
|
||||
</p>
|
||||
BIN
dllm/examples/bert/assets/chat.gif
Normal file
BIN
dllm/examples/bert/assets/chat.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.1 MiB |
71
dllm/examples/bert/chat.py
Normal file
71
dllm/examples/bert/chat.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Interactive chat / generation script for Bert models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Raw multi-turn generation (default)
|
||||
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import llada
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
50
dllm/examples/bert/eval.sh
Normal file
50
dllm/examples/bert/eval.sh
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
# ===== Basic Settings =====
|
||||
model_name_or_path="dllm-collection/ModernBERT-large-chat-v0"
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ===== Common arguments =====
|
||||
common_args="--model bert --apply_chat_template" # BERT model is default to use chat template
|
||||
|
||||
# =======================
|
||||
# BERT Instruct (Chat) Tasks
|
||||
# =======================
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks winogrande --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
73
dllm/examples/bert/generate.py
Normal file
73
dllm/examples/bert/generate.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
python -u examples/bert/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 64
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: bert.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
127
dllm/examples/bert/pt.py
Normal file
127
dllm/examples/bert/pt.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "Trelis/tiny-shakespeare"
|
||||
text_field: str = "Text"
|
||||
max_length: int = 128
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-base/tiny-shakespeare"
|
||||
num_train_epochs: int = 20
|
||||
learning_rate: float = 1e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
127
dllm/examples/bert/sft.py
Normal file
127
dllm/examples/bert/sft.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/sft.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
max_length: int = 512
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-large/alpaca"
|
||||
group_by_length: bool = True
|
||||
learning_rate: float = 1e-4
|
||||
num_train_epochs: int = 20
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
dllm.utils.default_sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=dllm.utils.NoAttentionMaskCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
187
dllm/examples/dream/README.md
Normal file
187
dllm/examples/dream/README.md
Normal file
@ -0,0 +1,187 @@
|
||||
# Dream
|
||||
|
||||
> 📄 Paper: [Dream 7B: Diffusion Large Language Models](https://arxiv.org/abs/2508.15487) | 💻 Code: [github.com/DreamLM/Dream](https://github.com/DreamLM/Dream)
|
||||
|
||||
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **Dream**.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Inference](#inference)
|
||||
- [Evaluation](#evaluation)
|
||||
|
||||
## Setup
|
||||
> [!IMPORTANT]
|
||||
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
|
||||
>
|
||||
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# tools relevant with Dream
|
||||
dllm/pipelines/dream
|
||||
├── __init__.py # Package initialization
|
||||
├── models/
|
||||
│ ├── configuration_dream.py # Dream model configuration
|
||||
│ ├── generation_utils.py # Diffusion-based generation logic
|
||||
│ ├── modeling_dream.py # Core Dream model architecture
|
||||
│ └── tokenization_dream.py # Tokenizer implementation for Dream
|
||||
├── generator.py # Inference logic
|
||||
├── trainer.py # Training logic (pretraining and SFT)
|
||||
└── utils.py # Auxiliary utilities and helper functions
|
||||
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/dream
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
<!-- > [!NOTE]
|
||||
> We slightly modified [`modeling_dream.py`](/dllm/pipelines/dream/models/modeling_dream.py) so that the `model.forward()` supports 2-D attention masks. We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
|
||||
>
|
||||
> We fixed bugs in `chat_template` and standardize `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
|
||||
|
||||
## Training
|
||||
|
||||
### Finetuning
|
||||
For example, to SFT [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) for instruction following on 8 GPUs, run:
|
||||
```shell
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/sft.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
|
||||
```shell
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
|
||||
<!-- **Reproducing [Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Base-7B)**. We tried our best to reproduce Dream-v0-Instruct-7B by finetuning Dream-v0-Base-7B using our training pipeline on the public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
|
||||
#### Reproducing [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)
|
||||
We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) by finetuning [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) using our training pipeline on the public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
|
||||
|
||||
```shell
|
||||
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
|
||||
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--num_proc 64
|
||||
|
||||
# train on 24*8=192 A100s with FSDP, take about 8 hours
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--load_preprocessed_data True \
|
||||
--output_dir "models/Dream-7B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
|
||||
--max_length 2048 \
|
||||
--truncation "right" \
|
||||
--group_by_length True \
|
||||
--num_train_epochs 5 \
|
||||
--learning_rate 1e-5 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--eval_on_start False \
|
||||
--eval_steps 0.1 \
|
||||
--save_steps 0.05
|
||||
```
|
||||
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
|
||||
|
||||
### Pretraining
|
||||
|
||||
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
|
||||
```shell
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/pt.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "mlfoundations/dclm-baseline-1.0" \
|
||||
--output_dir "models/Dream-7B-PT/dclm-baseline-1.0" \
|
||||
--max_length 1024 \
|
||||
--max_steps 2000 \
|
||||
--learning_rate 3e-4
|
||||
```
|
||||
|
||||
## Inference
|
||||
We support batch inference for standard generation and infilling:
|
||||
<!-- See [`examples/dream/generate.py`](/examples/dream/generate.py) for a full example: -->
|
||||
```shell
|
||||
python examples/dream/generate.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
|
||||
```
|
||||
We also support interactive multi-turn dialogue with visualization:
|
||||
```shell
|
||||
python examples/dream/chat.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/dream/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "dream" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=Dream-org/Dream-v0-Instruct-7B,mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
```
|
||||
|
||||
To automatically evaluate [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) and [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" --instruct True
|
||||
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Base-7B" --instruct False
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [Dream](https://github.com/DreamLM/Dream) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
|
||||
|
||||
| | MMLU | BBH | ARC‑C | ARC‑E | Hellaswag | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | RACE | Countdown | Sudoku | Trip planning |
|
||||
|:----------------|:-------:|:-------:|:-----:|:-----:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:-----------:|:----:|:-----------:|
|
||||
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (reported) | 69.5 | 57.9 | 59.9 | 83.9 | 73.3 | 74.8 | 75.8 | 77.2 | 39.6 | 36.6 | 57.9 | 56.2 | 44.7 | 16.0 | 81.0 | 17.8 |
|
||||
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (evaluated) | – | – | 59.7 | 83.3 | 73.1 | 72.9 | 72.0 | 69.6 | – | 35.5 | 45.8 | – | 43.0 | – | – | – |
|
||||
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/Dream-org/Dream-v0-Base-7B" style="color: #808080; text-decoration: none;">
|
||||
<code>Dream-8B-Base</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
| | MMLU | MMLU-Pro | GSM8K | Math | GPQA | HumanEval | MBPP | IFEval |
|
||||
|:----------------|:----:|:---------:|:-----:|:----:|:----:|:-----------:|:----:|:----:|
|
||||
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(reported) | 67.0 | 43.3 | 81.0 | 39.2 | 33.0 | 55.5 | 58.8 | 62.5 |
|
||||
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(evaluated) | – | 43.0 | 82.6 | 39.9 | 32.4 | 59.1 | – | 62.3 |
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 2. Evaluation results of
|
||||
<a href="https://huggingface.co/Dream-org/Dream-v0-Instruct-7B" style="color: #808080; text-decoration: none;">
|
||||
<code>Dream-8B-Instruct</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
|
||||
75
dllm/examples/dream/chat.py
Normal file
75
dllm/examples/dream/chat.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""
|
||||
Interactive chat / generation script for Dream models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Chat mode (multi-turn, chat template)
|
||||
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
|
||||
# Raw single-turn generation
|
||||
python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import dream
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(dream.DreamGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0.2
|
||||
top_p: float = 0.95
|
||||
alg: str = "entropy"
|
||||
alg_temp: float = 0.0
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
139
dllm/examples/dream/eval.sh
Normal file
139
dllm/examples/dream/eval.sh
Normal file
@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
|
||||
# ===== Input Arguments =====
|
||||
model_name_or_path="Dream-org/Dream-v0-Instruct-7B"
|
||||
instruct=True
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--instruct)
|
||||
instruct="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# ===== Conditional Configurations =====
|
||||
if [ "$instruct" = "True" ]; then
|
||||
echo ">>> Running in INSTRUCT mode"
|
||||
common_args="--model dream --apply_chat_template"
|
||||
else
|
||||
echo ">>> Running in BASE mode"
|
||||
common_args="--model dream"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Generation / Instruct Tasks
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" = "True" ]; then
|
||||
# Instruct Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gsm8k_cot --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=256,max_length=256,steps=256,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks minerva_math --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.0,top_p=1.0,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks humaneval_instruct_dream --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=768,max_length=768,steps=768,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mbpp_instruct --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1024,max_length=1024,steps=1024,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks ifeval --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1280,max_length=1280,steps=1280,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
|
||||
else
|
||||
# Base Generation Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks humaneval --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gsm8k_cot --num_fewshot 8 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=256,steps=256,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mbpp --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks minerva_math --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks bbh --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Likelihood Tasks (Base Only)
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" != "True" ]; then
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks mmlu --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks arc_easy --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks arc_challenge --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks hellaswag --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks piqa --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks winogrande --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \
|
||||
--tasks race --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},add_bos_token=true"
|
||||
fi
|
||||
117
dllm/examples/dream/generate.py
Normal file
117
dllm/examples/dream/generate.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
python -u examples/dream/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import dream
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(dream.DreamGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
temperature: float = 0.2
|
||||
top_p: float = 0.95
|
||||
alg: str = "entropy"
|
||||
alg_temp: float = 0.0
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = dream.DreamGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: dream.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
|
||||
# --- Example 2: Batch fill-in-the-blanks ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: dream.infilling()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
masked_messages = [
|
||||
[
|
||||
{"role": "user", "content": tokenizer.mask_token * 20},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Sorry, I do not have answer to this question.",
|
||||
},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Please write an educational python function."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "def hello_" + tokenizer.mask_token * 20 + " return",
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
masked_messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, (i, s) in enumerate(zip(inputs, sequences)):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print("[Masked]:\n" + tokenizer.decode(i))
|
||||
print("\n[Filled]:\n" + (s.strip() if s.strip() else "<empty>"))
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
162
dllm/examples/dream/pt.py
Normal file
162
dllm/examples/dream/pt.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/dream/pt.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import dream
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
text_field: str = "text"
|
||||
streaming: bool = True
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
random_length_ratio: float = field(
|
||||
default=0.01,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cut sequences during training. "
|
||||
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/Dream-7B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
learning_rate: float = 3e-4
|
||||
max_steps: int = 2_000
|
||||
per_device_train_batch_size: int = 4
|
||||
gradient_accumulation_steps: int = 4
|
||||
eval_steps: float = 0.05
|
||||
save_steps: float = 0.05
|
||||
# Dream PT specific args
|
||||
# Note: Since Dream’s pretraining recipe is not public,
|
||||
# this is only a reference implementation following LLaDA’s data processing approach.
|
||||
loss_weight_type: str = field(
|
||||
default="cart[geo_p:0.3]",
|
||||
metadata={
|
||||
"help": (
|
||||
"The loss weight type. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Parse & setup --------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary for streaming dataset
|
||||
if data_args.streaming:
|
||||
training_args.accelerator_config.dispatch_batches = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ---------------------------------------------------------------
|
||||
# initialize model weights from scratch
|
||||
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(config, dtype=torch.bfloat16)
|
||||
|
||||
# ----- Tokenizer -----------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optional PEFT: LoRA -------------------------------------------------
|
||||
model = dllm.utils.load_peft(model=model, model_args=model_args)
|
||||
|
||||
# ----- Dataset -------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dream.DreamTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
loss_weight_type=training_args.loss_weight_type,
|
||||
data_collator=dream.utils.DreamPTCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
random_length_ratio=data_args.random_length_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
192
dllm/examples/dream/sft.py
Normal file
192
dllm/examples/dream/sft.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/dream/sft.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import dream
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
# Dream SFT specific args
|
||||
perbatch_cutoff: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": (
|
||||
"Randomly pick a response length from batch and trim other responses. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
resp_cutoff_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cutting sequences during training. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/Dream-7B-SFT"
|
||||
group_by_length: bool = True
|
||||
# Dream SFT specific args
|
||||
loss_weight_type: str = field(
|
||||
default="cart[geo_p:0.3]",
|
||||
metadata={
|
||||
"help": (
|
||||
"The loss weight type. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# SFT mapping function
|
||||
# ------------------------------------------------------------------------------
|
||||
def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool) -> dict:
|
||||
"""
|
||||
Build Dream SFT features from a chat-format row.
|
||||
|
||||
Returns:
|
||||
dict with input_ids, labels, attention_mask, prompt_len
|
||||
"""
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
|
||||
if mask_prompt_loss:
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can predict the first token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on BOS
|
||||
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dream.DreamTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
loss_weight_type=training_args.loss_weight_type,
|
||||
data_collator=dream.utils.DreamSFTCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
perbatch_cutoff=data_args.perbatch_cutoff,
|
||||
resp_cutoff_ratio=data_args.resp_cutoff_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
3
dllm/examples/editflow/README.md
Normal file
3
dllm/examples/editflow/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
Work in progress.
|
||||
|
||||
Please see [`examples/editflow/bert/README.md`](/examples/editflow/bert/README.md) for examples of finetuning BERT with EditFlow.
|
||||
162
dllm/examples/editflow/_README.md
Normal file
162
dllm/examples/editflow/_README.md
Normal file
@ -0,0 +1,162 @@
|
||||
# Edit Flows
|
||||
|
||||
> **Reference**
|
||||
> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018)
|
||||
|
||||
This directory provides an educational reference for training EditFlow models. It demonstrates how to adapt open-weight DLLMs—such as [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)—to support *insertion*, *deletion*, beyond the standard *substitution*(`mask`->`tokens`) operations. It also includes examples for training (pretraining and finetuning) EditFlow models from scratch.
|
||||
|
||||
> [!NOTE]
|
||||
> - Examples are available for both LLaDA and Dream, but this README focuses on adapting open-weight LLaDA for edit operations ([`adapt_llada.py`](/examples/editflow/adapt_llada.py)) and reusing its architecture for training from scratch ([`pt_llada.py`](/examples/editflow/pt_llada.py) -> [`sft_llada.py`](/examples/editflow/sft_llada.py)).
|
||||
> - While `EditFlowCollator` supports custom `x0`, this README uses a fixed-length (128) masks as `x0`. The trained model generates text by replacing masks, deleting redundant ones, and inserting tokens as needed. To change the default `x0` distribution (e.g., empty sequences for [OneFlow](https://arxiv.org/abs/2510.03506)-like insertion-only generation), pass `--x0_sampler "empty"`.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Adapting LLaDA-8B-Instruct to support insertion and deletion](#adapting-llada-8b-instruct-to-support-insertion-and-deletion)
|
||||
- [Pretraining & Finetuning from scratch](#pretraining--finetuning-from-scratch)
|
||||
- [Sampling](#sampling)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
|
||||
## Setup
|
||||
> [!IMPORTANT]
|
||||
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
|
||||
|
||||
## Files overview
|
||||
```
|
||||
dllm/pipelines/editflow
|
||||
├── __init__.py # Package initialization
|
||||
├── models
|
||||
│ ├── dream
|
||||
│ │ └── modelling_dream.py # EditFlowDream: architecture based on Dream
|
||||
│ └── llada
|
||||
│ └── modelling_llada.py # EditFlowLLaDA: architecture based on LLaDA
|
||||
├── trainer.py
|
||||
└── utils.py
|
||||
|
||||
# example entry point for training / sampling
|
||||
examples/editflow
|
||||
├── adapt_dream.py # Example of adapting Dream for EditFlow directly
|
||||
├── adapt_llada.py # Example of adapting LLaDA for EditFlow directly
|
||||
├── generate.py # Generation example
|
||||
├── pt_dream.py # EditFlowDream pretraining example
|
||||
├── pt_llada.py # EditFlowLLaDA pretraining example
|
||||
├── pt.py # Pretraining function
|
||||
├── README.md # Documentation (you are here)
|
||||
├── sft_dream.py # EditFlowDream SFT example
|
||||
├── sft_llada.py # EditFlowLLaDA SFT example
|
||||
└── sft.py # Supervised finetuning function
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Adapting [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) to support *insertion* and *deletion*
|
||||
|
||||
The original LLaDA model generated text by iteratively substituting the given `<mask>` tokens to real tokens.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/ML-GSAI/LLaDA/blob/main/imgs/example_gradio.gif" alt="LLaDA demo" width="80%">
|
||||
</p>
|
||||
<p align="center"><em>Figure: Example Gradio demo for LLaDA.</em></p>
|
||||
|
||||
However, LLaDA supports only substitution. This example shows how to adapt it so that, during decoding, the model can not only replace fixed-length masks (e.g., 128 tokens) with real text but also insert new tokens and delete unnecessary masks adaptively:
|
||||
|
||||
```shell
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/editflow/adapt_llada.py \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
|
||||
--lm_head_key "model.transformer.ff_out" \
|
||||
--init_editflow_from_src True \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 5e-5
|
||||
```
|
||||
|
||||
If you are using slurm and want to train across, for example, four nodes (32 GPUs total), run:
|
||||
```shell
|
||||
sbatch --nodes=4 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/adapt_llada.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \
|
||||
--lm_head_key "model.transformer.ff_out" \
|
||||
--init_editflow_from_src True \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 5e-5
|
||||
```
|
||||
|
||||
After training, you can use the [generate.py](/examples/editflow/generate.py) scripts to provide a visualized decoding trace to see how the model performs *insertion* and *deletion* beyond regular mask *substitutions*. See [Sampling](#sampling) for details.
|
||||
|
||||
|
||||
### Pretraining & Finetuning from scratch
|
||||
You can also train an EditFlow model from scratch (pretrain → SFT) without adapting an existing DLLM.
|
||||
|
||||
Pretrain on a subset of [mlfoundations/dclm-baseline-1.0](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) using 192 GPUs (24x8) and FSDP:
|
||||
|
||||
```shell
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/pt_llada.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "mlfoundations/dclm-baseline-1.0" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--max_steps 2000 \
|
||||
--learning_rate 3e-4
|
||||
```
|
||||
|
||||
Finetune on a subset of [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) using 8 GPUS and FSDP for better instruction following:
|
||||
|
||||
```shell
|
||||
# you can also run locally with `accelerate ...`
|
||||
sbatch --nodes=1 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/sft_llada.py" \
|
||||
--model_name_or_path "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0/checkpoint-final" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]" \
|
||||
--output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \
|
||||
--x0_sampler "masks[length:128]" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 5e-5
|
||||
```
|
||||
|
||||
## Sampling
|
||||
|
||||
After training, you can visualize how the model performs mask substitution, insertion, and deletion during generation with [generate.py](/examples/editflow/generate.py). Inserted tokens appear <span style="color:blue; font-weight:bold">blue</span>, and tokens substituted from `<mask>` appear <span style="color:black; font-weight:bold">black</span>, and deleted tokens are shown with a strikethrough before they disappear.
|
||||
|
||||
```shell
|
||||
# Generate a long sequence to visualize insertions after 128 <mask> tokens
|
||||
python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
|
||||
--tau 0.02 --mask_length 128 --seed 7070 \
|
||||
--prompt "write a romantic story" --make_gif
|
||||
|
||||
# Generate a short sequence to visualize deletions after 128 <mask> tokens
|
||||
python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \
|
||||
--tau 0.02 --mask_length 128 --seed 7070 \
|
||||
--prompt "write a single-sentence romantic story" --make_gif
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/editflow/assets/deletion.gif" alt="EditFlow deletion demo" width="95%">
|
||||
</p>
|
||||
<p align="center"><em>Figure: Deletion & Substitution trace</code></em></p>
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/editflow/assets/insertion.gif" alt="LLaDA demo" width="95%">
|
||||
</p>
|
||||
<p align="center"><em>Figure: Inserction & Substitution trace</em></p>
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This Edit Flows implementation is inspired by https://github.com/TheMatrixMaster/edit-flows-demo.
|
||||
BIN
dllm/examples/editflow/assets/all.gif
Normal file
BIN
dllm/examples/editflow/assets/all.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.3 MiB |
BIN
dllm/examples/editflow/assets/deletion.gif
Normal file
BIN
dllm/examples/editflow/assets/deletion.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.7 MiB |
BIN
dllm/examples/editflow/assets/insertion.gif
Normal file
BIN
dllm/examples/editflow/assets/insertion.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.4 MiB |
77
dllm/examples/editflow/bert/README.md
Normal file
77
dllm/examples/editflow/bert/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
# Edit Flows - BERT
|
||||
|
||||
> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018)
|
||||
|
||||
|
||||
## Warmup
|
||||
|
||||
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text with EditFlow.
|
||||
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
|
||||
|
||||
### Pretrain
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
|
||||
```shell
|
||||
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/bert/pt.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "Trelis/tiny-shakespeare" \
|
||||
--text_field "Text" \
|
||||
--insert_eos False \
|
||||
--max_length 128 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/tiny-shakespeare"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
PYTHONPATH=. python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
|
||||
--tau 0.01 --mask_length 64 --seed 42 --make_gif
|
||||
|
||||
# see `decode_trace.gif`
|
||||
```
|
||||
|
||||
|
||||
### SFT
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
|
||||
```shell
|
||||
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/editflow/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "tatsu-lab/alpaca" \
|
||||
--max_length 512 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/alpaca"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
PYTHONPATH=. python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow/ModernBERT-large/alpaca/checkpoint-final" \
|
||||
--prompt "Could you please write a poem for me?" --tau 0.01 --mask_length 64 --seed 42 --make_gif
|
||||
|
||||
# see `decode_trace.gif`
|
||||
```
|
||||
|
||||
<!-- ```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/editflow/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
``` -->
|
||||
48
dllm/examples/editflow/bert/pt.py
Normal file
48
dllm/examples/editflow/bert/pt.py
Normal file
@ -0,0 +1,48 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import pt as editflow_pt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_pt.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
lm_head_key: str = "decoder"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_pt.DataArguments):
|
||||
dataset_args: str = "Trelis/tiny-shakespeare"
|
||||
text_field: str = "Text"
|
||||
max_length: int = 128
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_pt.TrainingArguments):
|
||||
output_dir: str = "models/EditFlow/ModernBERT-large/tiny-shakespeare"
|
||||
num_train_epochs: float = 20
|
||||
learning_rate: float = 3e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
x0_sampler: str = "masks[length:64]"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_pt.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
|
||||
)
|
||||
44
dllm/examples/editflow/bert/sft.py
Normal file
44
dllm/examples/editflow/bert/sft.py
Normal file
@ -0,0 +1,44 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
lm_head_key: str = "decoder"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
max_length: int = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = "models/EditFlow/ModernBERT-large/alpaca"
|
||||
num_train_epochs: float = 20
|
||||
learning_rate: float = 3e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
x0_sampler: str = "masks[length:64]"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
|
||||
)
|
||||
88
dllm/examples/editflow/dream/adapt.py
Normal file
88
dllm/examples/editflow/dream/adapt.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/dream/adapt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/dream/adapt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/adapt.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/adapt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B"
|
||||
lm_head_key: str = "lm_head"
|
||||
init_editflow_from_src: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-Dream-7B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
# Create EditFlow model (bf16 init on CUDA)
|
||||
ef_cfg = dllm.pipelines.editflow.EditFlowDreamConfig.from_pretrained(
|
||||
model_args.model_name_or_path
|
||||
)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16)
|
||||
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
|
||||
if model_args.init_editflow_from_src:
|
||||
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, dtype=torch.bfloat16
|
||||
)
|
||||
dllm.pipelines.editflow.utils.init_editflow_from_src(
|
||||
model, src_model, lm_head_key=model_args.lm_head_key
|
||||
)
|
||||
del src_model
|
||||
model = dllm.utils.load_peft(model, model_args)
|
||||
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
model=model,
|
||||
)
|
||||
67
dllm/examples/editflow/dream/pt.py
Normal file
67
dllm/examples/editflow/dream/pt.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/dream/pt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/dream/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/pt.py"
|
||||
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/pt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import pt as editflow_pt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_pt.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
|
||||
lm_head_key: str = "lm_head"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_pt.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_pt.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_pt.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowDreamConfig,
|
||||
)
|
||||
66
dllm/examples/editflow/dream/sft.py
Normal file
66
dllm/examples/editflow/dream/sft.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/dream/sft.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/editflow/dream/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/dream/sft.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = (
|
||||
"models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]/checkpoint-final"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-Dream-7B-Instruct-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
)
|
||||
418
dllm/examples/editflow/generate.py
Normal file
418
dllm/examples/editflow/generate.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""
|
||||
Minimal EditFlow τ-leap generator for EditBase-Dream with diffusion-style visualization.
|
||||
|
||||
What changed vs. your original:
|
||||
- tau_leap_step_minimal returns (x_next, any_edit, step_trace) preserving all intermediates.
|
||||
- generate_editflow_minimal returns (final_text, trace).
|
||||
- render_consecutive_trace_gif(trace, tokenizer, ...) draws a GIF where each frame shows
|
||||
ONLY the current output (like the Gemini diffusion page shows progressive refinement):
|
||||
* SUB tokens in the current frame are orange
|
||||
* INS tokens in the current frame are blue
|
||||
* KEEP tokens are black
|
||||
* If any deletions happened in the step, the title shows ⌫N (red)
|
||||
"""
|
||||
|
||||
# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --time=03:00:000 python examples/editflow/generate.py --model_name_or_path "models/EditFlow-Dream-Instruct-7B/tulu-3-sft-mixture/checkpoint-final" --tau 0.02 --mask_length 128 --seed 7070 --prompt "write a romantic story" --make_gif
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
|
||||
import tyro
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from dllm.core.schedulers import BaseKappaScheduler, LinearKappaScheduler
|
||||
|
||||
|
||||
# ------------------------------- Small utilities --------------------------------
|
||||
|
||||
|
||||
def _bernoulli_from_rate(rate: torch.Tensor, tau: float) -> torch.Tensor:
|
||||
"""First-order τ-leap Bernoulli with p ≈ rate * τ (clamped)."""
|
||||
p = (rate.float() * float(tau)).clamp_(0.0, 1.0 - 1e-6)
|
||||
return torch.bernoulli(p)
|
||||
|
||||
|
||||
def _sample_from_logits(logits_row: torch.Tensor, temperature: float) -> int:
|
||||
"""Sample one token id from a 1D logits row with temperature.
|
||||
temperature <= 0 -> greedy (argmax).
|
||||
"""
|
||||
if temperature <= 0.0:
|
||||
return int(torch.argmax(logits_row).item())
|
||||
return int(
|
||||
torch.distributions.Categorical(logits=(logits_row / temperature))
|
||||
.sample()
|
||||
.item()
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenCfg:
|
||||
tau: float = 0.02 # τ step
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
seed: int = 1234
|
||||
edit_prompt: bool = False # allow editing inside prompt region?
|
||||
temperature: float = 0.7 # token sampling temperature (sub/ins)
|
||||
verbose: bool = True # whether to show intermediate decoding traces
|
||||
time_independent: bool = True
|
||||
|
||||
|
||||
# -------------------------------- τ-leap one step --------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def tau_leap_step_minimal(
|
||||
x: torch.Tensor, # [T]
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_len: int, # number of initial prompt tokens (including BOS)
|
||||
t: float,
|
||||
sched: BaseKappaScheduler,
|
||||
cfg: GenCfg,
|
||||
prev_out: dict | None = None, # <-- pass prior step's model outputs
|
||||
reuse_prev: bool = False, # <-- if True, reuse prev_out instead of forward()
|
||||
) -> tuple[torch.Tensor, bool, dict, dict]:
|
||||
"""
|
||||
Single τ-leap step with deletion/substitution conflict resolution
|
||||
and right-insert policy.
|
||||
|
||||
Reuse semantics:
|
||||
• If cfg.time_independent == True and reuse_prev == True and prev_out is not None,
|
||||
we reuse `prev_out` tensors instead of calling model() again.
|
||||
• Otherwise we run a fresh forward().
|
||||
|
||||
Viz-only convention:
|
||||
• Any local annotated as _Ann[*, "viz-only"] is used only for human-visible
|
||||
tracing / debugging (console logs, GIFs) and does not affect generation.
|
||||
• Such variables are also prefixed with '_' for quick visual scanning.
|
||||
|
||||
Returns:
|
||||
x_next, any_edit, _step_trace, out_for_next (the freshly used model outputs)
|
||||
"""
|
||||
device = x.device
|
||||
T = x.numel()
|
||||
|
||||
# Decide whether to reuse the previous forward results
|
||||
use_reuse = bool(cfg.time_independent and reuse_prev and (prev_out is not None))
|
||||
if use_reuse:
|
||||
out = prev_out
|
||||
else:
|
||||
attn = torch.ones(1, T, dtype=torch.long, device=device)
|
||||
t_tensor = torch.full((1, 1), float(t), device=device)
|
||||
out = model(input_ids=x.unsqueeze(0), attention_mask=attn, t=t_tensor)
|
||||
|
||||
del_rate_h = out["del_rate_hat"] # [1, T]
|
||||
sub_rate_h = out["sub_rate_hat"] # [1, T]
|
||||
ins_rate_h = out["ins_rate_hat"] # [1, T]
|
||||
sub_logits = out["sub_logits"] # [1, T, V]
|
||||
ins_logits = out["ins_logits"] # [1, T, V]
|
||||
|
||||
# Scale normalized rates to true rates
|
||||
tt = torch.tensor([[t]], device=device)
|
||||
w = sched.weight(tt)
|
||||
del_rate = del_rate_h * w
|
||||
sub_rate = sub_rate_h * w
|
||||
ins_rate = ins_rate_h * w
|
||||
|
||||
# Clamp prompt_len within current T (robustness)
|
||||
prompt_len_clamped = int(max(1, min(prompt_len, T)))
|
||||
|
||||
if not cfg.edit_prompt:
|
||||
# Protect the entire prompt span from del/sub
|
||||
del_rate[:, :prompt_len_clamped] = 0.0
|
||||
sub_rate[:, :prompt_len_clamped] = 0.0
|
||||
# Disallow insertions inside the prompt EXCEPT at the last prompt token
|
||||
if prompt_len_clamped >= 2:
|
||||
ins_rate[:, : prompt_len_clamped - 1] = 0.0
|
||||
|
||||
# Combined "edit" (delete or substitute) event
|
||||
comb_rate = (del_rate + sub_rate).squeeze(0) # [T]
|
||||
comb_fire = _bernoulli_from_rate(comb_rate, cfg.tau).bool() # [T]
|
||||
|
||||
# If an edit fires at i, choose deletion with prob λ_del/(λ_del+λ_sub)
|
||||
p_del = (del_rate.squeeze(0) / (comb_rate + 1e-8)).clamp(0, 1) # [T]
|
||||
choose_del = (torch.rand_like(p_del) < p_del) & comb_fire # [T]
|
||||
choose_sub = comb_fire & (~choose_del) # [T]
|
||||
|
||||
# Insertions (right of token i)
|
||||
ins_fire = _bernoulli_from_rate(ins_rate.squeeze(0), cfg.tau).bool() # [T]
|
||||
|
||||
# Token draws (algorithmic, not viz-only)
|
||||
sub_samples: list[int | None] = [
|
||||
(
|
||||
_sample_from_logits(sub_logits[0, i], cfg.temperature)
|
||||
if choose_sub[i]
|
||||
else None
|
||||
)
|
||||
for i in range(T)
|
||||
]
|
||||
ins_samples: list[int | None] = [
|
||||
_sample_from_logits(ins_logits[0, i], cfg.temperature) if ins_fire[i] else None
|
||||
for i in range(T)
|
||||
]
|
||||
|
||||
# Build new sequence left→right (apply insertions to the RIGHT)
|
||||
new_ids: list[int] = []
|
||||
|
||||
# --- viz-only per-position labels (for trace/GIF) ---
|
||||
_before_ops: Annotated[list[str], "viz-only"] = (
|
||||
[]
|
||||
) # per 'before' position: DEL/SUB/KEEP
|
||||
_after_ops: Annotated[list[str], "viz-only"] = (
|
||||
[]
|
||||
) # per 'after' token aligned to new_ids: INS/SUB/KEEP
|
||||
|
||||
for i in range(T):
|
||||
if choose_del[i]:
|
||||
_before_ops.append("DEL")
|
||||
# deletion -> no token appended
|
||||
elif choose_sub[i]:
|
||||
_before_ops.append("SUB")
|
||||
new_tok = sub_samples[i]
|
||||
new_ids.append(int(new_tok))
|
||||
_after_ops.append("SUB")
|
||||
else:
|
||||
_before_ops.append("KEEP")
|
||||
new_ids.append(int(x[i].item()))
|
||||
_after_ops.append("KEEP")
|
||||
|
||||
if ins_samples[i] is not None:
|
||||
new_ids.append(int(ins_samples[i]))
|
||||
_after_ops.append("INS")
|
||||
|
||||
x_next = torch.tensor(new_ids, dtype=torch.long, device=device)
|
||||
any_edit = bool(comb_fire.any().item() or ins_fire.any().item())
|
||||
# Provide the exact outputs we used this step for the caller to pass forward
|
||||
out_for_next = out
|
||||
|
||||
# --- (vis) used only for verbose console trace ---
|
||||
if cfg.verbose and (comb_fire.any() or ins_fire.any()):
|
||||
|
||||
def _tok_str(tok_id: int) -> str: # viz-only helper
|
||||
try:
|
||||
s = tokenizer.decode([int(tok_id)])
|
||||
return s if s.strip() else f"<{int(tok_id)}>"
|
||||
except Exception:
|
||||
return f"<{int(tok_id)}>"
|
||||
|
||||
_ops_strs: Annotated[list[str], "viz-only"] = []
|
||||
for i in range(T):
|
||||
if choose_del[i]:
|
||||
_ops_strs.append(f"DEL@{i}:{_tok_str(int(x[i]))}")
|
||||
elif choose_sub[i]:
|
||||
_ops_strs.append(
|
||||
f"SUB@{i}:{_tok_str(int(x[i]))}->{_tok_str(sub_samples[i])}"
|
||||
)
|
||||
if ins_samples[i] is not None:
|
||||
_ops_strs.append(f"INS@{i}->{i+1}:{_tok_str(ins_samples[i])}")
|
||||
print("[time]", f"{t:.4f}")
|
||||
print("[events]", "; ".join(_ops_strs))
|
||||
print("[decode]\n", tokenizer.decode(new_ids, skip_special_tokens=False))
|
||||
print()
|
||||
|
||||
# --- (vis) step trace payload (returned; used only for visualization downstream) ---
|
||||
_step_trace: Annotated[dict, "viz-only"] = {
|
||||
"t": float(t),
|
||||
"x_before_ids": [int(i) for i in x.tolist()],
|
||||
"x_after_ids": [int(i) for i in new_ids],
|
||||
"before_ops": _before_ops, # viz-only labels
|
||||
"after_ops": _after_ops, # viz-only labels
|
||||
# below are algorithmic signals copied for visualization/analysis
|
||||
"choose_del": [bool(v) for v in choose_del.tolist()],
|
||||
"choose_sub": [bool(v) for v in choose_sub.tolist()],
|
||||
"ins_fire": [bool(v) for v in ins_fire.tolist()],
|
||||
"sub_samples": [int(s) if s is not None else None for s in sub_samples],
|
||||
"ins_samples": [int(s) if s is not None else None for s in ins_samples],
|
||||
"prompt_len": prompt_len_clamped,
|
||||
"used_reuse": bool(use_reuse),
|
||||
}
|
||||
|
||||
return x_next, any_edit, _step_trace, out_for_next
|
||||
|
||||
|
||||
# -------------------------------- top-level generate -------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_editflow_minimal(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
args,
|
||||
cfg: GenCfg,
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
Returns:
|
||||
final_text, trace
|
||||
|
||||
Notes on annotations:
|
||||
• Any local annotated with Annotated[..., "viz-only"] is only used to build
|
||||
the decode trace for visualization (e.g., GIF rendering) and has no effect
|
||||
on the actual generation. Such variables are also prefixed with '_' to make
|
||||
this visually obvious in code.
|
||||
"""
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
||||
# If prompt is None, start from BOS alone; otherwise ALWAYS prefix BOS
|
||||
bos = getattr(tokenizer, "bos_token_id", None)
|
||||
if bos is None:
|
||||
raise ValueError("Tokenizer must have a BOS token for this sampler.")
|
||||
|
||||
prompt = args.prompt
|
||||
if prompt is None:
|
||||
ids = [bos] # BOS alone
|
||||
else:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
# ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
# ids = [bos] + enc["input_ids"] # ALWAYS prefix BOS
|
||||
|
||||
prompt_len = len(ids)
|
||||
|
||||
if args.mask_length:
|
||||
if getattr(tokenizer, "mask_token_id", None) is None:
|
||||
raise ValueError(
|
||||
"Tokenizer must define mask_token_id when --mask_length > 0."
|
||||
)
|
||||
ids = ids + [tokenizer.mask_token_id] * args.mask_length
|
||||
|
||||
x = torch.tensor(ids, dtype=torch.long, device=model.device)
|
||||
|
||||
sched = LinearKappaScheduler()
|
||||
tau = cfg.tau
|
||||
steps = math.ceil(1.0 / max(tau, 1e-9))
|
||||
|
||||
_trace: Annotated[dict, "viz-only: full decode trace for GIF/inspection"] = {
|
||||
"steps": [],
|
||||
"init": {
|
||||
"t": 0.0,
|
||||
"x_ids": [int(i) for i in x.tolist()],
|
||||
"prompt_len": int(prompt_len),
|
||||
},
|
||||
"end_t": 0.0,
|
||||
}
|
||||
|
||||
# Local-only reuse: if previous iteration had no edits, reuse its forward.
|
||||
prev_out: dict | None = None
|
||||
prev_had_edits = True # first iteration must run a forward
|
||||
|
||||
t = 0.0
|
||||
for _ in range(steps):
|
||||
# We can reuse prev_out only if the model is declared time-independent
|
||||
# and the previous step had NO edits (sequence unchanged).
|
||||
reuse_prev = (
|
||||
cfg.time_independent and not prev_had_edits and (prev_out is not None)
|
||||
)
|
||||
|
||||
x, edited, _step_trace, prev_out = tau_leap_step_minimal(
|
||||
x=x,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt_len=prompt_len,
|
||||
t=t,
|
||||
sched=sched,
|
||||
cfg=cfg,
|
||||
prev_out=prev_out,
|
||||
reuse_prev=reuse_prev,
|
||||
)
|
||||
|
||||
_step_trace: Annotated[dict, "viz-only: per-step intermediates for trace"]
|
||||
_trace["steps"].append(_step_trace)
|
||||
|
||||
prev_had_edits = edited
|
||||
|
||||
t = min(1.0, t + tau)
|
||||
if t >= 1.0 - args.time_epsilon:
|
||||
break
|
||||
|
||||
_trace["end_t"] = float(t)
|
||||
|
||||
final_text = tokenizer.decode(x.tolist(), skip_special_tokens=False)
|
||||
print("[final]")
|
||||
return final_text, _trace
|
||||
|
||||
|
||||
# ---------------------------------------- CLI -------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
@dataclass
|
||||
class ScriptArgs:
|
||||
# Required (no default)
|
||||
model_name_or_path: Annotated[str, "Path or hub id for the model"]
|
||||
time_independent: Annotated[
|
||||
bool, "Whether model is conditioned on time step"
|
||||
] = True
|
||||
|
||||
prompt: Annotated[str | None, "Text prompt. If None, start from BOS alone."] = (
|
||||
None
|
||||
)
|
||||
# Boolean flag: tyro exposes --edit-prompt / --no-edit-prompt automatically for bools
|
||||
edit_prompt: Annotated[
|
||||
bool,
|
||||
"Allow delete/substitute and insertions in the prompt region (BOS+prompt).",
|
||||
] = False
|
||||
|
||||
# Generation-related args
|
||||
tau: Annotated[float, "τ-leap size"] = 0.01
|
||||
time_epsilon: Annotated[
|
||||
float, "Match this with the `time_epsilon` arg used in your EditFlowTrainer"
|
||||
] = 1e-3
|
||||
mask_length: Annotated[
|
||||
int,
|
||||
"Number of <mask> tokens appended after the prompt.\n"
|
||||
"EditFlow will iteratively substitute, insert, or delete masks to form the output.",
|
||||
] = 128
|
||||
temperature: Annotated[float, "Token sampling temperature; 0 for greedy."] = 0.7
|
||||
|
||||
seed: Annotated[int, "Random seed"] = 1234
|
||||
verbose: Annotated[bool, "Whether to show intermediate decoding traces"] = True
|
||||
|
||||
# Visualization
|
||||
make_gif: Annotated[bool, "Render a decoding trace GIF after generation."] = (
|
||||
False
|
||||
)
|
||||
gif_path: Annotated[
|
||||
str | None, "Output GIF path (default: decode_trace.gif)"
|
||||
] = None
|
||||
frame_ms: Annotated[int, "Per-frame duration in ms"] = 120
|
||||
|
||||
args = tyro.cli(ScriptArgs)
|
||||
|
||||
cfg = GenCfg(
|
||||
tau=args.tau,
|
||||
seed=args.seed,
|
||||
edit_prompt=args.edit_prompt,
|
||||
temperature=args.temperature,
|
||||
verbose=args.verbose,
|
||||
time_independent=args.time_independent,
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
|
||||
final_text, trace = generate_editflow_minimal(model, tokenizer, args, cfg)
|
||||
print(final_text)
|
||||
|
||||
if args.make_gif:
|
||||
from examples.editflow.viz import render_consecutive_trace_gif
|
||||
|
||||
out = args.gif_path or "decode_trace.gif"
|
||||
path = render_consecutive_trace_gif(
|
||||
trace,
|
||||
tokenizer,
|
||||
out_path=out,
|
||||
frame_ms=args.frame_ms,
|
||||
)
|
||||
print(f"[gif saved] {path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
88
dllm/examples/editflow/llada/adapt.py
Normal file
88
dllm/examples/editflow/llada/adapt.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/llada/adapt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/llada/adapt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/adapt.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/adapt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
lm_head_key: str = "model.transformer.ff_out"
|
||||
init_editflow_from_src: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
# Create EditFlow model (bf16 init on CUDA)
|
||||
ef_cfg = dllm.pipelines.editflow.EditFlowLLaDAConfig.from_pretrained(
|
||||
model_args.model_name_or_path
|
||||
)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16)
|
||||
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
|
||||
if model_args.init_editflow_from_src:
|
||||
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, dtype=torch.bfloat16
|
||||
)
|
||||
dllm.pipelines.editflow.utils.init_editflow_from_src(
|
||||
model, src_model, lm_head_key=model_args.lm_head_key
|
||||
)
|
||||
del src_model
|
||||
model = dllm.utils.load_peft(model, model_args)
|
||||
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
model=model,
|
||||
)
|
||||
67
dllm/examples/editflow/llada/pt.py
Normal file
67
dllm/examples/editflow/llada/pt.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/llada/pt.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (DeepSpeed FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/llada/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/pt.py"
|
||||
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/pt.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import pt as editflow_pt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_pt.ModelArguments):
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
|
||||
lm_head_key: str = "model.transformer.ff_out"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_pt.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_pt.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_pt.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowLLaDAConfig,
|
||||
)
|
||||
66
dllm/examples/editflow/llada/sft.py
Normal file
66
dllm/examples/editflow/llada/sft.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/llada/sft.py \
|
||||
--lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/editflow/llada/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/editflow/llada/sft.py"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = (
|
||||
"models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]/checkpoint-final"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/EditFlow-LLaDA-8B-Instruct-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
)
|
||||
176
dllm/examples/editflow/pt.py
Normal file
176
dllm/examples/editflow/pt.py
Normal file
@ -0,0 +1,176 @@
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import editflow
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = None # overwrite this
|
||||
lm_head_key: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The key to the `lm_head` in the source model for initializing operation heads in the EditFlow model. "
|
||||
"Overwrite this when `init_editflow_from_src` = True"
|
||||
)
|
||||
},
|
||||
)
|
||||
init_editflow_from_src: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to initialize EditFlow model from the source model."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
text_field: str = "text"
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = None # overwrite this
|
||||
num_train_epochs: float = 20
|
||||
learning_rate: float = 3e-4
|
||||
# max_steps: int = 2_000
|
||||
per_device_train_batch_size: int = 3
|
||||
per_device_eval_batch_size: int = 3
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
# EditFlow specific args
|
||||
scheduler_cls: str = field(
|
||||
default="LinearKappaScheduler",
|
||||
metadata={
|
||||
"help": (
|
||||
"The scheduler class controlling κ(t). "
|
||||
"Available options: see `dllm/utils/schedulers/kappa.py`"
|
||||
)
|
||||
},
|
||||
)
|
||||
normalize_per_position: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to normalize the loss per position."},
|
||||
)
|
||||
max_w: float = field(
|
||||
default=20.0,
|
||||
metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."},
|
||||
)
|
||||
x0_sampler: str = field(
|
||||
default="masks[length:128]",
|
||||
metadata={
|
||||
"help": (
|
||||
"Choose the x0 sampler. "
|
||||
"Available options: see `dllm/pipelines/editflow/utils.py`"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: TrainingArguments,
|
||||
ef_config_cls: type[transformers.PretrainedConfig],
|
||||
):
|
||||
# necessary when batch does not contain "labels" field
|
||||
training_args.label_names = []
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
# necessary for streaming dataset
|
||||
training_args.accelerator_config.dispatch_batches = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Load base Model and initialize EditFlow Model ---------------------------
|
||||
# Create EditFlow model (bf16 init on CUDA)
|
||||
ef_cfg = ef_config_cls.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
dtype=model_args.dtype,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(ef_cfg)
|
||||
if model_args.init_editflow_from_src:
|
||||
# Load src model config & weights (bf16 on CUDA) for intializing EditFlow model
|
||||
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, dtype=model_args.dtype
|
||||
)
|
||||
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
|
||||
editflow.utils.init_editflow_from_src(
|
||||
model, src_model, lm_head_key=model_args.lm_head_key
|
||||
)
|
||||
del src_model
|
||||
model = dllm.utils.load_peft(model, model_args)
|
||||
|
||||
def _no_flops(*args, **kwargs):
|
||||
return 0.0
|
||||
|
||||
model.floating_point_ops = _no_flops
|
||||
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = editflow.EditFlowTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=editflow.utils.EditFlowCollator(
|
||||
tokenizer=tokenizer, x0_sampler=training_args.x0_sampler
|
||||
),
|
||||
scheduler=dllm.core.schedulers.make_kappa_scheduler(
|
||||
training_args.scheduler_cls
|
||||
),
|
||||
normalize_per_position=training_args.normalize_per_position,
|
||||
max_w=training_args.max_w,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
192
dllm/examples/editflow/sft.py
Normal file
192
dllm/examples/editflow/sft.py
Normal file
@ -0,0 +1,192 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import editflow
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = None # overwrite this
|
||||
lm_head_key: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The key to the `lm_head` in the source model for initializing operation heads in the EditFlow model. "
|
||||
"Overwrite this when `init_editflow_from_src` = True"
|
||||
)
|
||||
},
|
||||
)
|
||||
init_editflow_from_src: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to initialize EditFlow model from the source model."
|
||||
},
|
||||
)
|
||||
init_editflow_from_editflow: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = None # overwrite this
|
||||
per_device_train_batch_size: int = 2
|
||||
per_device_eval_batch_size: int = 2
|
||||
learning_rate: float = 5e-5
|
||||
# EditFlow specific args
|
||||
scheduler_cls: str = field(
|
||||
default="LinearKappaScheduler",
|
||||
metadata={
|
||||
"help": (
|
||||
"The scheduler class controlling κ(t). "
|
||||
"Available options: see `dllm/utils/schedulers/kappa.py`"
|
||||
)
|
||||
},
|
||||
)
|
||||
normalize_per_position: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to normalize the loss per position."},
|
||||
)
|
||||
max_w: float = field(
|
||||
default=20.0,
|
||||
metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."},
|
||||
)
|
||||
x0_sampler: str = field(
|
||||
default="masks[length:128]",
|
||||
metadata={
|
||||
"help": (
|
||||
"Choose the x0 sampler. "
|
||||
"Available options: see `dllm/pipelines/editflow/utils.py`"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
|
||||
# - `input_ids`` = prompt + response
|
||||
# - `prompt_len` marks the prompt span to EXCLUDE from loss.
|
||||
# (Remove prompt_len to train on all tokens—if so, ensure a BOS is prepended.)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
if mask_prompt_loss:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can insert to the left of the very first token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
prompt_response_tokens = [tokenizer.bos_token_id] + prompt_response_tokens
|
||||
return {"input_ids": prompt_response_tokens}
|
||||
|
||||
|
||||
def train(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: TrainingArguments,
|
||||
ef_config_cls: type[transformers.PretrainedConfig],
|
||||
):
|
||||
# necessary when batch does not contain "labels" field
|
||||
training_args.label_names = []
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Load EditFlow Model ----------------------------------------------------
|
||||
if model_args.init_editflow_from_editflow:
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
else:
|
||||
ef_cfg = ef_config_cls.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
dtype=model_args.dtype,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(ef_cfg)
|
||||
if model_args.init_editflow_from_src:
|
||||
# Load src model config & weights (bf16 on CUDA) for intializing EditFlow model
|
||||
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, dtype=model_args.dtype
|
||||
)
|
||||
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
|
||||
editflow.utils.init_editflow_from_src(
|
||||
model, src_model, lm_head_key=model_args.lm_head_key
|
||||
)
|
||||
del src_model
|
||||
model = dllm.utils.load_peft(model, model_args)
|
||||
|
||||
def _no_flops(*args, **kwargs):
|
||||
return 0.0
|
||||
|
||||
model.floating_point_ops = _no_flops
|
||||
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = editflow.EditFlowTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=editflow.utils.EditFlowCollator(
|
||||
tokenizer=tokenizer, x0_sampler=training_args.x0_sampler
|
||||
),
|
||||
scheduler=dllm.core.schedulers.make_kappa_scheduler(
|
||||
training_args.scheduler_cls
|
||||
),
|
||||
normalize_per_position=training_args.normalize_per_position,
|
||||
max_w=training_args.max_w,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
489
dllm/examples/editflow/viz.py
Normal file
489
dllm/examples/editflow/viz.py
Normal file
@ -0,0 +1,489 @@
|
||||
# ------------------------------ Visualization (NEW) ------------------------------
|
||||
# Diffusion-style consecutive output: only show the CURRENT output per frame.
|
||||
# ------------------ Visualization (sanitized, masks stripped) ------------------
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Optional, List, Tuple, Annotated
|
||||
|
||||
|
||||
def render_consecutive_trace_gif(
|
||||
trace: dict,
|
||||
tokenizer,
|
||||
out_path: str = "decode_trace.gif",
|
||||
font_size: int = 30,
|
||||
line_spacing: int = 12,
|
||||
frame_ms: int = 250,
|
||||
final_ms: int = 5000, # final clean frame duration (ms)
|
||||
max_width: int = 1400,
|
||||
max_height: int = 3000,
|
||||
margin: int = 32,
|
||||
title_color=(80, 80, 80),
|
||||
text_color=(0, 0, 0), # base black
|
||||
mask_color=(150, 150, 150),
|
||||
sub_nonmask_color=(200, 0, 0), # persistent red
|
||||
ins_color=(0, 0, 200), # persistent blue
|
||||
del_strike_color=(120, 120, 120),
|
||||
events_color=(30, 30, 30),
|
||||
box_color=(120, 120, 120),
|
||||
bg_color=(255, 255, 255),
|
||||
):
|
||||
"""
|
||||
Persistent coloring keyed by token *instance* (not token id):
|
||||
- Inserted tokens -> BLUE across frames (until deleted/substituted again).
|
||||
- Substitution nonmask→nonmask -> RED across frames (until deleted/substituted again).
|
||||
- Substitution mask→nonmask -> stays BLACK (no extra color).
|
||||
Adds a final clean frame (5s) with no events box.
|
||||
"""
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import unicodedata
|
||||
|
||||
# ---------- font ----------
|
||||
try:
|
||||
font = ImageFont.truetype(
|
||||
"assets/JetBrainsMono-VariableFont_wght.ttf", font_size
|
||||
)
|
||||
except Exception:
|
||||
print(f"fail to load target font")
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# ---------- helpers ----------
|
||||
def _sanitize_token(s: str) -> str:
|
||||
vis_mask_token = "[m]"
|
||||
s = unicodedata.normalize("NFKC", s)
|
||||
s = s.replace("Ċ", "\n").replace("▁", " ").replace("Ġ", " ")
|
||||
s = s.replace("\t", " ")
|
||||
s = s.replace("\u00a0", " ").replace("\u2007", " ").replace("\u202f", " ")
|
||||
|
||||
# replace mask variants
|
||||
if "mdm_mask" in s.lower():
|
||||
s = re.sub(r"<[\|]?\s*mdm_mask\s*[\|]?>", "[m]", s, flags=re.IGNORECASE)
|
||||
s = s.replace("mdm_mask", "[m]")
|
||||
if "mask" in s.lower():
|
||||
s = re.sub(r"<[\|]?\s*mask\s*[\|]?>", "[m]", s, flags=re.IGNORECASE)
|
||||
s = s.replace("mask", "[m]")
|
||||
|
||||
# replace <|...|> format tokens with bracketed form
|
||||
s = re.sub(r"<\|\s*(.*?)\s*\|>", r"[\1]", s)
|
||||
return s
|
||||
|
||||
def _tok_str(tok_id: int) -> str:
|
||||
try:
|
||||
s = tokenizer.decode([int(tok_id)], skip_special_tokens=False)
|
||||
s = s if s.strip() else f"<{int(tok_id)}>"
|
||||
except Exception:
|
||||
s = f"<{int(tok_id)}>"
|
||||
return s.replace("\n", "\\n")
|
||||
|
||||
TOKEN_RE = re.compile(r"\s+|\S+")
|
||||
|
||||
def _wrap_text(draw: ImageDraw.ImageDraw, text: str, width_px: int) -> List[str]:
|
||||
if text == "":
|
||||
return [""]
|
||||
lines: List[str] = []
|
||||
for para in text.split("\n"):
|
||||
tokens = TOKEN_RE.findall(para)
|
||||
cur = ""
|
||||
for tok in tokens:
|
||||
candidate = cur + tok
|
||||
if draw.textlength(candidate, font=font) <= width_px:
|
||||
cur = candidate
|
||||
else:
|
||||
if cur:
|
||||
lines.append(cur)
|
||||
cur = tok
|
||||
while (
|
||||
draw.textlength(cur, font=font) > width_px and len(cur) > 0
|
||||
):
|
||||
lo, hi, fit = 1, len(cur), 1
|
||||
while lo <= hi:
|
||||
mid = (lo + hi) // 2
|
||||
if draw.textlength(cur[:mid], font=font) <= width_px:
|
||||
fit, lo = mid, mid + 1
|
||||
else:
|
||||
hi = mid - 1
|
||||
lines.append(cur[:fit])
|
||||
cur = cur[fit:]
|
||||
else:
|
||||
t = tok
|
||||
while draw.textlength(t, font=font) > width_px and len(t) > 0:
|
||||
lo, hi, fit = 1, len(t), 1
|
||||
while lo <= hi:
|
||||
mid = (lo + hi) // 2
|
||||
if draw.textlength(t[:mid], font=font) <= width_px:
|
||||
fit, lo = mid, mid + 1
|
||||
else:
|
||||
hi = mid - 1
|
||||
lines.append(t[:fit])
|
||||
t = t[fit:]
|
||||
cur = t
|
||||
lines.append(cur)
|
||||
return lines or [""]
|
||||
|
||||
tmp_img = Image.new("RGB", (10, 10), bg_color)
|
||||
tmp_draw = ImageDraw.Draw(tmp_img)
|
||||
text_width_budget = max_width - 2 * margin
|
||||
|
||||
# mask detection
|
||||
MASK_IDS = set()
|
||||
if getattr(tokenizer, "mask_token_id", None) is not None:
|
||||
MASK_IDS.add(int(tokenizer.mask_token_id))
|
||||
MASK_STRINGS = set()
|
||||
mt = getattr(tokenizer, "mask_token", None)
|
||||
if mt is not None:
|
||||
MASK_STRINGS.add(str(mt))
|
||||
MASK_STRINGS.add("<mdm_mask>")
|
||||
|
||||
def _is_mask_token(tok_id: int, tok_str_exact: str) -> bool:
|
||||
return (int(tok_id) in MASK_IDS) or (tok_str_exact in MASK_STRINGS)
|
||||
|
||||
def _wrap_tokens_with_index(tokens, deleted_flags):
|
||||
lines, cur, cur_w = [], [], 0
|
||||
for i, tok in enumerate(tokens):
|
||||
t = _sanitize_token(tok)
|
||||
parts = t.split("\n")
|
||||
for j, seg in enumerate(parts):
|
||||
seg_rest = seg
|
||||
while seg_rest:
|
||||
w = tmp_draw.textlength(seg_rest, font=font)
|
||||
if cur_w + w <= text_width_budget or not cur:
|
||||
cur.append((seg_rest, i, deleted_flags[i]))
|
||||
cur_w += w
|
||||
seg_rest = ""
|
||||
else:
|
||||
lines.append(cur)
|
||||
cur, cur_w = [], 0
|
||||
if j != len(parts) - 1:
|
||||
lines.append(cur)
|
||||
cur, cur_w = [], 0
|
||||
if cur:
|
||||
lines.append(cur)
|
||||
return lines
|
||||
|
||||
def _draw_dashed_rectangle(
|
||||
draw, xy, dash=8, gap=6, width=2, outline=(120, 120, 120)
|
||||
):
|
||||
x0, y0, x1, y1 = xy
|
||||
x = x0
|
||||
while x < x1:
|
||||
x2 = min(x + dash, x1)
|
||||
draw.line([(x, y0), (x2, y0)], fill=outline, width=width)
|
||||
draw.line([(x, y1), (x2, y1)], fill=outline, width=width)
|
||||
x += dash + gap
|
||||
y = y0
|
||||
while y < y1:
|
||||
y2 = min(y + dash, y1)
|
||||
draw.line([(x0, y), (x0, y2)], fill=outline, width=width)
|
||||
draw.line([(x1, y), (x1, y2)], fill=outline, width=width)
|
||||
y += dash + gap
|
||||
|
||||
def _ops_lines_for_step(st: dict):
|
||||
if st is None:
|
||||
return ["(no events)"]
|
||||
lines = []
|
||||
x_before = st["x_before_ids"]
|
||||
choose_del = st["choose_del"]
|
||||
choose_sub = st["choose_sub"]
|
||||
sub_samples = st["sub_samples"]
|
||||
ins_samples = st["ins_samples"]
|
||||
T = len(x_before)
|
||||
for i in range(T):
|
||||
if choose_del[i]:
|
||||
lines.append(f"DEL@{i}:{_tok_str(int(x_before[i]))}")
|
||||
elif choose_sub[i]:
|
||||
lines.append(
|
||||
f"SUB@{i}:{_tok_str(int(x_before[i]))}->{_tok_str(int(sub_samples[i]))}"
|
||||
)
|
||||
if ins_samples[i] is not None:
|
||||
lines.append(f"INS@{i}->{i+1}:{_tok_str(int(ins_samples[i]))}")
|
||||
if not lines:
|
||||
lines.append("(no events)")
|
||||
return lines
|
||||
|
||||
# ---- Instance-id machinery ----
|
||||
next_instance_id = 0
|
||||
|
||||
def _new_inst():
|
||||
nonlocal next_instance_id
|
||||
val = next_instance_id
|
||||
next_instance_id += 1
|
||||
return val
|
||||
|
||||
# Current sequence at the *start* (ids + instance_ids)
|
||||
curr_ids = list(trace["init"]["x_ids"])
|
||||
curr_inst = [_new_inst() for _ in curr_ids]
|
||||
|
||||
# Persistent color by instance_id: {"blue", "red"}
|
||||
color_by_inst = {}
|
||||
|
||||
# ---------- PASS 1: measure required heights per frame ----------
|
||||
measurement_payload = []
|
||||
|
||||
for step_idx, st in enumerate([None] + trace["steps"]):
|
||||
# build augmented view
|
||||
if st is None:
|
||||
aug_ids = list(curr_ids)
|
||||
deleted_flags = [False] * len(aug_ids)
|
||||
else:
|
||||
x_before = st["x_before_ids"]
|
||||
choose_del = st["choose_del"]
|
||||
after_ids = st["x_after_ids"]
|
||||
deleted_positions = [i for i, d in enumerate(choose_del) if d]
|
||||
|
||||
aug_ids = list(after_ids)
|
||||
deleted_flags = [False] * len(after_ids)
|
||||
for i in sorted(deleted_positions, reverse=True):
|
||||
aug_ids.insert(i, x_before[i])
|
||||
deleted_flags.insert(i, True)
|
||||
|
||||
tokens = tokenizer.convert_ids_to_tokens(aug_ids)
|
||||
wrapped_lines = _wrap_tokens_with_index(tokens, deleted_flags)
|
||||
|
||||
# estimate ops lines for this step
|
||||
if st:
|
||||
ops_text = " • " + " • ".join(_ops_lines_for_step(st))
|
||||
else:
|
||||
ops_text = "(no events)"
|
||||
ops_lines = _wrap_text(tmp_draw, ops_text, text_width_budget)
|
||||
|
||||
# compute height needed
|
||||
body_h = len(wrapped_lines) * (font_size + line_spacing)
|
||||
ops_h = len(ops_lines) * (font_size + line_spacing) + font_size # + 20
|
||||
required_h = margin + (font_size + line_spacing) + body_h + 20
|
||||
|
||||
measurement_payload.append(
|
||||
{
|
||||
"step_idx": step_idx,
|
||||
"st": st,
|
||||
"aug_ids": aug_ids,
|
||||
"tokens": tokens,
|
||||
"deleted_flags": deleted_flags,
|
||||
"wrapped_lines": wrapped_lines,
|
||||
"ops_lines": ops_lines,
|
||||
"required_h": required_h,
|
||||
}
|
||||
)
|
||||
|
||||
# Measure clean final frame (no events)
|
||||
final_text_ids = (
|
||||
trace["steps"][-1]["x_after_ids"] if trace["steps"] else trace["init"]["x_ids"]
|
||||
)
|
||||
final_tokens = tokenizer.convert_ids_to_tokens(final_text_ids)
|
||||
wrapped_clean = _wrap_tokens_with_index(final_tokens, [False] * len(final_tokens))
|
||||
clean_body_h = len(wrapped_clean) * (font_size + line_spacing)
|
||||
clean_required_h = margin + (font_size + line_spacing) + clean_body_h
|
||||
|
||||
# Pick a single uniform canvas height
|
||||
max_required_h = max(
|
||||
[p["required_h"] for p in measurement_payload] + [clean_required_h]
|
||||
) # + 20
|
||||
H = min(max_required_h, max_height)
|
||||
W = max_width
|
||||
|
||||
# For each frame we need an augmented view (with deleted placeholders) to draw
|
||||
frames = []
|
||||
|
||||
# Iterate steps; for step_idx==0 we still draw "initial state"
|
||||
steps_with_initial = [None] + trace["steps"]
|
||||
|
||||
for step_idx, st in enumerate(steps_with_initial):
|
||||
if st is None:
|
||||
# initial frame: augmented is just current tokens
|
||||
aug_ids = list(curr_ids)
|
||||
aug_inst = list(curr_inst)
|
||||
aug_deleted = [False] * len(aug_ids)
|
||||
ops_lines = ["(no events)"]
|
||||
title = "initial state"
|
||||
else:
|
||||
title = f"t = {st['t']:.3f}"
|
||||
x_before = list(st["x_before_ids"])
|
||||
choose_del = list(st["choose_del"])
|
||||
choose_sub = list(st["choose_sub"])
|
||||
sub_samples = list(st["sub_samples"])
|
||||
ins_samples = list(st["ins_samples"])
|
||||
assert (
|
||||
len(x_before) == len(curr_ids) == len(curr_inst)
|
||||
), "trace 'x_before' must match current sequence."
|
||||
|
||||
# Build augmented (drawn) and next (state-after) in one pass
|
||||
aug_ids, aug_inst, aug_deleted = [], [], []
|
||||
next_ids, next_inst = [], []
|
||||
|
||||
for i in range(len(x_before)):
|
||||
before_id = int(curr_ids[i])
|
||||
before_inst = curr_inst[i]
|
||||
|
||||
if choose_del[i]:
|
||||
# show deleted placeholder (strike-through)
|
||||
aug_ids.append(before_id)
|
||||
aug_inst.append(None)
|
||||
aug_deleted.append(True)
|
||||
# remove from next; also clear any persistent color
|
||||
color_by_inst.pop(before_inst, None)
|
||||
else:
|
||||
if choose_sub[i]:
|
||||
after_id = int(sub_samples[i])
|
||||
# in augmented we show the *after* token at same instance
|
||||
aug_ids.append(after_id)
|
||||
aug_inst.append(before_inst)
|
||||
aug_deleted.append(False)
|
||||
next_ids.append(after_id)
|
||||
next_inst.append(before_inst)
|
||||
|
||||
# update persistence by source type
|
||||
if int(before_id) in MASK_IDS:
|
||||
# mask → nonmask: no extra color (ensure cleared)
|
||||
color_by_inst.pop(before_inst, None)
|
||||
else:
|
||||
# nonmask → nonmask: mark RED
|
||||
color_by_inst[before_inst] = "red"
|
||||
else:
|
||||
# keep
|
||||
aug_ids.append(before_id)
|
||||
aug_inst.append(before_inst)
|
||||
aug_deleted.append(False)
|
||||
next_ids.append(before_id)
|
||||
next_inst.append(before_inst)
|
||||
|
||||
# insertion AFTER position i
|
||||
if ins_samples[i] is not None:
|
||||
ins_id = int(ins_samples[i])
|
||||
ins_inst = _new_inst()
|
||||
aug_ids.append(ins_id)
|
||||
aug_inst.append(ins_inst)
|
||||
aug_deleted.append(False)
|
||||
next_ids.append(ins_id)
|
||||
next_inst.append(ins_inst)
|
||||
# mark persistent BLUE for this *instance only*
|
||||
color_by_inst[ins_inst] = "blue"
|
||||
|
||||
# commit next state
|
||||
curr_ids, curr_inst = next_ids, next_inst
|
||||
ops_text = " • " + " • ".join(_ops_lines_for_step(st))
|
||||
ops_lines = _wrap_text(tmp_draw, ops_text, text_width_budget)
|
||||
|
||||
# ----- render this frame -----
|
||||
tokens = tokenizer.convert_ids_to_tokens(aug_ids)
|
||||
wrapped_lines = _wrap_tokens_with_index(tokens, aug_deleted)
|
||||
|
||||
img = Image.new("RGB", (W, H), bg_color)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
y = margin
|
||||
draw.text((margin, y), title, fill=title_color, font=font)
|
||||
y += font_size + line_spacing
|
||||
|
||||
for line in wrapped_lines:
|
||||
x = margin
|
||||
for seg_text, tok_idx, is_deleted in line:
|
||||
tok_id = int(aug_ids[tok_idx])
|
||||
tok_str_exact = tokens[tok_idx]
|
||||
inst = aug_inst[tok_idx]
|
||||
|
||||
if is_deleted:
|
||||
# strike deleted — grey masks slightly different if desired
|
||||
strike_color = (
|
||||
mask_color
|
||||
if _is_mask_token(tok_id, tok_str_exact)
|
||||
else del_strike_color
|
||||
)
|
||||
strike = "".join(ch + "\u0336" for ch in seg_text)
|
||||
draw.text((x, y), strike, fill=strike_color, font=font)
|
||||
x += tmp_draw.textlength(strike, font=font)
|
||||
else:
|
||||
# choose color by *instance*
|
||||
color = text_color
|
||||
if inst is not None and inst in color_by_inst:
|
||||
color = (
|
||||
ins_color
|
||||
if color_by_inst[inst] == "blue"
|
||||
else sub_nonmask_color
|
||||
)
|
||||
elif _is_mask_token(tok_id, tok_str_exact):
|
||||
color = mask_color
|
||||
draw.text((x, y), seg_text, fill=color, font=font)
|
||||
x += tmp_draw.textlength(seg_text, font=font)
|
||||
y += font_size + line_spacing
|
||||
|
||||
# draw events box for all but the extra final-clean frame we'll add later
|
||||
# if step_idx != len(steps_with_initial) - 1:
|
||||
# y += 20
|
||||
# x0, y0 = margin, y
|
||||
# x1 = max_width - margin
|
||||
# box_h = len(ops_lines) * (font_size + line_spacing) + font_size + 20
|
||||
# y1 = y0 + box_h
|
||||
# _draw_dashed_rectangle(draw, (x0, y0, x1, y1), outline=box_color)
|
||||
# draw.text((x0 + 10, y0 + 10), "events", fill=events_color, font=font)
|
||||
# yy = y0 + font_size + 20
|
||||
# for l in ops_lines:
|
||||
# draw.text((x0 + 10, yy), l, fill=events_color, font=font)
|
||||
# yy += font_size + line_spacing
|
||||
# y += 10
|
||||
frames.append(img)
|
||||
|
||||
# ----- extra final clean frame (no events box), 5s -----
|
||||
final_ids = list(curr_ids)
|
||||
final_inst = list(curr_inst)
|
||||
final_tokens = tokenizer.convert_ids_to_tokens(final_ids)
|
||||
|
||||
# wrap without deleted flags
|
||||
def _wrap_clean(tokens):
|
||||
lines, cur, cur_w = [], [], 0
|
||||
for i, tok in enumerate(tokens):
|
||||
t = _sanitize_token(tok)
|
||||
parts = t.split("\n")
|
||||
for j, seg in enumerate(parts):
|
||||
seg_rest = seg
|
||||
while seg_rest:
|
||||
w = tmp_draw.textlength(seg_rest, font=font)
|
||||
if cur_w + w <= text_width_budget or not cur:
|
||||
cur.append((seg_rest, i))
|
||||
cur_w += w
|
||||
seg_rest = ""
|
||||
else:
|
||||
lines.append(cur)
|
||||
cur, cur_w = [], 0
|
||||
if j != len(parts) - 1:
|
||||
lines.append(cur)
|
||||
cur, cur_w = [], 0
|
||||
if cur:
|
||||
lines.append(cur)
|
||||
return lines
|
||||
|
||||
wrapped_clean = _wrap_clean(final_tokens)
|
||||
|
||||
clean_img = Image.new("RGB", (W, H), bg_color)
|
||||
draw = ImageDraw.Draw(clean_img)
|
||||
draw.text((margin, margin), "final text", fill=title_color, font=font)
|
||||
y = margin + font_size + line_spacing
|
||||
for line in wrapped_clean:
|
||||
x = margin
|
||||
for seg_text, tok_idx in line:
|
||||
tok_id = int(final_ids[tok_idx])
|
||||
tok_str_exact = final_tokens[tok_idx]
|
||||
inst = final_inst[tok_idx]
|
||||
color = text_color
|
||||
if inst in color_by_inst:
|
||||
color = (
|
||||
ins_color if color_by_inst[inst] == "blue" else sub_nonmask_color
|
||||
)
|
||||
elif _is_mask_token(tok_id, tok_str_exact):
|
||||
color = mask_color
|
||||
draw.text((x, y), seg_text, fill=color, font=font)
|
||||
x += tmp_draw.textlength(seg_text, font=font)
|
||||
y += font_size + line_spacing
|
||||
frames.append(clean_img)
|
||||
|
||||
# save GIF
|
||||
durations = [frame_ms] * (len(frames) - 1) + [final_ms]
|
||||
frames[0].save(
|
||||
out_path,
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
duration=durations,
|
||||
loop=0,
|
||||
disposal=2,
|
||||
optimize=True,
|
||||
)
|
||||
return out_path
|
||||
206
dllm/examples/llada/README.md
Normal file
206
dllm/examples/llada/README.md
Normal file
@ -0,0 +1,206 @@
|
||||
# LLaDA
|
||||
|
||||
> 📄 Paper: [Large Language Diffusion Models](https://arxiv.org/abs/2502.09992) | 💻 Code: [github.com/ML-GSAI/LLaDA](https://github.com/ML-GSAI/LLaDA)
|
||||
|
||||
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **LLaDA**.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Inference](#inference)
|
||||
- [Evaluation](#evaluation)
|
||||
|
||||
## Setup
|
||||
> [!IMPORTANT]
|
||||
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
|
||||
>
|
||||
> **MoE checkpoints:** For models like [`LLaDA-MoE-7B-A1B-Base`](https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base), set `"model_type"` to `"lladamoe"` in the checkpoint’s `config.json`:
|
||||
> ```diff
|
||||
> - "model_type": "llada",
|
||||
> + "model_type": "lladamoe",
|
||||
> ```
|
||||
>
|
||||
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# tools relevant with LLaDA
|
||||
dllm/pipelines/llada
|
||||
├── __init__.py # Package initialization
|
||||
├── models/
|
||||
│ ├── configuration_lladamoe.py # LLaDA-MoE model configuration
|
||||
│ ├── configuration_llada.py # LLaDA model configuration
|
||||
│ ├── modeling_lladamoe.py # LLaDA-MoE model architecture
|
||||
│ └── modeling_llada.py # LLaDA model architecture
|
||||
├── generator.py # Inference logic
|
||||
└── trainer.py # Training logic (pretraining and finetuning)
|
||||
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/llada
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
<!-- > [!NOTE] -->
|
||||
<!-- > - We fixed attention mask bugs in [`modeling_lladamoe.py`](/dllm/pipelines/llada/models/modeling_lladamoe.py) and [`modeling_llada.py`](/dllm/pipelines/llada/models/modeling_llada.py). We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
|
||||
>
|
||||
> - We fixed bugs in `chat_template` and assign `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
|
||||
|
||||
<!-- > [!WARNING]
|
||||
> Before loading MoE checkpoints (e.g., [inclusionAI/LLaDA-MoE-7B-A1B-Base](https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base)), first overwrite the `model_type` field from `inclusionAI/LLaDA-MoE-7B-A1B-Base/config.json`:
|
||||
> ```diff
|
||||
> - "model_type": "llada",
|
||||
> + "model_type": "lladamoe",
|
||||
> ``` -->
|
||||
|
||||
## Training
|
||||
### Finetuning
|
||||
|
||||
For example, to SFT [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) for instruction following on 8 GPUs, run:
|
||||
```shell
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/llada/sft.py \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
|
||||
```shell
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
|
||||
<!-- **Reproducing [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)**. Though LLaDA is trained on proprietary data, we tried our best to reproduce LLaDA-8B-Instruct by finetuning LLaDA-8B-Base using our training pipeline on public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
|
||||
|
||||
#### Reproducing [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)
|
||||
Though LLaDA is trained on proprietary data, we tried our best to reproduce [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) by finetuning [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) using our training pipeline on public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
|
||||
|
||||
```shell
|
||||
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
|
||||
python dllm/tools/preprocess_sft_dataset.py \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--sft_map_fn_path "dllm.utils.default_sft_map_fn" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "data/sft/llada/tulu-3-sft-mixture" \
|
||||
--num_proc 64
|
||||
|
||||
# train on 24*8=192 A100s with FSDP, take about 8 hours
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "data/sft/llada/tulu-3-sft-mixture" \
|
||||
--load_preprocessed_data True \
|
||||
--output_dir "models/LLaDA-8B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
|
||||
--max_length 2048 \
|
||||
--truncation "right" \
|
||||
--group_by_length True \
|
||||
--num_train_epochs 5 \
|
||||
--learning_rate 1e-5 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 4 \
|
||||
--eval_on_start False \
|
||||
--eval_steps 0.1 \
|
||||
--save_steps 0.05
|
||||
```
|
||||
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
|
||||
|
||||
|
||||
### Pretraining
|
||||
|
||||
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
|
||||
```shell
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/pt.py" \
|
||||
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
|
||||
--dataset_args "mlfoundations/dclm-baseline-1.0" \
|
||||
--output_dir "models/LLaDA-8B-PT/dclm-baseline-1.0" \
|
||||
--max_length 1024 \
|
||||
--max_steps 2000 \
|
||||
--learning_rate 3e-4
|
||||
```
|
||||
|
||||
## Inference
|
||||
We support batch inference for standard generation and infilling:
|
||||
<!-- See [`examples/llada/generate.py`](/examples/llada/generate.py) for a full example: -->
|
||||
```shell
|
||||
python examples/llada/generate.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
|
||||
```
|
||||
We also support interactive multi-turn dialogue with visualization:
|
||||
<!-- See [`examples/llada/chat.py`](/examples/llada/chat.py) for a full example. -->
|
||||
```shell
|
||||
python examples/llada/chat.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [MMLU-Pro](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/llada/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "llada" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
|
||||
```
|
||||
|
||||
To automatically evaluate [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) and [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Instruct --instruct True
|
||||
bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Base --instruct False
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results refer to those from the original paper.
|
||||
> All evaluation parameters follow the configurations in the [LLaDA](https://github.com/ML-GSAI/LLaDA) repository.
|
||||
> Since the original evaluations were conducted with OpenCompass, task settings were adjusted for compatibility with the LLaDA model under `lm-eval`.
|
||||
> Complete evaluation results will be released soon. -->
|
||||
|
||||
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [LLaDA](https://github.com/ML-GSAI/LLaDA) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
|
||||
|
||||
<!-- <div align="center" style="min-width:1300px;"> -->
|
||||
|
||||
| | MMLU | BBH | ARC‑C | Hellaswag | TruthfulQA | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | CEval | CMMLU |
|
||||
|:----------------|:----:|:---:|:-----:|:-----------:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:------:|
|
||||
| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(reported)| 65.9 | 49.7 | 45.9 | 70.5 | 46.1 | 74.8 | 73.6 | 70.3 | 31.4 | 25.2 | 35.4 | 40.0 | 70.5 | 69.9 |
|
||||
| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(evaluated)| 65.8 | – | 45.7 | 69.3 | 45.6 | 70.7 | 70.6 | 70.4 | – | – | 32.3 | 38.8 | 70.2 | 69.9 |
|
||||
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/GSAI-ML/LLaDA-8B-Base" style="color: #808080; text-decoration: none;">
|
||||
<code>LLaDA-8B-Base</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
| | MMLU | MMLU‑Pro | ARC‑C | Hellaswag | GSM8K | Math | GPQA | HumanEval | MBPP |
|
||||
|:----------------|:----:|:---------:|:-----:|:-----------:|:-----:|:----:|:----:|:-----------:|:----:|
|
||||
| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(reported) | 65.5 | 37.0 | 88.5 | 74.6 | 69.4 | 31.9 | 33.3 | 49.4 | 41.0 |
|
||||
| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(evaluated) | 67.3 | 36.2 | 86.6 | 76.7 | 81.1 | – | – | 65.0 | 70.2 |
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 2. Evaluation results of
|
||||
<a href="https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct" style="color: #808080; text-decoration: none;">
|
||||
<code>LLaDA-8B-Instruct</code>
|
||||
</a>.
|
||||
</p>
|
||||
74
dllm/examples/llada/chat.py
Normal file
74
dllm/examples/llada/chat.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Interactive chat / generation script for LLaDA models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Chat mode (multi-turn, chat template)
|
||||
python -u examples/llada/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
|
||||
# Raw single-turn generation
|
||||
python -u examples/llada/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import llada
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
151
dllm/examples/llada/eval.sh
Normal file
151
dllm/examples/llada/eval.sh
Normal file
@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
|
||||
# ===== Input Arguments =====
|
||||
model_name_or_path="GSAI-ML/LLaDA-8B-Instruct"
|
||||
instruct=True
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--instruct)
|
||||
instruct="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ===== Conditional Configurations =====
|
||||
if [ "$instruct" = "True" ]; then
|
||||
echo ">>> Running in INSTRUCT mode"
|
||||
common_args="--model llada --apply_chat_template"
|
||||
else
|
||||
echo ">>> Running in BASE mode"
|
||||
common_args="--model llada"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Generation Tasks
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" = "True" ]; then
|
||||
# Instruct Generation Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gsm8k_cot --num_fewshot 8 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks bbh --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks minerva_math --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks humaneval_instruct --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mbpp_llada_instruct --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
else
|
||||
# Base Generation Tasks
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gsm8k --num_fewshot 8 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks bbh --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks minerva_math --num_fewshot 4 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks humaneval --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mbpp --num_fewshot 3 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
fi
|
||||
|
||||
|
||||
# =======================
|
||||
# Likelihood Tasks
|
||||
# =======================
|
||||
|
||||
if [ "$instruct" = "True" ]; then
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=3,steps=3,block_length=3,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=3,steps=3,block_length=3,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=5,steps=5,block_length=5,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gpqa_n_shot_gen --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=32,steps=32,block_length=32,cfg=0.0"
|
||||
|
||||
else
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks truthfulqa_mc2 --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=2.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks arc_challenge --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks hellaswag --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks winogrande --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks piqa --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks mmlu --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks cmmlu --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \
|
||||
--tasks ceval-valid --num_fewshot 5 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0"
|
||||
fi
|
||||
116
dllm/examples/llada/generate.py
Normal file
116
dllm/examples/llada/generate.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""
|
||||
python -u examples/llada/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: llada.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
|
||||
# --- Example 2: Batch fill-in-the-blanks ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: llada.infilling()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
masked_messages = [
|
||||
[
|
||||
{"role": "user", "content": tokenizer.mask_token * 20},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Sorry, I do not have answer to this question.",
|
||||
},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": "Please write an educational python function."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "def hello_" + tokenizer.mask_token * 20 + " return",
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
masked_messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, (i, s) in enumerate(zip(inputs, sequences)):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print("[Masked]:\n" + tokenizer.decode(i))
|
||||
print("\n[Filled]:\n" + (s.strip() if s.strip() else "<empty>"))
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
174
dllm/examples/llada/pt.py
Normal file
174
dllm/examples/llada/pt.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/llada/pt.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/llada/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 24 Nodes, 192 GPUs (FSDP):
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
# Uses only the configuration from model_name_or_path to initialize the model from scratch
|
||||
model_name_or_path: str = (
|
||||
"GSAI-ML/LLaDA-8B-Base" # "inclusionAI/LLaDA-MoE-7B-A1B-Base"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
text_field: str = "text"
|
||||
streaming: bool = True
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
random_length_ratio: float = field(
|
||||
default=0.01,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cut sequences during training. "
|
||||
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training for reference."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = (
|
||||
"models/LLaDA-8B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
|
||||
)
|
||||
learning_rate: float = 3e-4
|
||||
max_steps: int = 2_000
|
||||
per_device_train_batch_size: int = 4
|
||||
gradient_accumulation_steps: int = 4
|
||||
eval_steps: float = 0.05
|
||||
save_steps: float = 0.05
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary for streaming dataset
|
||||
if data_args.streaming:
|
||||
training_args.accelerator_config.dispatch_batches = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
# initialize model weights from scratch
|
||||
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(
|
||||
config, dtype=torch.bfloat16, init_params=True
|
||||
)
|
||||
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optional PEFT: LoRA ----------------------------------------------------
|
||||
model = dllm.utils.load_peft(model=model, model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
@dataclass
|
||||
class LLaDAPTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
# Reference: https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training
|
||||
# By default, 1% of the pre-training data are truncated to a random length
|
||||
random_length_ratio: float = 0.01
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
if torch.rand(1) < self.random_length_ratio:
|
||||
random_length = torch.randint(
|
||||
1, outputs["input_ids"].shape[1] + 1, (1,)
|
||||
)
|
||||
for key in ["input_ids", "labels", "attention_mask"]:
|
||||
if key in outputs:
|
||||
outputs[key] = outputs[key][:, :random_length]
|
||||
# Check if attention_mask is all ones and set it to None
|
||||
if torch.all(outputs["attention_mask"] == 1):
|
||||
outputs.pop("attention_mask")
|
||||
return outputs
|
||||
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=LLaDAPTCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
random_length_ratio=data_args.random_length_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
120
dllm/examples/llada/sft.py
Normal file
120
dllm/examples/llada/sft.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/llada/sft.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/llada/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/LLaDA-8B-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
group_by_length: bool = True
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
dllm.utils.default_sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=dllm.utils.NoAttentionMaskCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
1
dllm/examples/rnd/README.md
Normal file
1
dllm/examples/rnd/README.md
Normal file
@ -0,0 +1 @@
|
||||
WIP
|
||||
114
dllm/examples/rnd/preprocess.py
Normal file
114
dllm/examples/rnd/preprocess.py
Normal file
@ -0,0 +1,114 @@
|
||||
# """
|
||||
# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --cpus-per-task=12 --time=03:00:000
|
||||
|
||||
# python examples/rnd/preprocess.py --dataset_args "HuggingFaceTB/smoltalk" --output_dir "data/sft_proc/rnd/smoltalk"
|
||||
# """
|
||||
# import os
|
||||
# from dataclasses import dataclass
|
||||
# from typing import Dict, Any
|
||||
|
||||
# import datasets
|
||||
# import transformers
|
||||
# import accelerate
|
||||
# import tyro
|
||||
|
||||
# import dllm
|
||||
|
||||
|
||||
# # --- tyro: define dataclass for CLI args ---
|
||||
# @dataclass
|
||||
# class ScriptArguments:
|
||||
# """Preprocess SFT dataset (batch_size=1 only)"""
|
||||
# model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
# dataset_args: str = "HuggingFaceTB/smoltalk" # required
|
||||
# output_dir: str = "data/sft_proc/rnd/smoltalk" # required
|
||||
# mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
|
||||
# # TODO: strip_cols
|
||||
|
||||
# def __post_init__(self):
|
||||
# self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
# self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
# )
|
||||
|
||||
|
||||
# def dataset_offline_preprocess(dataset: datasets.DatasetDict, map_fn: callable, output_dir: str):
|
||||
# # Map with batch_size=1 and num_proc=1 (no batching, single process).
|
||||
# state = accelerate.PartialState()
|
||||
# with state.local_main_process_first():
|
||||
# processed = dataset.map(
|
||||
# map_fn,
|
||||
# batched=False,
|
||||
# num_proc=16,
|
||||
# load_from_cache_file=True,
|
||||
# writer_batch_size=512,
|
||||
# desc="offline preprocessing",
|
||||
# )
|
||||
|
||||
# # # Keep only the three required columns to save space.
|
||||
# # keep = {"input_ids", "labels", "prompt_len"}
|
||||
# # def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
|
||||
# # drop = [c for c in ds.column_names if c not in keep]
|
||||
# # return ds.remove_columns(drop) if drop else ds
|
||||
|
||||
# # if isinstance(processed, datasets.DatasetDict):
|
||||
# # for split in list(processed.keys()):
|
||||
# # processed[split] = strip_cols(processed[split])
|
||||
# # else:
|
||||
# # processed = strip_cols(processed)
|
||||
|
||||
# os.makedirs(output_dir, exist_ok=True)
|
||||
# processed.save_to_disk(output_dir)
|
||||
# print(f"[OK] Saved to: {output_dir}")
|
||||
|
||||
|
||||
# def main():
|
||||
# # Parse with tyro
|
||||
# args = tyro.cli(ScriptArguments)
|
||||
|
||||
# # tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
# tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# # Load your raw dataset (must contain a "messages" field per example).
|
||||
# dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
||||
|
||||
# dataset_offline_preprocess(dataset=dataset, map_fn=None, output_dir=args.output_dir)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
|
||||
from functools import partial
|
||||
import tyro
|
||||
|
||||
import dllm
|
||||
from dllm.tools.preprocess_sft_dataset import ScriptArguments, preprocess_sft_dataset
|
||||
|
||||
|
||||
def main():
|
||||
from examples.rnd.sft import sft_map_fn
|
||||
|
||||
# Parse with tyro
|
||||
args = tyro.cli(ScriptArguments)
|
||||
|
||||
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# Load your raw dataset (must contain a "messages" field per example).
|
||||
dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
||||
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=args.mask_prompt_loss,
|
||||
)
|
||||
preprocess_sft_dataset(
|
||||
dataset=dataset,
|
||||
map_fn=map_fn,
|
||||
output_dir=args.output_dir,
|
||||
remove_columns=args.remove_columns,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
199
dllm/examples/rnd/sft.py
Normal file
199
dllm/examples/rnd/sft.py
Normal file
@ -0,0 +1,199 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/rnd/sft.py
|
||||
|
||||
- 8 GPUs (DeepSpeed ZeRO-2):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/rnd/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 GPU:
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "single_gpu" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "zero2" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
import peft
|
||||
import datasets
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import rnd
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
moe_backend: str = "hf"
|
||||
attn_implementation: str = "sdpa"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]"
|
||||
truncation: str = "right"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]"
|
||||
# rnd specific
|
||||
group_by_length: bool = True
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
freeze_gate: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)."
|
||||
},
|
||||
)
|
||||
freeze_embedding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "If True, freeze embedding parameters."},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
moe_backend=model_args.moe_backend,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
model = dllm.utils.get_model(model_args=model_args, config=config)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optionally freeze modules ----------------------------------------------
|
||||
if not isinstance(model, peft.PeftModel):
|
||||
if getattr(training_args, "freeze_gate", False):
|
||||
for n, m in model.named_modules():
|
||||
if n.endswith(".gate"): # only router gate, not gate_proj
|
||||
for p in m.parameters(recurse=False):
|
||||
p.requires_grad_(False)
|
||||
|
||||
if getattr(training_args, "freeze_embedding", False):
|
||||
# model.model.embed_tokens.requires_grad_(False)
|
||||
model.model.embed_tokens.weight.requires_grad_(False)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
def sft_map_fn(row) -> dict:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
if training_args.mask_prompt_loss:
|
||||
# use -100 in labels to indicate positions where tokens should not be masked
|
||||
# and loss is ignored; all other positions match `input_ids`
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can make predictions for the first mask token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on the BOS token
|
||||
# `prompt_len` helps `post_process_dataset` truncate long sequences properly
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
# "attention_mask": [1.0] * len(prompt_response_tokens),
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
if not data_args.load_from_disk:
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(data_args.dataset_args)
|
||||
dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
else:
|
||||
from datasets import disable_caching
|
||||
|
||||
disable_caching()
|
||||
dataset = datasets.load_from_disk(data_args.dataset_args)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
@dataclass
|
||||
class RNDSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
# RND is finetuned on padding <eos_token>
|
||||
outputs.pop("attention_mask")
|
||||
# temp fix here (`group_by_length=True` leads to shape mismatch)
|
||||
# clip seq_len (second dim) to the same for outputs `input_ids, labels`
|
||||
import torch
|
||||
|
||||
keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs]
|
||||
if keys_to_clip:
|
||||
# Get smallest seq_len to avoid out-of-bounds
|
||||
min_len = min(
|
||||
outputs[k].size(1)
|
||||
for k in keys_to_clip
|
||||
if isinstance(outputs[k], torch.Tensor)
|
||||
)
|
||||
for k in keys_to_clip:
|
||||
t = outputs[k]
|
||||
if isinstance(t, torch.Tensor) and t.size(1) != min_len:
|
||||
outputs[k] = t[:, :min_len]
|
||||
return outputs
|
||||
|
||||
trainer = rnd.RNDTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
data_collator=RNDSFTCollator(
|
||||
tokenizer,
|
||||
# pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # RND is finetuned on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
199
dllm/examples/rnd/sft_v2.py
Normal file
199
dllm/examples/rnd/sft_v2.py
Normal file
@ -0,0 +1,199 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/rnd/sft.py
|
||||
|
||||
- 8 GPUs (DeepSpeed ZeRO-2):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/rnd/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 GPU:
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "single_gpu" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "zero2" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
import peft
|
||||
import datasets
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import rnd
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
moe_backend: str = "hf"
|
||||
attn_implementation: str = "sdpa"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]"
|
||||
truncation: str = "right"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]"
|
||||
# rnd specific
|
||||
group_by_length: bool = True
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
freeze_gate: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)."
|
||||
},
|
||||
)
|
||||
freeze_embedding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "If True, freeze embedding parameters."},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
moe_backend=model_args.moe_backend,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
model = dllm.utils.get_model(model_args=model_args, config=config)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optionally freeze modules ----------------------------------------------
|
||||
if not isinstance(model, peft.PeftModel):
|
||||
if getattr(training_args, "freeze_gate", False):
|
||||
for n, m in model.named_modules():
|
||||
if n.endswith(".gate"): # only router gate, not gate_proj
|
||||
for p in m.parameters(recurse=False):
|
||||
p.requires_grad_(False)
|
||||
|
||||
if getattr(training_args, "freeze_embedding", False):
|
||||
# model.model.embed_tokens.requires_grad_(False)
|
||||
model.model.embed_tokens.weight.requires_grad_(False)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
def sft_map_fn(row) -> dict:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
if training_args.mask_prompt_loss:
|
||||
# use -100 in labels to indicate positions where tokens should not be masked
|
||||
# and loss is ignored; all other positions match `input_ids`
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can make predictions for the first mask token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on the BOS token
|
||||
# `prompt_len` helps `post_process_dataset` truncate long sequences properly
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
# "attention_mask": [1.0] * len(prompt_response_tokens),
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
if not data_args.load_from_disk:
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(data_args.dataset_args)
|
||||
dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
else:
|
||||
dataset = datasets.load_from_disk(data_args.dataset_args)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
@dataclass
|
||||
class RNDSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
# RND is finetuned on padding <eos_token>
|
||||
outputs.pop("attention_mask")
|
||||
# temp fix here (`group_by_length=True` leads to shape mismatch)
|
||||
# clip seq_len (second dim) to the same for outputs `input_ids, labels`
|
||||
# TODO -> FIXED: clip all relevant tensors to a common seq_len
|
||||
# Determine common length across present tensors
|
||||
import torch
|
||||
|
||||
keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs]
|
||||
if keys_to_clip:
|
||||
# Get smallest seq_len to avoid out-of-bounds
|
||||
min_len = min(
|
||||
outputs[k].size(1)
|
||||
for k in keys_to_clip
|
||||
if isinstance(outputs[k], torch.Tensor)
|
||||
)
|
||||
for k in keys_to_clip:
|
||||
t = outputs[k]
|
||||
if isinstance(t, torch.Tensor) and t.size(1) != min_len:
|
||||
outputs[k] = t[:, :min_len]
|
||||
return outputs
|
||||
|
||||
tokenizer.pad_token_id = tokenizer.mask_token_ids
|
||||
trainer = rnd.RNDTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
data_collator=RNDSFTCollator(
|
||||
tokenizer,
|
||||
# pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=-100,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
219
dllm/examples/rnd/sft_v3.py
Normal file
219
dllm/examples/rnd/sft_v3.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/rnd/sft.py
|
||||
|
||||
- 8 GPUs (DeepSpeed ZeRO-2):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/rnd/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 GPU:
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "zero2" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
import peft
|
||||
import datasets
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import rnd
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
moe_backend: str = "hf"
|
||||
attn_implementation: str = "sdpa"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]"
|
||||
truncation: str = "right"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]"
|
||||
# rnd specific
|
||||
# group_by_length: bool = True
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
freeze_gate: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)."
|
||||
},
|
||||
)
|
||||
freeze_embedding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "If True, freeze embedding parameters."},
|
||||
)
|
||||
perbatch_cutoff: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": (
|
||||
"Randomly pick a response length from batch and trim other responses. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
resp_cutoff_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cutting sequences during training. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
moe_backend=model_args.moe_backend,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
model = dllm.utils.get_model(model_args=model_args, config=config)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optionally freeze modules ----------------------------------------------
|
||||
if not isinstance(model, peft.PeftModel):
|
||||
if getattr(training_args, "freeze_gate", False):
|
||||
for n, m in model.named_modules():
|
||||
if n.endswith(".gate"): # only router gate, not gate_proj
|
||||
for p in m.parameters(recurse=False):
|
||||
p.requires_grad_(False)
|
||||
|
||||
if getattr(training_args, "freeze_embedding", False):
|
||||
# model.model.embed_tokens.requires_grad_(False)
|
||||
model.model.embed_tokens.weight.requires_grad_(False)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
def sft_map_fn(row) -> dict:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
if training_args.mask_prompt_loss:
|
||||
# use -100 in labels to indicate positions where tokens should not be masked
|
||||
# and loss is ignored; all other positions match `input_ids`
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can make predictions for the first mask token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on the BOS token
|
||||
# `prompt_len` helps `post_process_dataset` truncate long sequences properly
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(prompt_response_tokens),
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
if not data_args.load_from_disk:
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(data_args.dataset_args)
|
||||
dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
else:
|
||||
dataset = datasets.load_from_disk(data_args.dataset_args)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
# @dataclass
|
||||
# class RNDSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
# def __call__(self, features, return_tensors=None):
|
||||
# outputs = super().__call__(features, return_tensors)
|
||||
# # RND is finetuned on padding <eos_token>
|
||||
# outputs.pop("attention_mask")
|
||||
# # temp fix here (`group_by_length=True` leads to shape mismatch)
|
||||
# # clip seq_len (second dim) to the same for outputs `input_ids, labels`
|
||||
# import torch
|
||||
# keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs]
|
||||
# if keys_to_clip:
|
||||
# # Get smallest seq_len to avoid out-of-bounds
|
||||
# min_len = min(outputs[k].size(1) for k in keys_to_clip if isinstance(outputs[k], torch.Tensor))
|
||||
# for k in keys_to_clip:
|
||||
# t = outputs[k]
|
||||
# if isinstance(t, torch.Tensor) and t.size(1) != min_len:
|
||||
# outputs[k] = t[:, :min_len]
|
||||
# return outputs
|
||||
trainer = rnd.RNDTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
# data_collator=RNDSFTCollator(
|
||||
# tokenizer,
|
||||
# # pad_to_multiple_of=8,
|
||||
# return_tensors="pt",
|
||||
# padding=True,
|
||||
# label_pad_token_id=-100, # RND is finetuned on padding <eos_token>
|
||||
# ),
|
||||
data_collator=dllm.pipelines.dream.utils.DreamSFTCollator(
|
||||
tokenizer,
|
||||
# pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=-100,
|
||||
perbatch_cutoff=training_args.perbatch_cutoff,
|
||||
resp_cutoff_ratio=training_args.resp_cutoff_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user