1127 update to latest
This commit is contained in:
77
dllm/examples/editflow/bert/README.md
Normal file
77
dllm/examples/editflow/bert/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
# Edit Flows - BERT
|
||||
|
||||
> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018)
|
||||
|
||||
|
||||
## Warmup
|
||||
|
||||
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text with EditFlow.
|
||||
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
|
||||
|
||||
### Pretrain
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
|
||||
```shell
|
||||
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/editflow/bert/pt.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "Trelis/tiny-shakespeare" \
|
||||
--text_field "Text" \
|
||||
--insert_eos False \
|
||||
--max_length 128 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/tiny-shakespeare"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
PYTHONPATH=. python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
|
||||
--tau 0.01 --mask_length 64 --seed 42 --make_gif
|
||||
|
||||
# see `decode_trace.gif`
|
||||
```
|
||||
|
||||
|
||||
### SFT
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
|
||||
```shell
|
||||
PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/editflow/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "tatsu-lab/alpaca" \
|
||||
--max_length 512 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/alpaca"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
PYTHONPATH=. python examples/editflow/generate.py \
|
||||
--model_name_or_path "models/EditFlow/ModernBERT-large/alpaca/checkpoint-final" \
|
||||
--prompt "Could you please write a poem for me?" --tau 0.01 --mask_length 64 --seed 42 --make_gif
|
||||
|
||||
# see `decode_trace.gif`
|
||||
```
|
||||
|
||||
<!-- ```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/editflow/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--x0_sampler "masks[length:64]" \
|
||||
--output_dir "models/EditFlow/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
``` -->
|
||||
48
dllm/examples/editflow/bert/pt.py
Normal file
48
dllm/examples/editflow/bert/pt.py
Normal file
@ -0,0 +1,48 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import pt as editflow_pt
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_pt.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
lm_head_key: str = "decoder"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_pt.DataArguments):
|
||||
dataset_args: str = "Trelis/tiny-shakespeare"
|
||||
text_field: str = "Text"
|
||||
max_length: int = 128
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_pt.TrainingArguments):
|
||||
output_dir: str = "models/EditFlow/ModernBERT-large/tiny-shakespeare"
|
||||
num_train_epochs: float = 20
|
||||
learning_rate: float = 3e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
x0_sampler: str = "masks[length:64]"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_pt.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
|
||||
)
|
||||
44
dllm/examples/editflow/bert/sft.py
Normal file
44
dllm/examples/editflow/bert/sft.py
Normal file
@ -0,0 +1,44 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from examples.editflow import sft as editflow_sft
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(editflow_sft.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
lm_head_key: str = "decoder"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(editflow_sft.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
max_length: int = 512
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(editflow_sft.TrainingArguments):
|
||||
output_dir: str = "models/EditFlow/ModernBERT-large/alpaca"
|
||||
num_train_epochs: float = 20
|
||||
learning_rate: float = 3e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
x0_sampler: str = "masks[length:64]"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
editflow_sft.train(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
|
||||
)
|
||||
Reference in New Issue
Block a user