Files
MIDIFoundationModel/dllm/examples/editflow/bert/sft.py
2025-11-27 15:44:17 +08:00

45 lines
1.3 KiB
Python

from dataclasses import dataclass
import transformers
import dllm
from examples.editflow import sft as editflow_sft
@dataclass
class ModelArguments(editflow_sft.ModelArguments):
model_name_or_path: str = "answerdotai/ModernBERT-large"
lm_head_key: str = "decoder"
@dataclass
class DataArguments(editflow_sft.DataArguments):
dataset_args: str = "tatsu-lab/alpaca"
max_length: int = 512
@dataclass
class TrainingArguments(editflow_sft.TrainingArguments):
output_dir: str = "models/EditFlow/ModernBERT-large/alpaca"
num_train_epochs: float = 20
learning_rate: float = 3e-4
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
eval_steps: float = 0.1
save_steps: float = 0.1
x0_sampler: str = "masks[length:64]"
if __name__ == "__main__":
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
editflow_sft.train(
model_args=model_args,
data_args=data_args,
training_args=training_args,
ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig,
)