1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

174
dllm/examples/llada/pt.py Normal file
View File

@ -0,0 +1,174 @@
"""
Local users
------------
- 1 GPU (4bit quant & LoRA, useful for testing):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/llada/pt.py \
--load_in_4bit True --lora True
- 8 GPUs (FSDP):
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/llada/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 24 Nodes, 192 GPUs (FSDP):
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/pt.py"
"""
import os
import functools
from dataclasses import dataclass, field
import torch
import transformers
import accelerate
import dllm
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
# Uses only the configuration from model_name_or_path to initialize the model from scratch
model_name_or_path: str = (
"GSAI-ML/LLaDA-8B-Base" # "inclusionAI/LLaDA-MoE-7B-A1B-Base"
)
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]"
text_field: str = "text"
streaming: bool = True
drop_tail: bool = True
insert_eos: bool = field(
default=True,
metadata={
"help": "False when adjacent samples from the datasets are semantically coherent."
},
)
random_length_ratio: float = field(
default=0.01,
metadata={
"help": (
"The probability of randomly cut sequences during training. "
"See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training for reference."
)
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = (
"models/LLaDA-8B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
learning_rate: float = 3e-4
max_steps: int = 2_000
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 4
eval_steps: float = 0.05
save_steps: float = 0.05
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary for streaming dataset
if data_args.streaming:
training_args.accelerator_config.dispatch_batches = False
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
# initialize model weights from scratch
config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
with dllm.utils.init_device_context_manager():
model = transformers.AutoModel.from_config(
config, dtype=torch.bfloat16, init_params=True
)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Optional PEFT: LoRA ----------------------------------------------------
model = dllm.utils.load_peft(model=model, model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_pt_dataset(
data_args.dataset_args,
streaming=data_args.streaming,
)
dataset = dataset.map(
functools.partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=data_args.text_field,
seq_length=data_args.max_length,
insert_eos=data_args.insert_eos,
drop_tail=data_args.drop_tail,
),
batched=True,
remove_columns=dataset["train"].column_names,
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
)
if data_args.streaming:
dataset = dataset.shuffle(seed=training_args.seed)
# ----- Training --------------------------------------------------------------
@dataclass
class LLaDAPTCollator(transformers.DataCollatorForSeq2Seq):
# Reference: https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training
# By default, 1% of the pre-training data are truncated to a random length
random_length_ratio: float = 0.01
def __call__(self, features, return_tensors=None):
outputs = super().__call__(features, return_tensors)
if torch.rand(1) < self.random_length_ratio:
random_length = torch.randint(
1, outputs["input_ids"].shape[1] + 1, (1,)
)
for key in ["input_ids", "labels", "attention_mask"]:
if key in outputs:
outputs[key] = outputs[key][:, :random_length]
# Check if attention_mask is all ones and set it to None
if torch.all(outputs["attention_mask"] == 1):
outputs.pop("attention_mask")
return outputs
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
data_collator=LLaDAPTCollator(
tokenizer,
return_tensors="pt",
padding=True,
random_length_ratio=data_args.random_length_ratio,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()