1127 update to latest
This commit is contained in:
1
dllm/examples/rnd/README.md
Normal file
1
dllm/examples/rnd/README.md
Normal file
@ -0,0 +1 @@
|
||||
WIP
|
||||
114
dllm/examples/rnd/preprocess.py
Normal file
114
dllm/examples/rnd/preprocess.py
Normal file
@ -0,0 +1,114 @@
|
||||
# """
|
||||
# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --cpus-per-task=12 --time=03:00:000
|
||||
|
||||
# python examples/rnd/preprocess.py --dataset_args "HuggingFaceTB/smoltalk" --output_dir "data/sft_proc/rnd/smoltalk"
|
||||
# """
|
||||
# import os
|
||||
# from dataclasses import dataclass
|
||||
# from typing import Dict, Any
|
||||
|
||||
# import datasets
|
||||
# import transformers
|
||||
# import accelerate
|
||||
# import tyro
|
||||
|
||||
# import dllm
|
||||
|
||||
|
||||
# # --- tyro: define dataclass for CLI args ---
|
||||
# @dataclass
|
||||
# class ScriptArguments:
|
||||
# """Preprocess SFT dataset (batch_size=1 only)"""
|
||||
# model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
# dataset_args: str = "HuggingFaceTB/smoltalk" # required
|
||||
# output_dir: str = "data/sft_proc/rnd/smoltalk" # required
|
||||
# mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
|
||||
# # TODO: strip_cols
|
||||
|
||||
# def __post_init__(self):
|
||||
# self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
# self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
# )
|
||||
|
||||
|
||||
# def dataset_offline_preprocess(dataset: datasets.DatasetDict, map_fn: callable, output_dir: str):
|
||||
# # Map with batch_size=1 and num_proc=1 (no batching, single process).
|
||||
# state = accelerate.PartialState()
|
||||
# with state.local_main_process_first():
|
||||
# processed = dataset.map(
|
||||
# map_fn,
|
||||
# batched=False,
|
||||
# num_proc=16,
|
||||
# load_from_cache_file=True,
|
||||
# writer_batch_size=512,
|
||||
# desc="offline preprocessing",
|
||||
# )
|
||||
|
||||
# # # Keep only the three required columns to save space.
|
||||
# # keep = {"input_ids", "labels", "prompt_len"}
|
||||
# # def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
|
||||
# # drop = [c for c in ds.column_names if c not in keep]
|
||||
# # return ds.remove_columns(drop) if drop else ds
|
||||
|
||||
# # if isinstance(processed, datasets.DatasetDict):
|
||||
# # for split in list(processed.keys()):
|
||||
# # processed[split] = strip_cols(processed[split])
|
||||
# # else:
|
||||
# # processed = strip_cols(processed)
|
||||
|
||||
# os.makedirs(output_dir, exist_ok=True)
|
||||
# processed.save_to_disk(output_dir)
|
||||
# print(f"[OK] Saved to: {output_dir}")
|
||||
|
||||
|
||||
# def main():
|
||||
# # Parse with tyro
|
||||
# args = tyro.cli(ScriptArguments)
|
||||
|
||||
# # tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
# tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# # Load your raw dataset (must contain a "messages" field per example).
|
||||
# dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
||||
|
||||
# dataset_offline_preprocess(dataset=dataset, map_fn=None, output_dir=args.output_dir)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
|
||||
from functools import partial
|
||||
import tyro
|
||||
|
||||
import dllm
|
||||
from dllm.tools.preprocess_sft_dataset import ScriptArguments, preprocess_sft_dataset
|
||||
|
||||
|
||||
def main():
|
||||
from examples.rnd.sft import sft_map_fn
|
||||
|
||||
# Parse with tyro
|
||||
args = tyro.cli(ScriptArguments)
|
||||
|
||||
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# Load your raw dataset (must contain a "messages" field per example).
|
||||
dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
||||
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=args.mask_prompt_loss,
|
||||
)
|
||||
preprocess_sft_dataset(
|
||||
dataset=dataset,
|
||||
map_fn=map_fn,
|
||||
output_dir=args.output_dir,
|
||||
remove_columns=args.remove_columns,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
199
dllm/examples/rnd/sft.py
Normal file
199
dllm/examples/rnd/sft.py
Normal file
@ -0,0 +1,199 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/rnd/sft.py
|
||||
|
||||
- 8 GPUs (DeepSpeed ZeRO-2):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/rnd/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 GPU:
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "single_gpu" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "zero2" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
import peft
|
||||
import datasets
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import rnd
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
moe_backend: str = "hf"
|
||||
attn_implementation: str = "sdpa"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]"
|
||||
truncation: str = "right"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]"
|
||||
# rnd specific
|
||||
group_by_length: bool = True
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
freeze_gate: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)."
|
||||
},
|
||||
)
|
||||
freeze_embedding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "If True, freeze embedding parameters."},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
moe_backend=model_args.moe_backend,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
model = dllm.utils.get_model(model_args=model_args, config=config)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optionally freeze modules ----------------------------------------------
|
||||
if not isinstance(model, peft.PeftModel):
|
||||
if getattr(training_args, "freeze_gate", False):
|
||||
for n, m in model.named_modules():
|
||||
if n.endswith(".gate"): # only router gate, not gate_proj
|
||||
for p in m.parameters(recurse=False):
|
||||
p.requires_grad_(False)
|
||||
|
||||
if getattr(training_args, "freeze_embedding", False):
|
||||
# model.model.embed_tokens.requires_grad_(False)
|
||||
model.model.embed_tokens.weight.requires_grad_(False)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
def sft_map_fn(row) -> dict:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
if training_args.mask_prompt_loss:
|
||||
# use -100 in labels to indicate positions where tokens should not be masked
|
||||
# and loss is ignored; all other positions match `input_ids`
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can make predictions for the first mask token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on the BOS token
|
||||
# `prompt_len` helps `post_process_dataset` truncate long sequences properly
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
# "attention_mask": [1.0] * len(prompt_response_tokens),
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
if not data_args.load_from_disk:
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(data_args.dataset_args)
|
||||
dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
else:
|
||||
from datasets import disable_caching
|
||||
|
||||
disable_caching()
|
||||
dataset = datasets.load_from_disk(data_args.dataset_args)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
@dataclass
|
||||
class RNDSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
# RND is finetuned on padding <eos_token>
|
||||
outputs.pop("attention_mask")
|
||||
# temp fix here (`group_by_length=True` leads to shape mismatch)
|
||||
# clip seq_len (second dim) to the same for outputs `input_ids, labels`
|
||||
import torch
|
||||
|
||||
keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs]
|
||||
if keys_to_clip:
|
||||
# Get smallest seq_len to avoid out-of-bounds
|
||||
min_len = min(
|
||||
outputs[k].size(1)
|
||||
for k in keys_to_clip
|
||||
if isinstance(outputs[k], torch.Tensor)
|
||||
)
|
||||
for k in keys_to_clip:
|
||||
t = outputs[k]
|
||||
if isinstance(t, torch.Tensor) and t.size(1) != min_len:
|
||||
outputs[k] = t[:, :min_len]
|
||||
return outputs
|
||||
|
||||
trainer = rnd.RNDTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
data_collator=RNDSFTCollator(
|
||||
tokenizer,
|
||||
# pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # RND is finetuned on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
199
dllm/examples/rnd/sft_v2.py
Normal file
199
dllm/examples/rnd/sft_v2.py
Normal file
@ -0,0 +1,199 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/rnd/sft.py
|
||||
|
||||
- 8 GPUs (DeepSpeed ZeRO-2):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/rnd/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 GPU:
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "single_gpu" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "zero2" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
import peft
|
||||
import datasets
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import rnd
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
moe_backend: str = "hf"
|
||||
attn_implementation: str = "sdpa"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]"
|
||||
truncation: str = "right"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]"
|
||||
# rnd specific
|
||||
group_by_length: bool = True
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
freeze_gate: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)."
|
||||
},
|
||||
)
|
||||
freeze_embedding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "If True, freeze embedding parameters."},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
moe_backend=model_args.moe_backend,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
model = dllm.utils.get_model(model_args=model_args, config=config)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optionally freeze modules ----------------------------------------------
|
||||
if not isinstance(model, peft.PeftModel):
|
||||
if getattr(training_args, "freeze_gate", False):
|
||||
for n, m in model.named_modules():
|
||||
if n.endswith(".gate"): # only router gate, not gate_proj
|
||||
for p in m.parameters(recurse=False):
|
||||
p.requires_grad_(False)
|
||||
|
||||
if getattr(training_args, "freeze_embedding", False):
|
||||
# model.model.embed_tokens.requires_grad_(False)
|
||||
model.model.embed_tokens.weight.requires_grad_(False)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
def sft_map_fn(row) -> dict:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
if training_args.mask_prompt_loss:
|
||||
# use -100 in labels to indicate positions where tokens should not be masked
|
||||
# and loss is ignored; all other positions match `input_ids`
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can make predictions for the first mask token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on the BOS token
|
||||
# `prompt_len` helps `post_process_dataset` truncate long sequences properly
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
# "attention_mask": [1.0] * len(prompt_response_tokens),
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
if not data_args.load_from_disk:
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(data_args.dataset_args)
|
||||
dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
else:
|
||||
dataset = datasets.load_from_disk(data_args.dataset_args)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
@dataclass
|
||||
class RNDSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
# RND is finetuned on padding <eos_token>
|
||||
outputs.pop("attention_mask")
|
||||
# temp fix here (`group_by_length=True` leads to shape mismatch)
|
||||
# clip seq_len (second dim) to the same for outputs `input_ids, labels`
|
||||
# TODO -> FIXED: clip all relevant tensors to a common seq_len
|
||||
# Determine common length across present tensors
|
||||
import torch
|
||||
|
||||
keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs]
|
||||
if keys_to_clip:
|
||||
# Get smallest seq_len to avoid out-of-bounds
|
||||
min_len = min(
|
||||
outputs[k].size(1)
|
||||
for k in keys_to_clip
|
||||
if isinstance(outputs[k], torch.Tensor)
|
||||
)
|
||||
for k in keys_to_clip:
|
||||
t = outputs[k]
|
||||
if isinstance(t, torch.Tensor) and t.size(1) != min_len:
|
||||
outputs[k] = t[:, :min_len]
|
||||
return outputs
|
||||
|
||||
tokenizer.pad_token_id = tokenizer.mask_token_ids
|
||||
trainer = rnd.RNDTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
data_collator=RNDSFTCollator(
|
||||
tokenizer,
|
||||
# pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=-100,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
219
dllm/examples/rnd/sft_v3.py
Normal file
219
dllm/examples/rnd/sft_v3.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/rnd/sft.py
|
||||
|
||||
- 8 GPUs (DeepSpeed ZeRO-2):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/rnd/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 GPU:
|
||||
sbatch --gres=gpu:1 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "zero2" \
|
||||
--script_path "examples/rnd/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
import peft
|
||||
import datasets
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import rnd
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
||||
moe_backend: str = "hf"
|
||||
attn_implementation: str = "sdpa"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]"
|
||||
truncation: str = "right"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]"
|
||||
# rnd specific
|
||||
# group_by_length: bool = True
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
freeze_gate: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)."
|
||||
},
|
||||
)
|
||||
freeze_embedding: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "If True, freeze embedding parameters."},
|
||||
)
|
||||
perbatch_cutoff: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": (
|
||||
"Randomly pick a response length from batch and trim other responses. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
resp_cutoff_ratio: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": (
|
||||
"The probability of randomly cutting sequences during training. "
|
||||
"See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# necessary when batch contains customized fields
|
||||
training_args.remove_unused_columns = False
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
config = transformers.AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
moe_backend=model_args.moe_backend,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
model = dllm.utils.get_model(model_args=model_args, config=config)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Optionally freeze modules ----------------------------------------------
|
||||
if not isinstance(model, peft.PeftModel):
|
||||
if getattr(training_args, "freeze_gate", False):
|
||||
for n, m in model.named_modules():
|
||||
if n.endswith(".gate"): # only router gate, not gate_proj
|
||||
for p in m.parameters(recurse=False):
|
||||
p.requires_grad_(False)
|
||||
|
||||
if getattr(training_args, "freeze_embedding", False):
|
||||
# model.model.embed_tokens.requires_grad_(False)
|
||||
model.model.embed_tokens.weight.requires_grad_(False)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
def sft_map_fn(row) -> dict:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
if training_args.mask_prompt_loss:
|
||||
# use -100 in labels to indicate positions where tokens should not be masked
|
||||
# and loss is ignored; all other positions match `input_ids`
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
else:
|
||||
# When training on all tokens, prepend a BOS token (if missing)
|
||||
# so the model can make predictions for the first mask token.
|
||||
if prompt_response_tokens[0] != tokenizer.bos_token_id:
|
||||
bos = [tokenizer.bos_token_id]
|
||||
prompt_response_tokens = bos + prompt_response_tokens
|
||||
prompt_tokens = bos + prompt_tokens
|
||||
labels = bos + labels
|
||||
labels[0] = -100 # ignore loss on the BOS token
|
||||
# `prompt_len` helps `post_process_dataset` truncate long sequences properly
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(prompt_response_tokens),
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
if not data_args.load_from_disk:
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(data_args.dataset_args)
|
||||
dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
else:
|
||||
dataset = datasets.load_from_disk(data_args.dataset_args)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
# @dataclass
|
||||
# class RNDSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
# def __call__(self, features, return_tensors=None):
|
||||
# outputs = super().__call__(features, return_tensors)
|
||||
# # RND is finetuned on padding <eos_token>
|
||||
# outputs.pop("attention_mask")
|
||||
# # temp fix here (`group_by_length=True` leads to shape mismatch)
|
||||
# # clip seq_len (second dim) to the same for outputs `input_ids, labels`
|
||||
# import torch
|
||||
# keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs]
|
||||
# if keys_to_clip:
|
||||
# # Get smallest seq_len to avoid out-of-bounds
|
||||
# min_len = min(outputs[k].size(1) for k in keys_to_clip if isinstance(outputs[k], torch.Tensor))
|
||||
# for k in keys_to_clip:
|
||||
# t = outputs[k]
|
||||
# if isinstance(t, torch.Tensor) and t.size(1) != min_len:
|
||||
# outputs[k] = t[:, :min_len]
|
||||
# return outputs
|
||||
trainer = rnd.RNDTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
# data_collator=RNDSFTCollator(
|
||||
# tokenizer,
|
||||
# # pad_to_multiple_of=8,
|
||||
# return_tensors="pt",
|
||||
# padding=True,
|
||||
# label_pad_token_id=-100, # RND is finetuned on padding <eos_token>
|
||||
# ),
|
||||
data_collator=dllm.pipelines.dream.utils.DreamSFTCollator(
|
||||
tokenizer,
|
||||
# pad_to_multiple_of=8,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=-100,
|
||||
perbatch_cutoff=training_args.perbatch_cutoff,
|
||||
resp_cutoff_ratio=training_args.resp_cutoff_ratio,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user