1127 update to latest
This commit is contained in:
190
dllm/examples/bert/README.md
Normal file
190
dllm/examples/bert/README.md
Normal file
@ -0,0 +1,190 @@
|
||||
# Generative BERT
|
||||
|
||||
[](https://huggingface.co/collections/dllm-collection/bert-chat)
|
||||
[](https://api.wandb.ai/links/asap-zzhou/101h5xvg)
|
||||
|
||||
This directory provides two key sets of resources:
|
||||
|
||||
1. **Toy Examples ([Warmup](#warmup)):** Scripts for pretraining and SFTing any BERT-style model on small datasets to generate text.
|
||||
2. **Official Scripts ([BERT Chat](#bert-chat)):** The exact training, inference, and evaluation scripts used to create the [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) checkpoints, two BERTs finetuned as Chatbots. For a deep dive into experimental results, lessons learned, and more reproduction details, please see our full [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
<p align="center" style="margin-top: 15px;">
|
||||
<img src="/examples/bert/assets/chat.gif" alt="chat" width="70%">
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>
|
||||
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
|
||||
</em>
|
||||
</p>
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/bert
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
|
||||
## Warmup
|
||||
|
||||
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text.
|
||||
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
|
||||
|
||||
### Pretrain
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "Trelis/tiny-shakespeare" \
|
||||
--text_field "Text" \
|
||||
--insert_eos False \
|
||||
--max_length 128 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tiny-shakespeare"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
# just press enter (empty prompt) if you want the model to generate text from scratch
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
|
||||
--chat False --remasking "random" --steps 128 --max_new_tokens 128
|
||||
```
|
||||
|
||||
### SFT
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "tatsu-lab/alpaca" \
|
||||
--max_length 512 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/alpaca"
|
||||
```
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/alpaca/checkpoint-final" --chat True
|
||||
```
|
||||
|
||||
## BERT Chat
|
||||
Here we show the exact commands we use to train and interact with the BERT Chat models:
|
||||
[`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0).
|
||||
For training curves and other details, please see [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
### Training
|
||||
|
||||
To reproduce [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-base/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
To reproduce [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" --chat True
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/bert/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "bert" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=dllm-collection/ModernBERT-large-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
```
|
||||
|
||||
To automatically evaluate [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-base-chat-v0"
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0"
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results are taken from the original paper.
|
||||
> Because the original work does not fully disclose its evaluation techniques or implementation tricks, we reproduce the setup using the best available methods. As a result, our reproduced scores may show a small residual gap relative to the reported numbers. -->
|
||||
|
||||
<!-- | [`GPT-2`](https://huggingface.co/openai-community/gpt2)(reported) | 0.460 | – | | | | | | | |
|
||||
| [`GPT-2`](https://huggingface.co/openai-community/gpt2)(evaluated) | 0.438 | 0.020 | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(reported) | 0.555 | – | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(evaluated) | 0.549 | 0.021 | | | | | | | | -->
|
||||
<!-- <div align="center" style="min-width:1500px;"> -->
|
||||
|
||||
| | LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU |
|
||||
|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
||||
| [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0)(evaluated) | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 |
|
||||
| [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0)(evaluated) | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 |
|
||||
| [`Qwen1.5-0.5B`](https://huggingface.co/Qwen/Qwen1.5-0.5B)(<ins>reported</ins> & evaluated) | 48.6 | <ins>22.0</ins> | <ins>50.5</ins> | <ins>18.3</ins> | <ins>3.1</ins> | <ins>39.2</ins> | 55.0 | 48.2 | <ins>46.6</ins> |
|
||||
| [`Qwen1.5-0.5B-Chat`](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat)(<ins>reported</ins> & evaluated) | 41.2 | <ins>11.3</ins> | <ins>37.2</ins> | 18.2 | 2.1 | <ins>35.0</ins> | 52.0 | 36.9 | 32.2 |
|
||||
| [`gpt2`](https://huggingface.co/openai-community/gpt2)(<ins>reported</ins> & evaluated) | <ins>46.0</ins> | 0.7 | 24.7 | 6.9 | 1.8 | 22.9 | 51.6 | 31.1 | 25.2 |
|
||||
| [`gpt2-medium`](https://huggingface.co/openai-community/gpt2-medium)(<ins>reported</ins> & evaluated) | <ins>55.5</ins> | 2.1 | 24.6 | 17.8 | 1.4 | 22.9 |53.1 | 39.4 | 0.3 |
|
||||
|
||||
|
||||
<p align="left" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-base-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-large-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B-Chat</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/openai-community/gpt2" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2</code>
|
||||
</a>, and
|
||||
<a href="https://huggingface.co/openai-community/gpt2-medium" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2-medium</code>
|
||||
</a>.
|
||||
<ins>Underlined entries</ins> are results from official reports: <a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf" style="color: #808080; text-decoration: none;">GPT-2 paper</a>, <a href="https://qwen.ai/blog?id=qwen1.5" style="color: #808080; text-decoration: none;">Qwen 1.5 blog</a>, and <a href="https://huggingface.co/Qwen/Qwen2-0.5B-Instruct" style="color: #808080; text-decoration: none;">Qwen2-0.5B-Instruct model card</a>. All other results are evaluated using our framework.
|
||||
</p>
|
||||
BIN
dllm/examples/bert/assets/chat.gif
Normal file
BIN
dllm/examples/bert/assets/chat.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.1 MiB |
71
dllm/examples/bert/chat.py
Normal file
71
dllm/examples/bert/chat.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Interactive chat / generation script for Bert models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Raw multi-turn generation (default)
|
||||
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import llada
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
50
dllm/examples/bert/eval.sh
Normal file
50
dllm/examples/bert/eval.sh
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
# ===== Basic Settings =====
|
||||
model_name_or_path="dllm-collection/ModernBERT-large-chat-v0"
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ===== Common arguments =====
|
||||
common_args="--model bert --apply_chat_template" # BERT model is default to use chat template
|
||||
|
||||
# =======================
|
||||
# BERT Instruct (Chat) Tasks
|
||||
# =======================
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks winogrande --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
73
dllm/examples/bert/generate.py
Normal file
73
dllm/examples/bert/generate.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
python -u examples/bert/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 64
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: bert.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
127
dllm/examples/bert/pt.py
Normal file
127
dllm/examples/bert/pt.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "Trelis/tiny-shakespeare"
|
||||
text_field: str = "Text"
|
||||
max_length: int = 128
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-base/tiny-shakespeare"
|
||||
num_train_epochs: int = 20
|
||||
learning_rate: float = 1e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
127
dllm/examples/bert/sft.py
Normal file
127
dllm/examples/bert/sft.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/sft.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
max_length: int = 512
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-large/alpaca"
|
||||
group_by_length: bool = True
|
||||
learning_rate: float = 1e-4
|
||||
num_train_epochs: int = 20
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
dllm.utils.default_sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=dllm.utils.NoAttentionMaskCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user