1127 update to latest
This commit is contained in:
192
dllm/examples/editflow/sft.py
Normal file
192
dllm/examples/editflow/sft.py
Normal file
@ -0,0 +1,192 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import editflow
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = None # overwrite this
|
||||
lm_head_key: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"The key to the `lm_head` in the source model for initializing operation heads in the EditFlow model. "
|
||||
"Overwrite this when `init_editflow_from_src` = True"
|
||||
)
|
||||
},
|
||||
)
|
||||
init_editflow_from_src: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to initialize EditFlow model from the source model."
|
||||
},
|
||||
)
|
||||
init_editflow_from_editflow: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = None # overwrite this
|
||||
per_device_train_batch_size: int = 2
|
||||
per_device_eval_batch_size: int = 2
|
||||
learning_rate: float = 5e-5
|
||||
# EditFlow specific args
|
||||
scheduler_cls: str = field(
|
||||
default="LinearKappaScheduler",
|
||||
metadata={
|
||||
"help": (
|
||||
"The scheduler class controlling κ(t). "
|
||||
"Available options: see `dllm/utils/schedulers/kappa.py`"
|
||||
)
|
||||
},
|
||||
)
|
||||
normalize_per_position: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to normalize the loss per position."},
|
||||
)
|
||||
max_w: float = field(
|
||||
default=20.0,
|
||||
metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."},
|
||||
)
|
||||
x0_sampler: str = field(
|
||||
default="masks[length:128]",
|
||||
metadata={
|
||||
"help": (
|
||||
"Choose the x0 sampler. "
|
||||
"Available options: see `dllm/pipelines/editflow/utils.py`"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
|
||||
# - `input_ids`` = prompt + response
|
||||
# - `prompt_len` marks the prompt span to EXCLUDE from loss.
|
||||
# (Remove prompt_len to train on all tokens—if so, ensure a BOS is prepended.)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
if mask_prompt_loss:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can insert to the left of the very first token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
prompt_response_tokens = [tokenizer.bos_token_id] + prompt_response_tokens
|
||||
return {"input_ids": prompt_response_tokens}
|
||||
|
||||
|
||||
def train(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: TrainingArguments,
|
||||
ef_config_cls: type[transformers.PretrainedConfig],
|
||||
):
|
||||
# necessary when batch does not contain "labels" field
|
||||
training_args.label_names = []
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Load EditFlow Model ----------------------------------------------------
|
||||
if model_args.init_editflow_from_editflow:
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
else:
|
||||
ef_cfg = ef_config_cls.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
dtype=model_args.dtype,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
with dllm.utils.init_device_context_manager():
|
||||
model = transformers.AutoModel.from_config(ef_cfg)
|
||||
if model_args.init_editflow_from_src:
|
||||
# Load src model config & weights (bf16 on CUDA) for intializing EditFlow model
|
||||
src_model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, dtype=model_args.dtype
|
||||
)
|
||||
# Initialize EditFlow model from the src model: copies backbone & clones lm_head
|
||||
editflow.utils.init_editflow_from_src(
|
||||
model, src_model, lm_head_key=model_args.lm_head_key
|
||||
)
|
||||
del src_model
|
||||
model = dllm.utils.load_peft(model, model_args)
|
||||
|
||||
def _no_flops(*args, **kwargs):
|
||||
return 0.0
|
||||
|
||||
model.floating_point_ops = _no_flops
|
||||
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = editflow.EditFlowTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=editflow.utils.EditFlowCollator(
|
||||
tokenizer=tokenizer, x0_sampler=training_args.x0_sampler
|
||||
),
|
||||
scheduler=dllm.core.schedulers.make_kappa_scheduler(
|
||||
training_args.scheduler_cls
|
||||
),
|
||||
normalize_per_position=training_args.normalize_per_position,
|
||||
max_w=training_args.max_w,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
Reference in New Issue
Block a user