115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
# """
|
|
# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --cpus-per-task=12 --time=03:00:000
|
|
|
|
# python examples/rnd/preprocess.py --dataset_args "HuggingFaceTB/smoltalk" --output_dir "data/sft_proc/rnd/smoltalk"
|
|
# """
|
|
# import os
|
|
# from dataclasses import dataclass
|
|
# from typing import Dict, Any
|
|
|
|
# import datasets
|
|
# import transformers
|
|
# import accelerate
|
|
# import tyro
|
|
|
|
# import dllm
|
|
|
|
|
|
# # --- tyro: define dataclass for CLI args ---
|
|
# @dataclass
|
|
# class ScriptArguments:
|
|
# """Preprocess SFT dataset (batch_size=1 only)"""
|
|
# model_name_or_path: str = "radicalnumerics/RND1-Base-0910"
|
|
# dataset_args: str = "HuggingFaceTB/smoltalk" # required
|
|
# output_dir: str = "data/sft_proc/rnd/smoltalk" # required
|
|
# mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
|
|
# # TODO: strip_cols
|
|
|
|
# def __post_init__(self):
|
|
# self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
|
# self.model_name_or_path, "BASE_MODELS_DIR"
|
|
# )
|
|
|
|
|
|
# def dataset_offline_preprocess(dataset: datasets.DatasetDict, map_fn: callable, output_dir: str):
|
|
# # Map with batch_size=1 and num_proc=1 (no batching, single process).
|
|
# state = accelerate.PartialState()
|
|
# with state.local_main_process_first():
|
|
# processed = dataset.map(
|
|
# map_fn,
|
|
# batched=False,
|
|
# num_proc=16,
|
|
# load_from_cache_file=True,
|
|
# writer_batch_size=512,
|
|
# desc="offline preprocessing",
|
|
# )
|
|
|
|
# # # Keep only the three required columns to save space.
|
|
# # keep = {"input_ids", "labels", "prompt_len"}
|
|
# # def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
|
|
# # drop = [c for c in ds.column_names if c not in keep]
|
|
# # return ds.remove_columns(drop) if drop else ds
|
|
|
|
# # if isinstance(processed, datasets.DatasetDict):
|
|
# # for split in list(processed.keys()):
|
|
# # processed[split] = strip_cols(processed[split])
|
|
# # else:
|
|
# # processed = strip_cols(processed)
|
|
|
|
# os.makedirs(output_dir, exist_ok=True)
|
|
# processed.save_to_disk(output_dir)
|
|
# print(f"[OK] Saved to: {output_dir}")
|
|
|
|
|
|
# def main():
|
|
# # Parse with tyro
|
|
# args = tyro.cli(ScriptArguments)
|
|
|
|
# # tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
|
|
# tokenizer = dllm.utils.get_tokenizer(args)
|
|
|
|
# # Load your raw dataset (must contain a "messages" field per example).
|
|
# dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
|
|
|
# dataset_offline_preprocess(dataset=dataset, map_fn=None, output_dir=args.output_dir)
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# main()
|
|
|
|
|
|
from functools import partial
|
|
import tyro
|
|
|
|
import dllm
|
|
from dllm.tools.preprocess_sft_dataset import ScriptArguments, preprocess_sft_dataset
|
|
|
|
|
|
def main():
|
|
from examples.rnd.sft import sft_map_fn
|
|
|
|
# Parse with tyro
|
|
args = tyro.cli(ScriptArguments)
|
|
|
|
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
|
|
tokenizer = dllm.utils.get_tokenizer(args)
|
|
|
|
# Load your raw dataset (must contain a "messages" field per example).
|
|
dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
|
|
|
map_fn = partial(
|
|
sft_map_fn,
|
|
tokenizer=tokenizer,
|
|
mask_prompt_loss=args.mask_prompt_loss,
|
|
)
|
|
preprocess_sft_dataset(
|
|
dataset=dataset,
|
|
map_fn=map_fn,
|
|
output_dir=args.output_dir,
|
|
remove_columns=args.remove_columns,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|