67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
"""
|
|
Local users
|
|
------------
|
|
- 1 GPU (LoRA, useful for testing):
|
|
accelerate launch \
|
|
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
|
examples/editflow/dream/sft.py \
|
|
--lora True
|
|
|
|
- 8 GPUs (FSDP):
|
|
accelerate launch \
|
|
--config_file scripts/accelerate_configs/zero2.yaml \
|
|
examples/editflow/dream/sft.py
|
|
|
|
Slurm users
|
|
# Note: run `mkdir logs` before running sbatch; and adjust
|
|
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
|
------------
|
|
- 1 Node, 8 GPUs (FSDP):
|
|
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
|
--accelerate_config "fsdp" \
|
|
--script_path "examples/editflow/dream/sft.py"
|
|
|
|
- 2 Nodes, 16 GPUs (FSDP):
|
|
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
|
--accelerate_config "fsdp" \
|
|
--script_path "examples/editflow/dream/sft.py"
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import transformers
|
|
|
|
from examples.editflow import sft as editflow_sft
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments(editflow_sft.ModelArguments):
|
|
model_name_or_path: str = (
|
|
"models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]/checkpoint-final"
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataArguments(editflow_sft.DataArguments):
|
|
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
|
|
|
|
|
@dataclass
|
|
class TrainingArguments(editflow_sft.TrainingArguments):
|
|
output_dir: str = (
|
|
"models/EditFlow-Dream-7B-Instruct-SFT/tulu-3-sft-mixture[train:10000,test:1000]"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# ----- Argument parsing -------------------------------------------------------
|
|
parser = transformers.HfArgumentParser(
|
|
(ModelArguments, DataArguments, TrainingArguments)
|
|
)
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
editflow_sft.train(
|
|
model_args=model_args,
|
|
data_args=data_args,
|
|
training_args=training_args,
|
|
)
|