1127 update to latest
This commit is contained in:
192
dllm/examples/dream/sft.py
Normal file
192
dllm/examples/dream/sft.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU (4bit quant & LoRA, useful for testing):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/dream/sft.py \
|
||||
--load_in_4bit True --lora True
|
||||
|
||||
- 8 GPUs (FSDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (FSDP):
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (FSDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import dream
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "Dream-org/Dream-v0-Base-7B"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]"
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
# Dream SFT specific args
|
||||
perbatch_cutoff: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": (
|
||||
"Randomly pick a response length from batch and trim other responses. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
resp_cutoff_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cutting sequences during training. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/Dream-7B-SFT"
|
||||
group_by_length: bool = True
|
||||
# Dream SFT specific args
|
||||
loss_weight_type: str = field(
|
||||
default="cart[geo_p:0.3]",
|
||||
metadata={
|
||||
"help": (
|
||||
"The loss weight type. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# SFT mapping function
|
||||
# ------------------------------------------------------------------------------
|
||||
def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool) -> dict:
|
||||
"""
|
||||
Build Dream SFT features from a chat-format row.
|
||||
|
||||
Returns:
|
||||
dict with input_ids, labels, attention_mask, prompt_len
|
||||
"""
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
|
||||
if mask_prompt_loss:
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can predict the first token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on BOS
|
||||
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dream.DreamTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
loss_weight_type=training_args.loss_weight_type,
|
||||
data_collator=dream.utils.DreamSFTCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
perbatch_cutoff=data_args.perbatch_cutoff,
|
||||
resp_cutoff_ratio=data_args.resp_cutoff_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user