181 lines
6.3 KiB
Python
181 lines
6.3 KiB
Python
import torch
|
|
import accelerate
|
|
import transformers
|
|
from peft import prepare_model_for_kbit_training
|
|
|
|
from dllm.utils.utils import disable_caching_allocator_warmup, print_main, load_peft
|
|
from dllm.utils.configs import ModelArguments, TrainingArguments
|
|
|
|
|
|
def get_model(
|
|
model_args,
|
|
config: transformers.PretrainedConfig | None = None,
|
|
) -> transformers.PreTrainedModel:
|
|
"""
|
|
Load a model with flexible input sources.
|
|
|
|
Args:
|
|
model_args: An optional dataclass or namespace containing model parameters.
|
|
model_name_or_path: Optional direct model path or name (overrides model_args.model_name_or_path).
|
|
dtype: Dtype (string or torch.dtype).
|
|
load_in_4bit: Whether to load using 4-bit quantization (can override model_args.load_in_4bit).
|
|
|
|
Returns:
|
|
transformers.PreTrainedModel
|
|
"""
|
|
model_name_or_path = getattr(model_args, "model_name_or_path")
|
|
dtype = getattr(model_args, "dtype", "bfloat16")
|
|
load_in_4bit = getattr(model_args, "load_in_4bit", False)
|
|
attn_implementation = getattr(model_args, "attn_implementation", None)
|
|
|
|
# Device map: skip when ZeRO-3
|
|
device_map = (
|
|
{"": accelerate.PartialState().local_process_index}
|
|
if not transformers.modeling_utils.is_deepspeed_zero3_enabled()
|
|
and torch.cuda.is_available()
|
|
else None
|
|
)
|
|
|
|
quant_config = None
|
|
if load_in_4bit and transformers.utils.is_bitsandbytes_available():
|
|
quant_config = transformers.BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_compute_dtype=dtype,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
)
|
|
|
|
params = {
|
|
"dtype": dtype,
|
|
"device_map": device_map,
|
|
"quantization_config": quant_config,
|
|
"attn_implementation": attn_implementation,
|
|
"config": config,
|
|
}
|
|
|
|
try:
|
|
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
|
model_name_or_path, **params
|
|
)
|
|
except:
|
|
model = transformers.AutoModel.from_pretrained(model_name_or_path, **params)
|
|
|
|
# --- if quantized, prepare for LoRA / QLoRA training ---
|
|
if load_in_4bit and quant_config is not None:
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
|
|
|
|
# Optionally train with lora
|
|
model = load_peft(model, model_args)
|
|
|
|
return model
|
|
|
|
|
|
def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer:
|
|
"""
|
|
Load a tokenizer with flexible input sources.
|
|
|
|
Args:
|
|
model_args: Optional dataclass or namespace containing model parameters.
|
|
model: Optional model instance to configure tokenizer behavior.
|
|
model_name_or_path: Optional direct model name or path (overrides model_args.model_name_or_path).
|
|
|
|
Returns:
|
|
transformers.PreTrainedTokenizer
|
|
"""
|
|
# Lazy imports to avoid circular dependencies
|
|
from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM
|
|
from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM
|
|
from dllm.pipelines.dream.models.modeling_dream import DreamModel
|
|
from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
|
|
from transformers import (
|
|
BertPreTrainedModel,
|
|
RobertaPreTrainedModel,
|
|
ModernBertPreTrainedModel,
|
|
)
|
|
|
|
model_name_or_path = getattr(model_args, "model_name_or_path")
|
|
|
|
# ---------------- Tokenizer loading ----------------
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
model_name_or_path,
|
|
padding_side="right",
|
|
)
|
|
|
|
assert tokenizer.eos_token != None or tokenizer.pad_token != None
|
|
|
|
if not tokenizer.pad_token:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
if not tokenizer.eos_token:
|
|
tokenizer.eos_token = tokenizer.pad_token
|
|
if not tokenizer.bos_token:
|
|
tokenizer.bos_token = tokenizer.pad_token
|
|
|
|
# If model is not provided, return as-is
|
|
model_cfg = transformers.AutoConfig.from_pretrained(model_name_or_path)
|
|
model_cls = transformers.AutoModel._model_mapping[type(model_cfg)]
|
|
|
|
# ---------------- Model-specific customization ----------------
|
|
if issubclass(model_cls, LLaDAModelLM):
|
|
tokenizer.add_special_tokens({"mask_token": "<|mdm_mask|>"})
|
|
tokenizer.eot_token = "<|eot_id|>"
|
|
# tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) # can not do this for llada base directly
|
|
# TODO: for llada base, add special_tokens = {"<|start_header_id|>": 126346, "<|end_header_id|>": 126347, "<|eot_id|>": 126348}
|
|
# fix bugs in chat template
|
|
tokenizer.chat_template = """\
|
|
{% set loop_messages = messages %}
|
|
{% for message in loop_messages %}
|
|
{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}
|
|
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
|
|
|
|
{{ message['content'] | trim }}<|eot_id|>
|
|
{%- endfor %}
|
|
{% if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
|
|
<|start_header_id|>assistant<|end_header_id|>
|
|
|
|
{% endif %}
|
|
"""
|
|
elif issubclass(model_cls, LLaDAMoEModelLM):
|
|
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
|
|
tokenizer.eot_token = "<|role_end|>"
|
|
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
|
|
elif issubclass(model_cls, DreamModel):
|
|
tokenizer.eot_token = "<|im_end|>"
|
|
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
|
|
elif issubclass(model_cls, RND1LM):
|
|
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
|
|
elif issubclass(
|
|
model_cls,
|
|
(BertPreTrainedModel, RobertaPreTrainedModel, ModernBertPreTrainedModel),
|
|
):
|
|
tokenizer.eot_token = "[/Answer]"
|
|
tokenizer.chat_template = """\
|
|
{% if messages[0]['role'] == 'system' %}
|
|
[SYS]
|
|
{{ messages[0]['content'] | trim }}
|
|
[/SYS]
|
|
|
|
{% set loop_messages = messages[1:] %}
|
|
{% else %}
|
|
{% set loop_messages = messages %}
|
|
{% endif -%}
|
|
{%- for message in loop_messages %}
|
|
{% if message['role'] == 'user' %}
|
|
[Question]
|
|
{{ message['content'] | trim }}
|
|
[/Question]
|
|
|
|
{% elif message['role'] == 'assistant' %}
|
|
[Answer]
|
|
{{ message['content'] | trim }}
|
|
[/Answer]
|
|
|
|
{% endif %}
|
|
{% endfor -%}
|
|
{%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
|
|
[Answer]
|
|
{% endif %}
|
|
"""
|
|
else:
|
|
print_main("no tokenizer customization for model class:", model_cls)
|
|
return tokenizer
|