diff --git a/Amadeus/sub_decoder_utils.py b/Amadeus/sub_decoder_utils.py index 3109ef4..ffe2fc1 100644 --- a/Amadeus/sub_decoder_utils.py +++ b/Amadeus/sub_decoder_utils.py @@ -1,8 +1,18 @@ from math import ceil +from typing import Optional, Union, Literal +from typing_extensions import Unpack import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import functional as F + +from torch import Tensor +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.modules.activation import _is_make_fx_tracing, _check_arg_device, _arg_requires_grad + class MLP(nn.Module): def __init__(self, in_size, out_size, hidden_size, dropout): @@ -49,6 +59,64 @@ class extendedMLP(nn.Module): x = layer(x) return x +class extendedMLP(nn.Module): + def __init__(self, in_size, out_size, num_layers, hidden_size, dropout): + super().__init__() + self.input_size = in_size + + self.layers = nn.ModuleList() + if num_layers == 1: + # Only one layer + self.layers.append(nn.Linear(in_size, out_size)) + return + elif num_layers > 1: + # First layer + self.layers.append(nn.Linear(in_size, hidden_size)) + self.layers.append(nn.Dropout(dropout)) + self.layers.append(nn.ReLU()) + # Intermediate layers + if num_layers > 2: + for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers + self.layers.append(nn.Linear(hidden_size, hidden_size)) + self.layers.append(nn.Dropout(dropout)) + self.layers.append(nn.ReLU()) + # Last layer + self.layers.append(nn.Linear(hidden_size, out_size)) + else: + raise ValueError("num_layers should be a positive integer") + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +class SwiGLUFFN(nn.Module): + def __init__(self, in_size, out_size, num_layers=2, hidden_size=2048, dropout=0.1): + super().__init__() + self.input_size = in_size + + if num_layers == 1: + # 单层情况,直接线性映射 + self.ffn = nn.Linear(in_size, out_size) + elif num_layers == 2: + # 两层时使用 SwiGLU + self.w1 = nn.Linear(in_size, 2 * hidden_size) # 前半主分支,后半门控分支 + self.w2 = nn.Linear(hidden_size, out_size) + self.dropout = nn.Dropout(dropout) + else: + raise ValueError("SwiGLU FFN 仅支持 num_layers=1 或 2") + + def forward(self, x): + if hasattr(self, "ffn"): + return self.ffn(x) + else: + x_proj = self.w1(x) + x_main, x_gate = x_proj.chunk(2, dim=-1) # 一分为二 + x = F.silu(x_main) * x_gate # SwiGLU: silu(a) * b + x = self.dropout(x) + x = self.w2(x) + return x + class multiMLP(nn.Module): def __init__(self, in_size, out_size, hidden_size, dropout, pred_order): super().__init__() @@ -209,6 +277,28 @@ class TransformerLayer(nn.Module): output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']} return output_dict +class TransformerLayerV2(nn.Module): + def __init__(self, dim, num_heads, dropout): + super().__init__() + self.self_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) + self.cross_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) + self.residual_FF = ResidualLayerNormModule(SwiGLUFFN(in_size=dim, out_size=dim, num_layers=2, hidden_size=4*dim, dropout=dropout)) + self.dropout = nn.Dropout(dropout) + + def forward(self, input_dict): + ''' + input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask} + ''' + # self attention + attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self') + + input_dict['input_seq'] = attn_output + # cross attention + attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross') + attn_output = self.residual_FF.forward_mlp(attn_output) + attn_output = self.dropout(attn_output) + output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']} + return output_dict class FeatureEnricher(nn.Module): def __init__(self, dim, num_heads, dropout): super().__init__() @@ -225,4 +315,483 @@ class FeatureEnricher(nn.Module): attn_output = self.residual_FF.forward_mlp(attn_output) attn_output = self.dropout(attn_output) output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']} - return output_dict \ No newline at end of file + return output_dict + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information from different representation subspaces. + + .. note:: + See `this tutorial `_ + for an in depth discussion of the performant building blocks PyTorch offers for building your own + transformer layers. + + Method described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O + + where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``nn.MultiheadAttention`` will use the optimized implementations of + ``scaled_dot_product_attention()`` when possible. + + In addition to support for the new ``scaled_dot_product_attention()`` + function, for speeding up Inference, MHA will use + fastpath inference with support for Nested Tensors, iff: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). + - inputs are batched (3D) with ``batch_first==True`` + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + - autocast is disabled + + If the optimized inference fastpath implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: + https://arxiv.org/abs/2205.14135 + + """ + + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + ) -> None: + if embed_dim <= 0 or num_heads <= 0: + raise ValueError( + f"embed_dim and num_heads must be greater than 0," + f" got embed_dim={embed_dim} and num_heads={num_heads} instead" + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super().__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> tuple[Tensor, Optional[Tensor]]: + r"""Compute attention outputs using query, key, and value embeddings. + + Supports optional parameters for padding, masks and attention weights. + + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and float masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` + and achieve the best performance for MHA. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + If both attn_mask and key_padding_mask are supplied, their types should match. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + is_causal: If specified, applies a causal mask as attention mask. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``attn_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ # noqa: B950 + why_not_fast_path = "" + if ( + (attn_mask is not None and torch.is_floating_point(attn_mask)) + or (key_padding_mask is not None) + and torch.is_floating_point(key_padding_mask) + ): + why_not_fast_path = "floating-point masks are not supported for fast path." + + is_batched = query.dim() == 3 + + key_padding_mask = F._canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=F._none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype, + ) + + attn_mask = F._canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() + + if not is_fastpath_enabled: + why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" + elif not is_batched: + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif self.in_proj_weight is None: + why_not_fast_path = "in_proj_weight was None" + elif query.dtype != self.in_proj_weight.dtype: + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif (self.num_heads % 2) != 0: + why_not_fast_path = "self.num_heads is not even" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif query.is_nested and ( + key_padding_mask is not None or attn_mask is not None + ): + why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif _is_make_fx_tracing(): + why_not_fast_path = "we are running make_fx tracing" + elif not all(_check_arg_device(x) for x in tensor_args): + why_not_fast_path = ( + "some Tensor argument's device is neither one of " + f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}" + ) + elif torch.is_grad_enabled() and any( + _arg_requires_grad(x) for x in tensor_args + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + merged_mask, mask_type = self.merge_masks( + attn_mask, key_padding_mask, query + ) + + if self.in_proj_bias is not None and self.in_proj_weight is not None: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + merged_mask, + need_weights, + average_attn_weights, + mask_type, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = (x.transpose(1, 0) for x in (query, key)) + value = key + else: + query, key, value = (x.transpose(1, 0) for x in (query, key, value)) + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + is_causal=is_causal, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + def merge_masks( + self, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + query: Tensor, + ) -> tuple[Optional[Tensor], Optional[int]]: + r"""Determine mask type and combine masks if necessary. + + If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + + if attn_mask is not None: + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + + # Always expands attn_mask to 4D + if attn_mask.dim() == 3: + attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) + else: # attn_mask.dim() == 2: + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand( + batch_size, self.num_heads, -1, -1 + ) + merged_mask = attn_mask_expanded + + if key_padding_mask is not None: + key_padding_mask_expanded = key_padding_mask.view( + batch_size, 1, 1, seq_len + ).expand(-1, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded + key_padding_mask_expanded + + # no attn_mask and no key_padding_mask, returns None, None + return merged_mask, mask_type + + + diff --git a/Amadeus/sub_decoder_zoo.py b/Amadeus/sub_decoder_zoo.py index db7ba31..3718703 100644 --- a/Amadeus/sub_decoder_zoo.py +++ b/Amadeus/sub_decoder_zoo.py @@ -347,13 +347,24 @@ class SelfAttention(SubDecoderClass): causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size) self.register_buffer('causal_mask', causal_mask) + # self.transformer_decoder = Decoder( + # dim = dim, + # depth = sub_decoder_depth, + # heads = heads, + # attn_dropout = dropout, + # ff_dropout = dropout, + # attn_flash = True) self.transformer_decoder = Decoder( - dim = dim, + dim = dim, depth = sub_decoder_depth, heads = heads, attn_dropout = dropout, ff_dropout = dropout, - attn_flash = True) + attn_flash = True, + use_rmsnorm=True, + ff_swish = True, # set this to True + ff_glu = True, # set to true to use for all feedforwards + ) # add final dropout print('Applying Xavier Uniform Init to x-transformer following torch.Transformer') self._apply_xavier_init() @@ -713,7 +724,7 @@ class DiffusionDecoder(SubDecoderClass): dropout:float, sub_decoder_enricher_use:bool, MASK_IDX:int = 126336, - denoising_steps:int = 6, + denoising_steps:int = 8, eps:float = 1e-3, method:str = 'low-confidence', # or random or auto-regressive ): @@ -1091,7 +1102,7 @@ class DiffusionDecoder(SubDecoderClass): logits_dict[feature] = logit return logits_dict, (masked_indices, p_mask) -class DiffusionDecoder(SubDecoderClass): +class DiffusionDecoderV2(SubDecoderClass): def __init__( self, prediction_order:list, @@ -1102,7 +1113,7 @@ class DiffusionDecoder(SubDecoderClass): dropout:float, sub_decoder_enricher_use:bool, MASK_IDX:int = 126336, - denoising_steps:int = 6, + denoising_steps:int = 8, eps:float = 1e-3, method:str = 'low-confidence', # or random or auto-regressive ): @@ -1129,7 +1140,7 @@ class DiffusionDecoder(SubDecoderClass): self.input_norm = nn.LayerNorm(dim) - self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout)) + self.feature_boost_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout)) if sub_decoder_enricher_use: self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True) @@ -1138,14 +1149,21 @@ class DiffusionDecoder(SubDecoderClass): self.register_buffer('causal_mask', causal_mask) self.register_buffer('causal_ca_mask', causal_ca_mask) - # get depth of the sub-decoder if sub_decoder_depth > 1: - self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)]) + self.sub_decoder_layers = nn.Sequential(*[TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)]) else: - self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout)) + self.sub_decoder_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout)) if sub_decoder_enricher_use: self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout)) + + self.aux_ar_decoder = SelfAttention(prediction_order=prediction_order, + vocab=vocab, + sub_decoder_depth=1, + dim=dim, + heads=heads, + dropout=dropout, + sub_decoder_enricher_use=False) # simplified version of the forward process in diffusion model def _forward_process(self, input_ids, eps=1e-3, mask_idx=None): @@ -1273,9 +1291,11 @@ class DiffusionDecoder(SubDecoderClass): # print("sampled_token_dict", sampled_token_dict) return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None, aux_ar=False): logits_dict = {} hidden_vec = input_dict['hidden_vec'] # B x T x d_model + copy_input_dict = input_dict.copy() + target = input_dict['target'] #B x T x d_model bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder @@ -1307,6 +1327,10 @@ class DiffusionDecoder(SubDecoderClass): memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False) # ---- Generate(Inference) ---- # if target is None: + if aux_ar: # inference with auxiliary auto-regressive decoder + aux_ar_logits, sampled_token_dict = self.aux_ar_decoder(copy_input_dict, sampling_method='auto-regressive', threshold=threshold, temperature=temperature, condition_step=condition_step) + # print("aux_ar_logits", aux_ar_logits) + return aux_ar_logits, sampled_token_dict sampled_token_dict = {} b,t,d = hidden_vec.shape # B x T x d_model l = len(self.prediction_order) # num_sub_tokens @@ -1420,4 +1444,7 @@ class DiffusionDecoder(SubDecoderClass): logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :]) logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size logits_dict[feature] = logit - return logits_dict, (masked_indices, p_mask) \ No newline at end of file + # get aux ar decoder logits + aux_ar_logits = self.aux_ar_decoder(copy_input_dict, target) # B x T + return (logits_dict, aux_ar_logits), (masked_indices, p_mask) + # return logits_dict, (masked_indices, p_mask) \ No newline at end of file diff --git a/Amadeus/symbolic_encoding/data_utils.py b/Amadeus/symbolic_encoding/data_utils.py index af0ebca..f11139f 100644 --- a/Amadeus/symbolic_encoding/data_utils.py +++ b/Amadeus/symbolic_encoding/data_utils.py @@ -3,6 +3,7 @@ import random from pathlib import Path from collections import OrderedDict from typing import Union, List, Tuple, Dict +import torch import numpy as np import matplotlib.pyplot as plt @@ -157,6 +158,8 @@ class IterTuneCompiler(IterableDataset): segment = self.augmentor(segment) # use input_ids replace tune_name tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption + # print(segment.shape, mask.shape, tune_name.shape) + # segment = segment[torch.randperm(segment.size(0))] yield segment, mask, tune_name, encoded_caption def __len__(self): diff --git a/Amadeus/symbolic_yamls/config-accelerate.yaml b/Amadeus/symbolic_yamls/config-accelerate.yaml index 7fd217f..2afb195 100644 --- a/Amadeus/symbolic_yamls/config-accelerate.yaml +++ b/Amadeus/symbolic_yamls/config-accelerate.yaml @@ -1,7 +1,9 @@ defaults: # - nn_params: nb8_embSum_NMT # - nn_params: remi8 - - nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2 + # - nn_params: oct8_embSum_diff_t2m_300M_pretrainingv3 + # - nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2 + - nn_params: oct8_embSum_har_t2m_600M_pretrainingv3 # - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2 # - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2 # - nn_params: nb8_embSum_subPararell @@ -15,7 +17,7 @@ defaults: # - nn_params: remi8_main12_head_16_dim512 # - nn_params: nb5_embSum_diff_main12head16dim768_sub3 -dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset +dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset captions_path: dataset/midicaps/train_set.json # dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean @@ -23,28 +25,28 @@ captions_path: dataset/midicaps/train_set.json use_ddp: True # True, False | distributed data parallel use_fp16: True # True, False | mixed precision training -use_diff: True # True,use diffusion in subdecoder +use_diff: False # True,use diffusion in subdecoder diff_steps: 8 # number of diffusion steps -use_dispLoss: True +use_dispLoss: False lambda_weight: 0.5 tau: 0.5 train_params: device: cuda - batch_size: 10 + batch_size: 9 grad_clip: 1.0 num_iter: 300000 # total number of iterations num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint iterations_per_training_cycle: 10 # number of iterations for logging training loss iterations_per_validation_cycle: 3000 # number of iterations for validation process - input_length: 3072 # input sequence length3072 + input_length: 2048 # input sequence length3072 # you can use focal loss, it it's not used, set focal_gamma to 0 focal_alpha: 1 focal_gamma: 0 # learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details scheduler : cosinelr - initial_lr: 0.00001 + initial_lr: 0.0003 decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts' warmup_steps: 2000 #number of warmup steps diff --git a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv2.yaml b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv2.yaml index 0e16086..46dd009 100644 --- a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv2.yaml +++ b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv2.yaml @@ -5,7 +5,7 @@ model_name: AmadeusModel input_embedder_name: SummationEmbedder main_decoder_name: XtransformerNewPretrainingDecoder sub_decoder_name: DiffusionDecoder -model_dropout: 0.2 +model_dropout: 0 input_embedder: num_layer: 1 num_head: 8 diff --git a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv3.yaml b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv3.yaml new file mode 100644 index 0000000..26cdb79 --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv3.yaml @@ -0,0 +1,19 @@ +encoding_scheme: oct +num_features: 8 +vocab_name: MusicTokenVocabOct +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewPretrainingDecoder +sub_decoder_name: DiffusionDecoderV2 +model_dropout: 0 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 768 + num_layer: 16 + num_head: 12 +sub_decoder: + decout_window_size: 1 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_300M_pretrainingv3.yaml b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_300M_pretrainingv3.yaml new file mode 100644 index 0000000..2b51707 --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_300M_pretrainingv3.yaml @@ -0,0 +1,19 @@ +encoding_scheme: oct +num_features: 8 +vocab_name: MusicTokenVocabOct +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewPretrainingDecoder +sub_decoder_name: DiffusionDecoderV2 +model_dropout: 0 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 1080 + num_layer: 13 + num_head: 12 +sub_decoder: + decout_window_size: 1 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_300M_pretrainingv3 copy.yaml b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_300M_pretrainingv3 copy.yaml new file mode 100644 index 0000000..2b51707 --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_300M_pretrainingv3 copy.yaml @@ -0,0 +1,19 @@ +encoding_scheme: oct +num_features: 8 +vocab_name: MusicTokenVocabOct +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewPretrainingDecoder +sub_decoder_name: DiffusionDecoderV2 +model_dropout: 0 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 1080 + num_layer: 13 + num_head: 12 +sub_decoder: + decout_window_size: 1 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_300M_pretrainingv3.yaml b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_300M_pretrainingv3.yaml new file mode 100644 index 0000000..f97652e --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_300M_pretrainingv3.yaml @@ -0,0 +1,19 @@ +encoding_scheme: oct +num_features: 8 +vocab_name: MusicTokenVocabOct +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewPretrainingDecoder +sub_decoder_name: SelfAttention +model_dropout: 0 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 1080 + num_layer: 13 + num_head: 12 +sub_decoder: + decout_window_size: 3 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_600M_pretrainingv3.yaml b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_600M_pretrainingv3.yaml new file mode 100644 index 0000000..97ceb08 --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_har_t2m_600M_pretrainingv3.yaml @@ -0,0 +1,19 @@ +encoding_scheme: oct +num_features: 8 +vocab_name: MusicTokenVocabOct +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewPretrainingDecoder +sub_decoder_name: SelfAttention +model_dropout: 0 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 1272 + num_layer: 20 + num_head: 12 +sub_decoder: + decout_window_size: 1 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/toy_train.py b/Amadeus/toy_train.py new file mode 100644 index 0000000..e050382 --- /dev/null +++ b/Amadeus/toy_train.py @@ -0,0 +1,454 @@ +from math import ceil +from typing import Optional, Union, Literal +from typing_extensions import Unpack + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from transformers import Qwen2Config +from transformers.activations import get_activation +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6): + super().__init__() + + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x: Tensor) -> Tensor: + dtype = x.dtype + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + x = x.float() + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + x = self.weight * x + return x.to(dtype) + + +class FeedForward(nn.Module): + def __init__(self, + hidden_size: int, + intermediate_size: Optional[int] = None, + dropout: float = 0., + hidden_act: Literal['silu', 'gelu'] = 'silu'): + super().__init__() + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = get_activation(hidden_act) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor) -> Tensor: + gate, hidden = self.up_proj(x).chunk(2, dim=-1) + gate = self.act_fn(gate) + + return self.down_proj(self.dropout(gate * hidden)) + + +class MultiHeadAttention(nn.Module): + def __init__(self, + hidden_size: int, + head_dim: int = 64, + num_attention_heads: Optional[int] = None, + attention_dropout: float = 0., + bias: bool = False, + rms_norm_eps: float = 1e-6, + attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa" + ): + super().__init__() + + self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim) + self.head_dim = head_dim + self.scaling = head_dim ** -0.5 + self.attention_dropout = attention_dropout + + self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim, + bias=bias) + self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim, + bias=bias) + self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim, + bias=bias) + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, hidden_size, + bias=bias) + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation] + + # To be compatible with huggingface + self.config = Qwen2Config() + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + key_states: Optional[Tensor] = None, + value_states: Optional[Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + + q = self.q_proj(hidden_states) + + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2) + + if(key_states is None or value_states is None): + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2) + value_states = v.view(hidden_shape).transpose(1, 2) + + attn_output, attn_weights = self.attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class DecoderLayerWithCrossAttention(nn.Module): + def __init__(self, + hidden_size: int, + + attn_head_dim: int = 64, + num_attention_heads: Optional[int] = None, + attention_dropout: float = 0., + attn_bias: bool = False, + + mlp_intermediate_size: Optional[int] = None, + mlp_dropout: float = 0., + mlp_act: Literal['gelu', 'silu'] = 'silu', + + rms_norm_eps: float = 1e-6, + attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa" + ): + super().__init__() + self.hidden_size = hidden_size + + self.self_attn = MultiHeadAttention( + hidden_size, + attn_head_dim, + num_attention_heads, + attention_dropout, + attn_bias, + rms_norm_eps, + attn_implementation) + + self.cross_attn = MultiHeadAttention( + hidden_size, + attn_head_dim, + num_attention_heads, + attention_dropout, + attn_bias, + rms_norm_eps, + attn_implementation) + + self.mlp = FeedForward( + hidden_size, + mlp_intermediate_size, + mlp_dropout, + mlp_act) + + self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) + self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) + self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + input_dict, + ): + ''' + input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask} + ''' + hidden_states = input_dict['input_seq'] + cross_attn_states = input_dict['memory'] + cross_attn_mask = input_dict['memory_mask'] + attention_mask = input_dict['memory_mask'] + output_attentions = False + # Self Attention + residual = hidden_states + hidden_states = self.self_attn_layernorm(hidden_states) + hidden_states, dec_attn_weight = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + # Cross Attention + if(cross_attn_states is not None): + residual = hidden_states + hidden_states = self.cross_attn_layernorm(hidden_states) + cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim) + key_states = self.cross_attn.k_proj(cross_attn_states) + value_states = self.cross_attn.v_proj(cross_attn_states) + + key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2) + value_states = value_states.view(cross_attn_shape).transpose(1, 2) + hidden_states, cross_attn_weight = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attn_mask, + key_states=key_states, + value_states=value_states + ) + hidden_states = residual + hidden_states + + # Feed Forward + residual = hidden_states + hidden_states = self.ffn_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask} + if(output_attentions): + if(cross_attn_states is not None): + # return (hidden_states, (dec_attn_weight, cross_attn_weight)) + output_dict['dec_attn_weight'] = dec_attn_weight + output_dict['cross_attn_weight'] = cross_attn_weight + return output_dict + else: + # return (hidden_states, dec_attn_weight) + output_dict['dec_attn_weight'] = dec_attn_weight + return output_dict + else: + return output_dict + +def sinusoidal_positional_embedding( + x: Tensor, + n: float = 10000.0) -> Tensor: + device, dtype = x.device, x.dtype + T = x.shape[-2] + D = x.shape[-1] + + positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1) + embeddings = torch.zeros(T, D, device=device, dtype=dtype) + + denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding + embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model)) + embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model)) + + return x + embeddings + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +import numpy as np + +# 创建一个简单的数据集来测试收敛性 +class SimpleSeq2SeqDataset(Dataset): + def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256): + self.num_samples = num_samples + self.seq_len = seq_len + self.vocab_size = vocab_size + self.hidden_size = hidden_size + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + # 模拟编码器输出 (memory) + memory = torch.randn(self.seq_len, self.hidden_size) + + # 模拟解码器输入和目标 + decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,)) + target = torch.randint(0, self.vocab_size, (self.seq_len,)) + + # 创建简单的注意力掩码 + memory_mask = torch.ones(self.seq_len, self.seq_len) + + return { + 'memory': memory, + 'decoder_input': decoder_input, + 'target': target, + 'memory_mask': memory_mask + } + +# 创建简单的模型 +class SimpleDecoderModel(nn.Module): + def __init__(self, vocab_size, hidden_size=256, num_layers=3): + super(SimpleDecoderModel, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + # 嵌入层 + self.embedding = nn.Embedding(vocab_size, hidden_size) + + # 位置编码 + self.pos_encoding = sinusoidal_positional_embedding + + # 多个解码器层 + self.layers = nn.ModuleList([ + DecoderLayerWithCrossAttention( + hidden_size=hidden_size, + attn_head_dim=64, + num_attention_heads=2, + attention_dropout=0.1, + attn_bias=False, + mlp_intermediate_size=512, + mlp_dropout=0.1, + mlp_act='silu', + rms_norm_eps=1e-6, + attn_implementation="sdpa" + ) for _ in range(num_layers) + ]) + + # 输出层 + self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6) + self.output_proj = nn.Linear(hidden_size, vocab_size) + + def forward(self, decoder_input, memory, memory_mask): + # 嵌入 + 位置编码 + x = self.embedding(decoder_input) + x = self.pos_encoding(x) + + # 通过多个解码器层 + input_dict = { + 'input_seq': x, + 'memory': memory, + 'memory_mask': memory_mask + } + + for layer in self.layers: + output_dict = layer(input_dict) + input_dict['input_seq'] = output_dict['input_seq'] + + # 最终输出 + hidden_states = self.output_norm(output_dict['input_seq']) + logits = self.output_proj(hidden_states) + + return logits + +# 训练函数 +def train_decoder_model(): + # 超参数 + vocab_size = 100 + hidden_size = 256 + seq_len = 10 + batch_size = 32 + num_epochs = 10 + learning_rate = 0.001 + + # 设备设置 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"使用设备: {device}") + + # 创建数据集和数据加载器 + dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len, + vocab_size=vocab_size, hidden_size=hidden_size) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + # 初始化模型 + model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size) + model.to(device) + + # 损失函数和优化器 + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + # 训练记录 + train_losses = [] + + print("开始训练...") + model.train() + + for epoch in range(num_epochs): + epoch_loss = 0.0 + num_batches = 0 + + for batch in dataloader: + # 移动到设备 + memory = batch['memory'].to(device) + decoder_input = batch['decoder_input'].long().to(device) + target = batch['target'].long().to(device) + memory_mask = batch['memory_mask'].to(device) + + # 前向传播 + optimizer.zero_grad() + logits = model(decoder_input, memory, memory_mask) + + # 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len) + loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1)) + + # 反向传播 + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + num_batches += 1 + + avg_loss = epoch_loss / num_batches + train_losses.append(avg_loss) + + print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') + + # 早期停止检查:如果损失明显下降,说明模型在收敛 + if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值 + print("✅ 模型正常收敛!") + return True + + # 检查最终损失 + final_loss = train_losses[-1] + if final_loss < 5.0: + print("✅ 训练完成,模型正常收敛!") + return True + else: + print("❌ 损失下降不明显,可能需要调整模型或超参数") + return False + +# 运行训练 +if __name__ == "__main__": + # 先测试单个decoder层的前向传播 + print("测试DecoderLayerWithCrossAttention前向传播...") + + # 创建测试数据 + batch_size, seq_len, hidden_size = 2, 64, 256 + decoder_input = torch.randn(seq_len, hidden_size) + memory = torch.randn(seq_len, hidden_size) + memory_mask = torch.ones(seq_len, seq_len) + + # 测试单层 + decoder_layer = DecoderLayerWithCrossAttention( + hidden_size=hidden_size, + attn_head_dim=64, + num_attention_heads=4 + ) + + input_dict = { + 'input_seq': decoder_input, + 'memory': memory, + 'memory_mask': memory_mask + } + + try: + output = decoder_layer(input_dict) + print("✅ DecoderLayer前向传播测试通过") + print(f"输入形状: {decoder_input.shape}") + print(f"输出形状: {output['input_seq'].shape}") + + # 运行完整训练 + success = train_decoder_model() + if success: + print("\n🎉 测试完成!DecoderLayerWithCrossAttention能够正常训练和收敛") + else: + print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数") + + except Exception as e: + print(f"❌ 前向传播测试失败: {e}") \ No newline at end of file diff --git a/Amadeus/train_utils.py b/Amadeus/train_utils.py index 57fe28b..c6375da 100644 --- a/Amadeus/train_utils.py +++ b/Amadeus/train_utils.py @@ -228,19 +228,39 @@ class DiffusionLoss4CompoundToken(): loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum() return loss - + + def get_aux_ar_nll_loss(self, logits, target, mask): + probs = logits.softmax(dim=-1) + if probs.ndim == 3: + probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size] + if target.ndim == 2: + target = target.flatten(0, 1) # [batch_size*seq_len] + # clamp min value to 1e-7 to avoid log(0) + pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len] + loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len] + loss = loss * mask.flatten(0, 1) # [batch_size*seq_len] + loss = loss.sum() / mask.sum() # calculating mean loss considering mask + return loss + def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5): train_loss_list = [] log_loss_dict_normal = {} mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1) p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1) disp_loss = None + aux_ar_logits = None + # print(len(logits_dict)) + if len(logits_dict) == 2: # has aux ar loss + logits_dict, aux_ar_logits = logits_dict if input_dict is not None: hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim feat = hidden_vec.mean(dim=1) #bs,dim disp_loss = dispersive_loss(feat, tau=tau) # scalar for idx, key in enumerate(self.feature_list): training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx]) + if aux_ar_logits is not None: + aux_ar_loss = self.get_aux_ar_nll_loss(aux_ar_logits[key], shifted_tgt[..., idx], mask) + training_loss = 0.5 * training_loss + 0.5 * aux_ar_loss train_loss_list.append(training_loss) if valid: if key == 'type' or key == 'timesig': diff --git a/Amadeus/trainer_accelerate.py b/Amadeus/trainer_accelerate.py index 94054cf..ce31805 100644 --- a/Amadeus/trainer_accelerate.py +++ b/Amadeus/trainer_accelerate.py @@ -74,8 +74,8 @@ class LanguageModelTrainer: sampling_threshold: float, # Threshold for sampling decisions sampling_temperature: float, # Temperature for controlling sampling randomness config, # Configuration parameters (contains general, training, and inference settings) - model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional) - # model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional) + # model_checkpoint="wandb/run-20251114_151512-k21rnynj/files/checkpoints/iter104999_loss0.2490.pt", # Path to a pre-trained model checkpoint (optional) + model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional) ): # Save model, optimizer, and other configurations self.model = model @@ -892,6 +892,10 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer): segment, mask, caption,encoded_caption = batch input_seq, target = segment[:, :-1], segment[:, 1:] total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True) + try: + aux_ar_logits, logits_dict = logits_dict + except: + logits_dict = logits_dict probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()} num_nonmask_tokens = torch.sum(mask) input_seq = input_seq.to(self.device) diff --git a/Amadeus/transformer_utils.py b/Amadeus/transformer_utils.py index 8b970c3..b2001d9 100644 --- a/Amadeus/transformer_utils.py +++ b/Amadeus/transformer_utils.py @@ -729,11 +729,6 @@ class XtransformerNewPretrainingDecoder(nn.Module): ): super().__init__() self._make_decoder_layer(dim, depth, heads, dropout) - self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base') - # frozen text encoder - for param in self.text_encoder.parameters(): - param.requires_grad = False - def _make_decoder_layer(self, dim, depth, heads, dropout): self.transformer_decoder = Decoder( dim = dim, diff --git a/data_representation/octuple2tuneinidx.py b/data_representation/octuple2tuneinidx.py index 3b7f282..f2638c8 100644 --- a/data_representation/octuple2tuneinidx.py +++ b/data_representation/octuple2tuneinidx.py @@ -20,10 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed from tqdm import tqdm from symusic import Score from miditok import Octuple, TokenizerConfig +from itertools import groupby, chain +from random import shuffle, seed lock = RLock() +def shuffled(seq): + shuffle(seq) + return seq + +def permute_inside_and_across_tracks(seq): + seq_sorted = sorted(seq, key=lambda x: x[5]) + tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program + return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks))) + + def convert_event_dicts(dict_list): """ 将 event 词表列表按顺序转换为结构化输出 @@ -59,7 +71,7 @@ def convert_event_dicts(dict_list): return result -def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str): +def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool): try: score = Score(midi_path, ttype="tick") @@ -71,6 +83,8 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str): # 分词 tok_seq = tokenizer(score) token_ids = tok_seq.ids + if whether_shuffle: + token_ids = permute_inside_and_across_tracks(token_ids) # add sos token at the beginning vocab = tokenizer.vocab sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1) @@ -86,7 +100,7 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str): print(f" × 处理文件时出错:{midi_path} -> {e}") -def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)): +def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False): # === 1. 初始化分词器并保存词表 === print("初始化分词器 Octuple...") config = TokenizerConfig( @@ -108,15 +122,20 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu os.makedirs(output_dir, exist_ok=True) # === 3. 收集 MIDI 文件 === - midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \ - glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True) - midi_paths = list(midi_paths) + # midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \ + # glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True) + # midi_paths = list(midi_paths) + midi_paths = [] + for root, _, files in os.walk(midi_dir): + for file in files: + if file.lower().endswith(('.mid', '.midi')): + midi_paths.append(os.path.join(root, file)) print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n") # === 4. 并行处理 === results = [] with ProcessPoolExecutor(max_workers=num_threads) as executor: - futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths} + futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths} for future in tqdm(as_completed(futures), total=len(futures)): res = future.result() @@ -128,8 +147,18 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu if __name__ == "__main__": + import argparse + midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录 + + parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)") + parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录") + parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序") + dataset_name = midi_directory.split("/")[-1] tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8" + + output_dir = tuneidx_prefix - preprocess_midi_directory(midi_directory, output_dir) \ No newline at end of file + args = parser.parse_args() + preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle) \ No newline at end of file diff --git a/data_representation/permute.py b/data_representation/permute.py new file mode 100644 index 0000000..1e9946d --- /dev/null +++ b/data_representation/permute.py @@ -0,0 +1,39 @@ +from itertools import groupby, chain +from random import shuffle, seed + +def shuffled(seq): + shuffle(seq) + return seq + +def permute_inside_and_across_tracks(seq): + seq_sorted = sorted(seq, key=lambda x: x[5]) + tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program + return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks))) + +# (PitchDrum, Position, Bar, Velocity, Duration, Program, Tempo, TimeSignature) +seq = [ + # Program 0 + (60, 0, 0, 90, 96, 0, 120, 16), + (64, 48,0, 88, 96, 0, 120, 16), + (67, 96,0, 92, 96, 0, 120, 16), + + # Program 32 + (40, 0, 0, 80, 192, 32, 120, 16), + (43, 0, 1, 78, 192, 32, 120, 16), + + # Program 40 + (72, 24,0, 85, 72, 40, 120, 16), + (74, 72,0, 83, 72, 40, 120, 16), + (76, 24,1, 86, 72, 40, 120, 16), +] + +# seed(42) + +inside_track_permuted_and_track_permuted = permute_inside_and_across_tracks(seq) + +print("原始 seq:") +for e in seq: + print(e) +print("\n打乱后的 seq:") +for e in inside_track_permuted_and_track_permuted: + print(e) diff --git a/data_representation/resample.py b/data_representation/resample.py new file mode 100644 index 0000000..ac0d771 --- /dev/null +++ b/data_representation/resample.py @@ -0,0 +1,634 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +基于token分布距离的重采样脚本 +读取octuple_token_analysis_report.json,计算每个数据与整体分布的距离, +按照距离加权采样,距离越远的越容易被采样 +""" + +import os +import numpy as np +from pathlib import Path +from collections import Counter +from tqdm import tqdm +import json +from scipy.stats import entropy, wasserstein_distance +from scipy.spatial.distance import jensenshannon +from concurrent.futures import ThreadPoolExecutor, as_completed +import multiprocessing as mp + +# Octuple的列名定义 +COLUMN_NAMES = [ + "pitch", # 0: Pitch/PitchDrum + "position", # 1: Position + "bar", # 2: Bar + "velocity", # 3: Velocity + "duration", # 4: Duration + "program", # 5: Program + "tempo", # 6: Tempo + "timesig" # 7: TimeSignature +] + + +def load_distribution_from_json(json_path): + """ + 从JSON文件中加载整体token分布 + + Args: + json_path: JSON文件路径 + + Returns: + dict: {column_name: {token: probability}} + """ + print(f"读取分布文件: {json_path}") + with open(json_path, 'r', encoding='utf-8') as f: + report = json.load(f) + + distributions = {} + columns = report.get('columns', {}) + + for col_name in COLUMN_NAMES: + if col_name not in columns: + print(f"警告: 列 {col_name} 不在报告中") + distributions[col_name] = {} + continue + + col_data = columns[col_name] + token_counts = col_data.get('token_counts', {}) + total_tokens = col_data.get('total_tokens', 1) + + # 转换为概率分布 + distribution = {} + for token_str, count in token_counts.items(): + token = int(token_str) + distribution[token] = count / total_tokens + + distributions[col_name] = distribution + print(f" 列 {col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}") + + return distributions + + +def compute_data_distribution(data, col_idx): + """ + 计算单个数据在指定列的token分布 + + Args: + data: numpy数组 (num_tokens, num_columns) + col_idx: 列索引 + + Returns: + dict: {token: probability} + """ + if data.size == 0: + return {} + + tokens = data[:, col_idx] + unique, counts = np.unique(tokens, return_counts=True) + total = len(tokens) + + distribution = {} + for token, count in zip(unique, counts): + distribution[int(token)] = count / total + + return distribution + + +def compute_emd_distance(dist1, dist2): + """ + 使用推土机距离(Earth Mover's Distance / Wasserstein距离)计算两个分布之间的距离 + + Args: + dist1: 分布1,dict {token: probability},已归一化 + dist2: 分布2,dict {token: probability},已归一化 + + Returns: + float: EMD距离 + """ + # 获取所有token的并集,并排序 + all_tokens = sorted(set(dist1.keys()) | set(dist2.keys())) + + if not all_tokens: + return 0.0 + + # 构建概率向量和token值向量 + p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens]) + q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens]) + token_values = np.array(all_tokens, dtype=float) + + # 归一化(处理数值误差) + p_sum = p_weights.sum() + q_sum = q_weights.sum() + + if p_sum < 1e-10 or q_sum < 1e-10: + return 0.0 + + p_weights = p_weights / p_sum + q_weights = q_weights / q_sum + + # 使用Wasserstein距离(1-Wasserstein距离,即推土机距离) + # wasserstein_distance需要两个分布的样本值位置和权重 + # 对于离散分布,我们使用token值作为位置 + emd = wasserstein_distance(token_values, token_values, p_weights, q_weights) + + return emd + + +def compute_distribution_distance(dist1, dist2, method='emd'): + """ + 计算两个分布之间的距离 + + Args: + dist1: 分布1,dict {token: probability} + dist2: 分布2,dict {token: probability} + method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度) + + Returns: + float: 分布距离 + """ + if method == 'emd': + return compute_emd_distance(dist1, dist2) + + # 获取所有token的并集 + all_tokens = set(dist1.keys()) | set(dist2.keys()) + + if not all_tokens: + return 0.0 + + # 构建概率向量 + p = np.array([dist1.get(token, 0.0) for token in all_tokens]) + q = np.array([dist2.get(token, 0.0) for token in all_tokens]) + + # 归一化(处理数值误差) + p = p / (p.sum() + 1e-10) + q = q / (q.sum() + 1e-10) + + if method == 'js': + # Jensen-Shannon散度(对称,范围[0, 1]) + return jensenshannon(p, q) + elif method == 'kl': + # KL散度(非对称,需要处理零值) + # 添加小的平滑项避免log(0) + p = p + 1e-10 + q = q + 1e-10 + p = p / p.sum() + q = q / q.sum() + return entropy(p, q) + else: + raise ValueError(f"未知的距离方法: {method}") + + +def extract_subdistribution(global_dist, data_tokens): + """ + 从全局分布中提取只包含数据中出现的token的子分布,并归一化 + + Args: + global_dist: 全局分布,dict {token: probability} + data_tokens: 数据中出现的token集合,set或list + + Returns: + dict: 子分布,dict {token: probability},已归一化 + """ + if not data_tokens or not global_dist: + return {} + + # 提取子分布 + sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens} + + # 归一化 + total = sum(sub_dist.values()) + if total < 1e-10: + return {} + + normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()} + + return normalized_sub_dist + + +def compute_data_distance(data, global_distributions, method='emd'): + """ + 计算单个数据与整体分布的距离 + 对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布, + 都进行归一化然后计算推土机距离 + + Args: + data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载) + global_distributions: 整体分布,dict {column_name: {token: probability}} + method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度) + + Returns: + float: 平均距离(跨所有列) + """ + # 如果data是路径,则加载它 + if isinstance(data, (str, Path)): + try: + data = np.load(data)['arr_0'] + except Exception as e: + # 不打印错误,让调用者处理 + raise RuntimeError(f"加载文件 {data} 时出错: {e}") + + distances = [] + + for col_idx, col_name in enumerate(COLUMN_NAMES): + # 计算该数据在该列的分布 + data_dist = compute_data_distribution(data, col_idx) + + # 获取整体分布 + global_dist = global_distributions.get(col_name, {}) + + if not data_dist or not global_dist: + continue + + # 从全局分布中提取只包含数据中出现的token的子分布 + data_tokens = set(data_dist.keys()) + sub_global_dist = extract_subdistribution(global_dist, data_tokens) + + if not sub_global_dist: + continue + + # 归一化数据分布 + data_dist_sum = sum(data_dist.values()) + if data_dist_sum < 1e-10: + continue + normalized_data_dist = {token: prob / data_dist_sum + for token, prob in data_dist.items()} + + # 计算距离(两个分布都已归一化) + dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method) + distances.append(dist) + + # 返回平均距离 + return np.mean(distances) if distances else 0.0 + + +def _load_single_file(npz_file): + """ + 加载单个npz文件的辅助函数(用于多线程) + + Args: + npz_file: npz文件路径 + + Returns: + tuple: (data, file_path) 或 None(如果加载失败) + """ + try: + data = np.load(npz_file)['arr_0'] + if data.ndim == 2: + return (data, npz_file) + elif data.ndim == 1: + print(f"警告: {npz_file} 是一维数组,跳过") + return None + except Exception as e: + print(f"错误: 加载 {npz_file} 时出错: {e}") + return None + + +def get_data_file_paths(data_dir): + """ + 获取所有数据文件路径(不加载数据) + + Args: + data_dir: 数据目录路径 + + Returns: + list: 文件路径列表 + """ + data_dir = Path(data_dir) + npz_files = [] + + if data_dir.exists(): + npz_files = sorted(list(data_dir.rglob("*.npz"))) + + if not npz_files: + print(f"警告: 在 {data_dir} 中未找到.npz文件") + return [] + + print(f"找到 {len(npz_files)} 个.npz文件") + return npz_files + + +def load_data_with_paths(data_dir, num_threads=None, lazy=False): + """ + 加载所有数据并返回数据路径列表(多线程版本) + + Args: + data_dir: 数据目录路径 + num_threads: 线程数,None表示使用CPU核心数 + lazy: 如果为True,只返回文件路径,不加载数据 + + Returns: + tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True + """ + data_dir = Path(data_dir) + npz_files = [] + + if data_dir.exists(): + npz_files = sorted(list(data_dir.rglob("*.npz"))) + + if not npz_files: + print(f"警告: 在 {data_dir} 中未找到.npz文件") + return [], [] + + if lazy: + print(f"找到 {len(npz_files)} 个.npz文件(延迟加载模式)") + return None, npz_files + + print(f"找到 {len(npz_files)} 个.npz文件,开始加载...") + + if num_threads is None: + num_threads = min(mp.cpu_count(), len(npz_files)) + + all_data = [] + file_paths = [] + + # 使用多线程加载文件 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = {executor.submit(_load_single_file, npz_file): npz_file + for npz_file in npz_files} + + for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"): + result = future.result() + if result is not None: + data, file_path = result + all_data.append(data) + file_paths.append(file_path) + + # 保持原始顺序 + if file_paths: + sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0])) + file_paths, all_data = zip(*sorted_pairs) + file_paths = list(file_paths) + all_data = list(all_data) + + return all_data, file_paths + + +def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True): + """ + 根据距离进行加权重采样 + + Args: + file_paths: 文件路径列表 + distances: 距离列表 + sample_ratio: 采样比例 + method: 距离计算方法(用于确定权重方向) + lazy: 如果为True,返回文件路径而不是数据 + + Returns: + tuple: (sampled_data_or_paths, sampled_paths, sampled_indices) + """ + n_samples = int(len(file_paths) * sample_ratio) + print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)") + + # 将距离转换为权重 + # 距离越远,权重越大 + distances = np.array(distances) + + # 处理零距离或异常值 + min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10 + distances = np.maximum(distances, min_dist * 0.1) + + # 归一化距离到[0, 1],然后转换为权重 + # 使用指数函数增强距离差异 + normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10) + weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样 + + # 归一化权重 + weights = weights / weights.sum() + + # 加权随机采样 + indices = np.arange(len(file_paths)) + sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights) + + sampled_paths = [file_paths[i] for i in sampled_indices] + + # 如果lazy=True,返回路径;否则加载数据 + if lazy: + sampled_data = sampled_paths # 返回路径,延迟加载 + else: + # 加载采样后的数据 + sampled_data = [] + for path in tqdm(sampled_paths, desc="加载采样数据"): + try: + data = np.load(path)['arr_0'] + sampled_data.append(data) + except Exception as e: + print(f"错误: 加载 {path} 时出错: {e}") + sampled_data.append(None) + + print(f"采样完成:") + print(f" 采样数据数量: {len(sampled_paths)}") + print(f" 平均距离: {distances[sampled_indices].mean():.6f}") + print(f" 最小距离: {distances[sampled_indices].min():.6f}") + print(f" 最大距离: {distances[sampled_indices].max():.6f}") + + return sampled_data, sampled_paths, sampled_indices + + +def _save_single_file(args_tuple): + """ + 保存单个文件的辅助函数(用于多线程) + 支持延迟加载:如果data是路径,则从文件加载 + + Args: + args_tuple: (data, original_path, output_dir) + + Returns: + tuple: (success, original_path) 或 (False, original_path, error_msg) + """ + data, original_path, output_dir = args_tuple + try: + # 如果data是路径,则加载它 + if isinstance(data, (str, Path)): + data = np.load(data)['arr_0'] + + # 保持相对路径结构 + relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3]) + output_path = output_dir / relative_path + output_path.parent.mkdir(parents=True, exist_ok=True) + + np.savez_compressed(output_path, data) + return (True, original_path) + except Exception as e: + error_msg = str(e) + print(f"错误: 保存 {original_path} 时出错: {error_msg}") + return (False, original_path, error_msg) + + +def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None): + """ + 保存采样后的数据(多线程版本) + + Args: + sampled_data: 采样后的数据列表 + sampled_paths: 采样后的文件路径列表 + output_dir: 输出目录 + num_threads: 线程数,None表示使用CPU核心数 + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n保存采样数据到: {output_dir}") + + if num_threads is None: + num_threads = min(mp.cpu_count(), len(sampled_data)) + + # 准备参数 + save_args = [(data, original_path, output_dir) + for data, original_path in zip(sampled_data, sampled_paths)] + + # 使用多线程保存文件 + success_count = 0 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # 提交所有任务 + futures = [executor.submit(_save_single_file, args) + for args in save_args] + + # 收集结果 + for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"): + try: + result = future.result(timeout=300) # 设置超时避免卡死 + if isinstance(result, tuple) and len(result) >= 2: + success = result[0] + if success: + success_count += 1 + except Exception as e: + print(f"错误: 获取保存结果时出错: {e}") + + print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件") + + +def main(): + """主函数""" + import argparse + + parser = argparse.ArgumentParser(description="基于token分布距离的重采样") + parser.add_argument("--json_path", type=str, + default="octuple_token_analysis_report.json", + help="token分析报告JSON文件路径") + parser.add_argument("--data_dir", type=str, + default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8", + help="数据目录路径") + parser.add_argument("--output_dir", type=str, + default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled", + help="输出目录路径") + parser.add_argument("--sample_ratio", type=float, default=0.3, + help="采样比例 (默认: 0.3)") + parser.add_argument("--distance_method", type=str, default="emd", + choices=["emd", "js", "kl"], + help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)") + parser.add_argument("--seed", type=int, default=42, + help="随机种子") + parser.add_argument("--num_threads", type=int, default=1, + help="线程数,None表示使用CPU核心数 (默认: None)") + + args = parser.parse_args() + + # 设置随机种子 + np.random.seed(args.seed) + + # 1. 加载整体分布 + global_distributions = load_distribution_from_json(args.json_path) + + # 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据) + _, file_paths = load_data_with_paths(args.data_dir, lazy=True) + + if not file_paths: + print("错误: 未找到任何数据文件") + return + + print(f"\n共找到 {len(file_paths)} 个数据文件") + + # 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载) + print("\n计算每个数据与整体分布的距离(延迟加载模式)...") + + def _compute_distance_wrapper(args_tuple): + """计算距离的包装函数(用于多线程,支持延迟加载)""" + idx, file_path, global_dists, method = args_tuple + try: + distance = compute_data_distance(file_path, global_dists, method=method) + return (idx, distance, None) + except Exception as e: + return (idx, 0.0, str(e)) + + if args.num_threads is None: + num_threads = min(mp.cpu_count(), len(file_paths)) + else: + num_threads = args.num_threads + + # 准备参数(使用文件路径而不是数据,包含索引) + distance_args = [(i, file_path, global_distributions, args.distance_method) + for i, file_path in enumerate(file_paths)] + + # 使用多线程计算距离(按需加载数据) + # 初始化结果列表,保持顺序 + distances = [0.0] * len(file_paths) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # 提交所有任务 + futures = [executor.submit(_compute_distance_wrapper, args) + for args in distance_args] + + # 收集结果,使用 tqdm 显示进度 + for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"): + try: + idx, distance, error = future.result(timeout=300) # 设置超时避免卡死 + distances[idx] = distance + if error: + print(f"警告: 计算距离时出错 (索引 {idx}): {error}") + except Exception as e: + print(f"错误: 获取结果时出错: {e}") + # 如果无法获取结果,保持默认值 0.0 + + distances = np.array(distances) + print(f"\n距离统计:") + print(f" 平均距离: {distances.mean():.6f}") + print(f" 最小距离: {distances.min():.6f}") + print(f" 最大距离: {distances.max():.6f}") + print(f" 标准差: {distances.std():.6f}") + + # 4. 根据距离进行加权采样(延迟加载模式) + sampled_data, sampled_paths, sampled_indices = weighted_resample( + file_paths, distances, + sample_ratio=args.sample_ratio, + method=args.distance_method, + lazy=True # 使用延迟加载,避免重复加载数据 + ) + + # 5. 保存采样结果(多线程,延迟加载) + save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads) + + # 6. 保存采样索引(可选,用于后续分析) + indices_file = Path(args.output_dir) / "sampled_indices.npy" + np.save(indices_file, sampled_indices) + print(f"\n采样索引已保存到: {indices_file}") + + # 保存采样信息 + info = { + "total_samples": len(file_paths), + "sampled_samples": len(sampled_data), + "sample_ratio": args.sample_ratio, + "distance_method": args.distance_method, + "distance_stats": { + "mean": float(distances.mean()), + "min": float(distances.min()), + "max": float(distances.max()), + "std": float(distances.std()) + }, + "sampled_distance_stats": { + "mean": float(distances[sampled_indices].mean()), + "min": float(distances[sampled_indices].min()), + "max": float(distances[sampled_indices].max()), + "std": float(distances[sampled_indices].std()) + } + } + + info_file = Path(args.output_dir) / "resample_info.json" + with open(info_file, 'w', encoding='utf-8') as f: + json.dump(info, f, indent=2, ensure_ascii=False) + print(f"采样信息已保存到: {info_file}") + + +if __name__ == "__main__": + main() + diff --git a/data_representation/resampleV2.py b/data_representation/resampleV2.py new file mode 100644 index 0000000..c3ca13a --- /dev/null +++ b/data_representation/resampleV2.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +基于position和duration token的重采样脚本V2 +对于每首歌: +1. 如果包含的position和duration不在总数据集前3个,则必定采样 +2. 对于包含的,以某个固定的百分比采样 +3. 两个条件满足一个即可 +""" + +import os +import numpy as np +from pathlib import Path +from collections import Counter +from tqdm import tqdm +import json +from concurrent.futures import ThreadPoolExecutor, as_completed +import multiprocessing as mp + +# Octuple的列名定义 +COLUMN_NAMES = [ + "pitch", # 0: Pitch/PitchDrum + "position", # 1: Position + "bar", # 2: Bar + "velocity", # 3: Velocity + "duration", # 4: Duration + "program", # 5: Program + "tempo", # 6: Tempo + "timesig" # 7: TimeSignature +] + + +def load_top_tokens_from_json(json_path, column_name, top_k=3): + """ + 从JSON文件中加载指定列的前top_k个最常见的token + + Args: + json_path: JSON文件路径 + column_name: 列名(如'position'或'duration') + top_k: 返回前k个最常见的token + + Returns: + set: 前top_k个最常见的token集合 + """ + print(f"读取分布文件: {json_path}") + with open(json_path, 'r', encoding='utf-8') as f: + report = json.load(f) + + columns = report.get('columns', {}) + + if column_name not in columns: + print(f"警告: 列 {column_name} 不在报告中") + return set() + + col_data = columns[column_name] + token_counts = col_data.get('token_counts', {}) + + # 按出现次数排序,获取前top_k个 + sorted_tokens = sorted(token_counts.items(), key=lambda x: int(x[1]), reverse=True) + top_tokens = {int(token_str) for token_str, _ in sorted_tokens[:top_k]} + + print(f" 列 {column_name} 的前{top_k}个最常见token: {sorted(top_tokens)}") + + return top_tokens + + +def get_data_tokens(data, col_idx): + """ + 获取单个数据在指定列的所有唯一token + + Args: + data: numpy数组 (num_tokens, num_columns) 或文件路径 + col_idx: 列索引 + + Returns: + set: 唯一token集合 + """ + # 如果data是路径,则加载它 + if isinstance(data, (str, Path)): + try: + data = np.load(data)['arr_0'] + except Exception as e: + raise RuntimeError(f"加载文件 {data} 时出错: {e}") + + if data.size == 0: + return set() + + tokens = data[:, col_idx] + unique_tokens = set(int(token) for token in np.unique(tokens)) + + return unique_tokens + + +def should_sample_song(data, top_position_tokens, top_duration_tokens, + contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, rng=None): + """ + 判断一首歌是否应该被采样 + + Args: + data: numpy数组 (num_tokens, num_columns) 或文件路径 + top_position_tokens: position列的前3个最常见token集合 + top_duration_tokens: duration列的前3个最常见token集合 + contain_sample_ratio: 对于包含前3个token的歌曲,采样比例 + not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率) + rng: 随机数生成器,如果为None则使用全局的np.random + + Returns: + tuple: (是否应该采样, 是否在前3个) - 在前3个指position和duration都在前3个 + """ + # 获取position和duration列的唯一token + position_idx = COLUMN_NAMES.index("position") + duration_idx = COLUMN_NAMES.index("duration") + + position_tokens = get_data_tokens(data, position_idx) + duration_tokens = get_data_tokens(data, duration_idx) + + # 判断是否在前3个 + position_in_top3 = bool(position_tokens & top_position_tokens) + duration_in_top3 = bool(duration_tokens & top_duration_tokens) + in_top3 = position_in_top3 and duration_in_top3 + + if rng is None: + rng = np.random + + # 条件1: 如果position或duration不包含前3个token,以更高概率采样 + if not position_in_top3 or not duration_in_top3: + should_sample = rng.random() < not_contain_sample_ratio + return should_sample, False + + # 条件2: 如果包含前3个token,则以固定百分比采样 + should_sample = rng.random() < contain_sample_ratio + return should_sample, True + + +def _load_single_file(npz_file): + """ + 加载单个npz文件的辅助函数(用于多线程) + + Args: + npz_file: npz文件路径 + + Returns: + tuple: (data, file_path) 或 None(如果加载失败) + """ + try: + data = np.load(npz_file)['arr_0'] + if data.ndim == 2: + return (data, npz_file) + elif data.ndim == 1: + print(f"警告: {npz_file} 是一维数组,跳过") + return None + except Exception as e: + print(f"错误: 加载 {npz_file} 时出错: {e}") + return None + + +def get_data_file_paths(data_dir): + """ + 获取所有数据文件路径(不加载数据) + + Args: + data_dir: 数据目录路径 + + Returns: + list: 文件路径列表 + """ + data_dir = Path(data_dir) + npz_files = [] + + if data_dir.exists(): + npz_files = sorted(list(data_dir.rglob("*.npz"))) + + if not npz_files: + print(f"警告: 在 {data_dir} 中未找到.npz文件") + return [] + + print(f"找到 {len(npz_files)} 个.npz文件") + return npz_files + + +def resample_songs(file_paths, top_position_tokens, top_duration_tokens, + contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, + num_threads=None, seed=42): + """ + 根据新逻辑进行重采样 + + Args: + file_paths: 文件路径列表 + top_position_tokens: position列的前3个最常见token集合 + top_duration_tokens: duration列的前3个最常见token集合 + contain_sample_ratio: 对于包含前3个token的歌曲,采样比例 + not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率) + num_threads: 线程数,None表示使用CPU核心数 + seed: 随机种子 + + Returns: + tuple: (sampled_paths, sampled_indices, stats) + """ + import threading + + # 为每个线程创建独立的随机数生成器 + thread_local = threading.local() + + def get_thread_rng(): + """获取当前线程的随机数生成器""" + if not hasattr(thread_local, 'rng'): + # 使用线程ID和种子创建独立的随机数生成器 + thread_id = threading.current_thread().ident + thread_local.rng = np.random.RandomState(seed + hash(thread_id) % 1000000) + return thread_local.rng + + if num_threads is None: + num_threads = min(mp.cpu_count(), len(file_paths)) + + print(f"\n开始重采样,使用 {num_threads} 个线程...") + print(f" 包含前3个token的采样比例: {contain_sample_ratio*100:.1f}%") + print(f" 不包含前3个token的采样比例: {not_contain_sample_ratio*100:.1f}%") + + def _should_sample_wrapper(args_tuple): + """判断是否采样的包装函数(用于多线程)""" + file_path, top_pos, top_dur, contain_ratio, not_contain_ratio = args_tuple + try: + # 使用线程本地的随机数生成器 + thread_rng = get_thread_rng() + should_sample, in_top3 = should_sample_song( + file_path, top_pos, top_dur, contain_ratio, not_contain_ratio, thread_rng + ) + return (file_path, should_sample, in_top3, None) + except Exception as e: + return (file_path, False, False, str(e)) + + # 准备参数 + sample_args = [(file_path, top_position_tokens, top_duration_tokens, + contain_sample_ratio, not_contain_sample_ratio) + for file_path in file_paths] + + # 使用多线程判断每首歌是否应该采样 + sampled_paths = [] + sampled_indices = [] + stats = { + 'not_in_top3_count': 0, # 不在前3个的歌曲数量 + 'not_in_top3_sampled': 0, # 不在前3个且被采样的歌曲数量 + 'in_top3_count': 0, # 在前3个的歌曲数量 + 'in_top3_sampled': 0 # 在前3个且被采样的歌曲数量 + } + + # 限制并发任务数量,避免一次性提交过多任务 + batch_size = min(1000, len(file_paths)) + results = {} + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # 分批提交任务 + for batch_start in range(0, len(sample_args), batch_size): + batch_end = min(batch_start + batch_size, len(sample_args)) + batch_args = sample_args[batch_start:batch_end] + + futures = {executor.submit(_should_sample_wrapper, args): args[0] + for args in batch_args} + + # 收集结果 + for future in tqdm(as_completed(futures), total=len(futures), + desc=f"判断采样 [{batch_start+1}-{batch_end}/{len(file_paths)}]", + leave=False): + try: + file_path, should_sample, in_top3, error = future.result(timeout=60) + results[file_path] = (should_sample, in_top3, error) + if error: + print(f"警告: 处理 {file_path} 时出错: {error}") + except Exception as e: + print(f"错误: 获取结果时出错: {e}") + + # 按原始顺序处理结果,并统计 + for idx, file_path in enumerate(file_paths): + if file_path not in results: + continue + + should_sample, in_top3, error = results[file_path] + if error: + continue + + # 统计信息 + if in_top3: + stats['in_top3_count'] += 1 + if should_sample: + stats['in_top3_sampled'] += 1 + else: + stats['not_in_top3_count'] += 1 + if should_sample: + stats['not_in_top3_sampled'] += 1 + + if should_sample: + sampled_paths.append(file_path) + sampled_indices.append(idx) + + print(f"\n采样完成:") + print(f" 总歌曲数: {len(file_paths)}") + print(f" 采样歌曲数: {len(sampled_paths)}") + print(f" 采样比例: {len(sampled_paths)/len(file_paths)*100:.2f}%") + print(f" 不在前3个的歌曲数: {stats['not_in_top3_count']}") + print(f" 不在前3个且被采样的歌曲数: {stats['not_in_top3_sampled']}") + if stats['not_in_top3_count'] > 0: + print(f" 不在前3个的歌曲采样比例: {stats['not_in_top3_sampled']/stats['not_in_top3_count']*100:.2f}%") + print(f" 在前3个的歌曲数: {stats['in_top3_count']}") + print(f" 在前3个且被采样的歌曲数: {stats['in_top3_sampled']}") + if stats['in_top3_count'] > 0: + print(f" 在前3个的歌曲采样比例: {stats['in_top3_sampled']/stats['in_top3_count']*100:.2f}%") + + return sampled_paths, sampled_indices, stats + + +def _save_single_file(args_tuple): + """ + 保存单个文件的辅助函数(用于多线程) + 支持延迟加载:如果data是路径,则从文件加载 + + Args: + args_tuple: (data, original_path, output_dir) + + Returns: + tuple: (success, original_path) 或 (False, original_path, error_msg) + """ + data, original_path, output_dir = args_tuple + try: + # 如果data是路径,则加载它 + if isinstance(data, (str, Path)): + data = np.load(data)['arr_0'] + + # 保持相对路径结构 + relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3]) + output_path = output_dir / relative_path + output_path.parent.mkdir(parents=True, exist_ok=True) + + np.savez_compressed(output_path, data) + return (True, original_path) + except Exception as e: + error_msg = str(e) + print(f"错误: 保存 {original_path} 时出错: {error_msg}") + return (False, original_path, error_msg) + + +def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None): + """ + 保存采样后的数据(多线程版本) + + Args: + sampled_data: 采样后的数据列表 + sampled_paths: 采样后的文件路径列表 + output_dir: 输出目录 + num_threads: 线程数,None表示使用CPU核心数 + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n保存采样数据到: {output_dir}") + + if num_threads is None: + num_threads = min(mp.cpu_count(), len(sampled_data)) + + # 准备参数 + save_args = [(data, original_path, output_dir) + for data, original_path in zip(sampled_data, sampled_paths)] + + # 使用多线程保存文件 + success_count = 0 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + # 提交所有任务 + futures = [executor.submit(_save_single_file, args) + for args in save_args] + + # 收集结果 + for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"): + try: + result = future.result(timeout=300) # 设置超时避免卡死 + if isinstance(result, tuple) and len(result) >= 2: + success = result[0] + if success: + success_count += 1 + except Exception as e: + print(f"错误: 获取保存结果时出错: {e}") + + print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件") + + +def main(): + """主函数""" + import argparse + + parser = argparse.ArgumentParser(description="基于position和duration token的重采样V2") + parser.add_argument("--json_path", type=str, + default="octuple_token_analysis_report.json", + help="token分析报告JSON文件路径") + parser.add_argument("--data_dir", type=str, + default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8", + help="数据目录路径") + parser.add_argument("--output_dir", type=str, + default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2", + help="输出目录路径") + parser.add_argument("--contain_sample_ratio", type=float, default=0.1, + help="对于包含前3个token的歌曲,采样比例 (默认: 0.1)") + parser.add_argument("--not_contain_sample_ratio", type=float, default=0.9, + help="对于不包含前3个token的歌曲,采样比例 (默认: 0.9)") + parser.add_argument("--top_k", type=int, default=3, + help="使用前k个最常见的token (默认: 3)") + parser.add_argument("--seed", type=int, default=42, + help="随机种子") + parser.add_argument("--num_threads", type=int, default=None, + help="线程数,None表示使用CPU核心数 (默认: None)") + + args = parser.parse_args() + + # 1. 加载position和duration的前top_k个最常见token + top_position_tokens = load_top_tokens_from_json( + args.json_path, "position", top_k=args.top_k + ) + top_duration_tokens = load_top_tokens_from_json( + args.json_path, "duration", top_k=args.top_k + ) + + if not top_position_tokens or not top_duration_tokens: + print("错误: 无法加载前top_k个token") + return + + # 2. 获取所有数据文件路径 + file_paths = get_data_file_paths(args.data_dir) + + if not file_paths: + print("错误: 未找到任何数据文件") + return + + print(f"\n共找到 {len(file_paths)} 个数据文件") + + # 3. 根据新逻辑进行重采样 + sampled_paths, sampled_indices, stats = resample_songs( + file_paths, + top_position_tokens, + top_duration_tokens, + contain_sample_ratio=args.contain_sample_ratio, + not_contain_sample_ratio=args.not_contain_sample_ratio, + num_threads=args.num_threads, + seed=args.seed + ) + + # 4. 保存采样结果(延迟加载) + sampled_data = sampled_paths # 使用路径,延迟加载 + save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads) + + # 5. 保存采样索引(可选,用于后续分析) + indices_file = Path(args.output_dir) / "sampled_indices.npy" + np.save(indices_file, np.array(sampled_indices)) + print(f"\n采样索引已保存到: {indices_file}") + + # 6. 保存采样信息 + info = { + "total_samples": len(file_paths), + "sampled_samples": len(sampled_paths), + "contain_sample_ratio": args.contain_sample_ratio, + "not_contain_sample_ratio": args.not_contain_sample_ratio, + "top_k": args.top_k, + "top_position_tokens": sorted(list(top_position_tokens)), + "top_duration_tokens": sorted(list(top_duration_tokens)), + "stats": stats + } + + info_file = Path(args.output_dir) / "resample_info.json" + with open(info_file, 'w', encoding='utf-8') as f: + json.dump(info, f, indent=2, ensure_ascii=False) + print(f"采样信息已保存到: {info_file}") + + +if __name__ == "__main__": + main() + diff --git a/data_representation/test.py b/data_representation/test.py index 08dafa7..dd5d576 100644 --- a/data_representation/test.py +++ b/data_representation/test.py @@ -1,14 +1,282 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +统计octuple分词结果中每一列每个token的出现次数,并生成分析报告 +""" + +import os import numpy as np +from pathlib import Path +from collections import defaultdict, Counter +from tqdm import tqdm +import json -# 读取 npz 文件 -data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True) +# Octuple的列名定义 +COLUMN_NAMES = [ + "pitch", # 0: Pitch/PitchDrum + "position", # 1: Position + "bar", # 2: Bar + "velocity", # 3: Velocity + "duration", # 4: Duration + "program", # 5: Program + "tempo", # 6: Tempo + "timesig" # 7: TimeSignature +] -# 查看保存的键 -print(data.files) -# 输出:['filename', 'sequence'] -# 访问数据 -sequence = data["arr_0"] +def load_octuple_data(data_dir): + """ + 加载所有octuple分词后的.npz文件 + + Args: + data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录 + + Returns: + list: 所有加载的numpy数组列表 + """ + data_dir = Path(data_dir) + npz_files = [] + + # 如果目录存在,查找所有.npz文件 + if data_dir.exists(): + npz_files = list(data_dir.rglob("*.npz")) + + if not npz_files: + print(f"警告: 在 {data_dir} 中未找到.npz文件") + return [] + + print(f"找到 {len(npz_files)} 个.npz文件,开始加载...") + + all_data = [] + for npz_file in tqdm(npz_files, desc="加载数据"): + try: + data = np.load(npz_file)['arr_0'] + # 确保数据是二维数组 (num_tokens, num_columns) + if data.ndim == 2: + all_data.append(data) + elif data.ndim == 1: + # 如果是一维,可能需要reshape,但octuple应该是二维的 + print(f"警告: {npz_file} 是一维数组,跳过") + except Exception as e: + print(f"错误: 加载 {npz_file} 时出错: {e}") + continue + + return all_data + + +def count_tokens_by_column(all_data): + """ + 统计每一列每个token的出现次数 + + Args: + all_data: 所有数据的列表,每个元素是一个numpy数组 (num_tokens, num_columns) + + Returns: + dict: {column_index: Counter({token_value: count})} + """ + column_counters = defaultdict(Counter) + + print("统计token出现次数...") + for data in tqdm(all_data, desc="处理数据"): + if data.size == 0: + continue + + num_columns = data.shape[1] if data.ndim == 2 else 1 + + for col_idx in range(num_columns): + if data.ndim == 2: + tokens = data[:, col_idx] + else: + tokens = data + + # 统计该列中每个token的出现次数 + unique, counts = np.unique(tokens, return_counts=True) + for token, count in zip(unique, counts): + column_counters[col_idx][int(token)] += int(count) + + return dict(column_counters) + + +def generate_report(column_counters, output_file=None): + """ + 生成分析报告 + + Args: + column_counters: 统计结果字典 + output_file: 输出文件路径(可选) + """ + report_lines = [] + report_lines.append("=" * 80) + report_lines.append("OCTUPLE分词结果统计分析报告") + report_lines.append("=" * 80) + report_lines.append("") + + # 总体统计 + total_tokens = sum(sum(counter.values()) for counter in column_counters.values()) + report_lines.append(f"总token数: {total_tokens:,}") + report_lines.append(f"分析的列数: {len(column_counters)}") + report_lines.append("") + + # 每一列的详细统计 + for col_idx in sorted(column_counters.keys()): + counter = column_counters[col_idx] + col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}" + + report_lines.append("-" * 80) + report_lines.append(f"列 {col_idx}: {col_name}") + report_lines.append("-" * 80) + + total_in_column = sum(counter.values()) + unique_tokens = len(counter) + min_token = min(counter.keys()) if counter else 0 + max_token = max(counter.keys()) if counter else 0 + + report_lines.append(f" 总token数: {total_in_column:,}") + report_lines.append(f" 唯一token数: {unique_tokens:,}") + report_lines.append(f" Token值范围: [{min_token}, {max_token}]") + report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A") + report_lines.append("") + + # Top 20 最常见的token + report_lines.append(f" Top 20 最常见的token:") + top_tokens = counter.most_common(20) + for rank, (token, count) in enumerate(top_tokens, 1): + percentage = (count / total_in_column * 100) if total_in_column > 0 else 0 + report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)") + report_lines.append("") + + # Top 20 最不常见的token(出现次数>0的) + report_lines.append(f" Top 20 最不常见的token (出现次数>0):") + bottom_tokens = counter.most_common()[-20:] + bottom_tokens.reverse() + for rank, (token, count) in enumerate(bottom_tokens, 1): + percentage = (count / total_in_column * 100) if total_in_column > 0 else 0 + report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)") + report_lines.append("") + + # 分布统计 + counts_list = list(counter.values()) + if counts_list: + report_lines.append(f" 分布统计:") + report_lines.append(f" 最小出现次数: {min(counts_list):,}") + report_lines.append(f" 最大出现次数: {max(counts_list):,}") + report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}") + report_lines.append(f" 标准差: {np.std(counts_list):,.2f}") + report_lines.append("") + + # 跨列分析 + report_lines.append("=" * 80) + report_lines.append("跨列分析") + report_lines.append("=" * 80) + report_lines.append("") + + for col_idx in sorted(column_counters.keys()): + counter = column_counters[col_idx] + col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}" + total_in_column = sum(counter.values()) + percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0 + report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)") + + report_lines.append("") + report_lines.append("=" * 80) + report_lines.append("报告生成完成") + report_lines.append("=" * 80) + + # 输出报告 + report_text = "\n".join(report_lines) + print("\n" + report_text) + + # 保存到文件 + if output_file: + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + f.write(report_text) + print(f"\n报告已保存到: {output_path}") + + # 同时保存JSON格式的详细数据 + if output_file: + json_output = output_path.with_suffix('.json') + json_data = { + 'summary': { + 'total_tokens': total_tokens, + 'num_columns': len(column_counters) + }, + 'columns': {} + } + + for col_idx in sorted(column_counters.keys()): + counter = column_counters[col_idx] + col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}" + json_data['columns'][col_name] = { + 'total_tokens': sum(counter.values()), + 'unique_tokens': len(counter), + 'token_counts': dict(counter), + 'top_20': dict(counter.most_common(20)), + 'bottom_20': dict(counter.most_common()[-20:]) + } + + with open(json_output, 'w', encoding='utf-8') as f: + json.dump(json_data, f, indent=2, ensure_ascii=False) + print(f"详细数据已保存到: {json_output}") + + +def main(): + """主函数""" + # 默认数据目录 - 可以根据需要修改 + default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi" + + # 可以指定具体的数据目录,例如: + data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2" + # 或者使用默认目录扫描所有oct8目录 + + # import sys + # if len(sys.argv) > 1: + # data_dir = sys.argv[1] + # else: + # # 自动查找所有oct8目录 + # base_dir = Path(default_data_dir) + # oct8_dirs = list(base_dir.rglob("oct8")) + # if oct8_dirs: + # print(f"找到以下oct8目录:") + # for i, d in enumerate(oct8_dirs, 1): + # print(f" {i}. {d}") + # if len(oct8_dirs) == 1: + # data_dir = str(oct8_dirs[0]) + # print(f"\n使用目录: {data_dir}") + # else: + # # 使用第一个找到的目录,或者合并所有目录 + # print(f"\n使用第一个目录: {oct8_dirs[0]}") + # print("如需分析其他目录,请指定路径作为参数") + # data_dir = str(oct8_dirs[0]) + # else: + # data_dir = default_data_dir + # print(f"未找到oct8目录,使用默认目录: {data_dir}") + + # 加载数据 + all_data = load_octuple_data(data_dir) + + if not all_data: + print("错误: 未加载到任何数据") + return + + # 检查数据格式 + if all_data: + sample = all_data[0] + print(f"\n数据格式检查:") + print(f" 样本形状: {sample.shape}") + print(f" 样本数据类型: {sample.dtype}") + print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}") + print() + + # 统计token出现次数 + column_counters = count_tokens_by_column(all_data) + + # 生成报告 + output_file = "octuple_token_analysis_report_part.txt" + generate_report(column_counters, output_file) + + +if __name__ == "__main__": + main() -print("token 序列长度:", len(sequence)) -print("前 20 个 token:", sequence[:20]) \ No newline at end of file diff --git a/dllm/.gitignore b/dllm/.gitignore new file mode 100644 index 0000000..2a16347 --- /dev/null +++ b/dllm/.gitignore @@ -0,0 +1,139 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.idea +.Python +build/ +develop-eggs/ +dist/ +downloads/ +applications/DeepSpeed-Chat/data +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Others +/.vscode/ +/tmp/ +/data/ +/wandb/ +/logs/ +/models*/ diff --git a/dllm/.gitmodules b/dllm/.gitmodules new file mode 100644 index 0000000..39e556f --- /dev/null +++ b/dllm/.gitmodules @@ -0,0 +1,4 @@ +[submodule "lm-evaluation-harness"] + path = lm-evaluation-harness + url = https://github.com/ZHZisZZ/lm-evaluation-harness + branch = dllm \ No newline at end of file diff --git a/dllm/LICENSE b/dllm/LICENSE new file mode 100644 index 0000000..960d805 --- /dev/null +++ b/dllm/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Zhanhui Zhou + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/dllm/README.md b/dllm/README.md new file mode 100644 index 0000000..9c7eb8f --- /dev/null +++ b/dllm/README.md @@ -0,0 +1,283 @@ +

dLLM

+ +

+Simple Diffusion Language Modeling +

+ +

+dLLM logo +

+ + +## Overview +**dLLM** is a library that unifies the training and evaluation of **diffusion language models**, bringing transparency and reproducibility to the entire development pipeline: + + + +- dLLM provides scalable training pipelines (inspired by [`transformers`](https://github.com/huggingface/transformers/blob/main/src/transformers) [Trainer](https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py)), with support for [LoRA](https://github.com/huggingface/peft), [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) and [FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and beyond. + +- dLLM provides unified evaluation pipelines (inspired by [`lm-evaluation-harness`](https://github.com/EleutherAI/lm-evaluation-harness)) that abstracts away inference details and making customization simple. + +- Built on these components, dLLM provide the minimal **pretraining / finetuning / evaluation** recipes for open-weight models (e.g., [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)), and implementations of training algorithms (e.g., [Edit Flows](https://arxiv.org/abs/2506.09018)). + + + + +## News +**[2025/11]** We released a collection of BERTs finetuned for instruction-following: [`ModernBERT-{large,base}-chat-v0`](https://huggingface.co/collections/dllm-collection/bert-chat). This proof-of-concept shows that BERT’s internal knowledge can be leveraged for generative tasks via masked instruction tuning. See [![blog](https://img.shields.io/badge/W&B-white?logo=weightsandbiases) BERT Chat Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg) for detailed recipes, experimental results and lessons learned; See [`examples/bert`](/examples/bert) for training / inference / evaluation instructions. + + +## Table of Contents +- [Features](#features) +- [Setup](#setup) +- [Files overview](#files-overview) +- [Training](#training) +- [Inference](#inference) +- [Evaluation](#evaluation) +- [Citation](#citation) + + +## Features + +- [`examples/llada`](/examples/llada): Pretraining, finetuning and evaluating LLaDA [LLaDA](https://arxiv.org/abs/2502.09992) / [LLaDA-MoE](https://arxiv.org/abs/2509.24389). +- [`examples/dream`](/examples/dream): Pretraining, finetuning and evaluating Dream [Dream](https://arxiv.org/abs/2508.15487). +- [`examples/bert`](/examples/bert): Finetuning any [BERT](https://arxiv.org/abs/1810.04805) to be lightweight Chatbots. +
+ 🎬 Click to show BERT Chat Demo + +

+ chat +

+

+ + Chat with ModernBERT-large-chat-v0. See Inference for details. + +

+
+- [`examples/editflow`](/examples/editflow): Educational reference for training [EditFlow](https://arxiv.org/abs/2506.09018) models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT Chat) with *edit operations*—insertion, deletion, and substitution—and how to pretrain or finetune EditFlow models from scratch on public data. + +
+ 🎬 Click to show EditFlow Demo + +

+ EditFlow demo +

+

EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.

+ +
+- More upcoming. + + +## Setup +### Installation +```bash +# create and activate conda environment +conda create -n dllm python=3.10 -y +conda activate dllm + +# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work) +conda install cuda=12.4 -c nvidia +pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \ + --index-url https://download.pytorch.org/whl/cu124 + +# install dllm package +pip install -e . +``` +### (optional) Evaluation setup + +```bash +# initialize `lm-evaluation-harness` submodule +git submodule update --init --recursive + +# install submodule in editable mode with IFEval & Math dependencies +pip install -e "lm-evaluation-harness[ifeval,math]" +``` + +### (optional) Slurm setup +For [Slurm](https://slurm.schedmd.com/) users, update [`scripts/train.slurm.sh`](/scripts/train.slurm.sh) for your cluster: +```diff +- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster +- #SBATCH --quotatype=spot # Note: adjust this for your cluster ++ #SBATCH --partition=YOUR_PARTITION ++ #SBATCH --quotatype=YOUR_QUOTATYPE +``` +Next, create a directory for your job logs: +```shell +mkdir logs +``` +This folder will store the log files generated by your sbatch jobs. + +## Files overview +``` +# modules for training / sampling +dllm +├── core # Core reusable modules shared across `dllm/pipelines` +│ ├── generation +│ ├── schedulers +│ └── trainers +├── data +├── pipelines # Application-specific training & inference pipelines +| ├── bert +│ ├── dream +│ ├── editflow +│ └── llada +│ ├── models # Model architecture and configs +│ ├── generator.py # Generation utilities +│ ├── trainer.py # Core training logic +│ └── eval.py # Evaluation entry point +├── tools +└── utils + +# entry points for training / sampling +examples +├── bert +├── dream +├── editflow +└── llada + ├── chat.py # Interactive inference example + ├── generate.py # Inference example + ├── pt.py # Pretraining example + ├── README.md # Documentation (you are here) + ├── sft.py # Supervised finetuning example + └── eval.sh # Evalution script +``` + +## Training + +A typical training entry script looks like (for example, [`examples/llada/sft.py`](/examples/llada/sft.py)) looks like this: +```python +import transformers + +import dllm + +model_args, data_args, training_args = parser.parse_args_into_dataclasses() +# ----- Model ------------------------------------------------------------------ +model = dllm.utils.get_model(model_args=model_args) +# ----- Tokenizer -------------------------------------------------------------- +tokenizer = dllm.utils.get_tokenizer(model_args=model_args) +# ----- Dataset ---------------------------------------------------------------- +dataset = "..." + +# ----- Training -------------------------------------------------------------- +trainer = dllm.core.trainers.MDLMTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + args=training_args, + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + label_pad_token_id=tokenizer.pad_token_id, + ), +) +trainer.train() +``` + +You can launch training job locally with `accelerate`, or submit it to a [Slurm](https://slurm.schedmd.com/) cluster using `sbatch`. +```shell +# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA) +accelerate launch \ + --config_file scripts/accelerate_configs/zero2.yaml \ + examples/llada/sft.py \ + --num_train_epochs 4 \ + --load_in_4bit True --lora True +``` +```shell +# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs) +sbatch --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/sft.py" \ + --num_train_epochs 4 + +# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs) +sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/sft.py" \ + --num_train_epochs 4 +``` +See [Features](#features) for specific training recipes. + + +> Here are some useful tips for training: +> 1. Use a subset of data: +> `--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]"` +> 2. Concatenate datasets: +> `--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk"` +> 3. Train with LoRA and 4bit quantization: +> `--load_in_4bit True --lora True` +> 4. Train with different distributed training methods: +> `--accelerate_config "ddp,zero-{1,2,3},fsdp"` + +## Inference + +We provide unified [generators](/dllm/core/generation/generator.py) that abstracts away inference details. +A typical inference entry script (for example, [`examples/llada/generate.py`](/examples/llada/generate.py)) looks like this: +```python +import dllm +from dllm import llada + +model = dllm.utils.get_model(model_args=script_args).eval() +tokenizer = dllm.utils.get_tokenizer(model_args=script_args) +# for other models, change your generator and keep others unchanged +generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer) + +messages = [ + [{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}], + [{"role": "user", "content": "Please write an educational python function."}], +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, +) + +outputs = generator.generate(inputs, return_dict_in_generate=True) +sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs) +``` + +You can also try interactive chat script (for example, [`examples/llada/chat.py`](/examples/llada/chat.py)) for visualized multi-turn dialogue: + +

+ chat +

+ + +## Evaluation +> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation. + +For example, to evaluate [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [`MMLU_Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro), run: +```shell +accelerate launch --num_processes 4 \ + dllm/pipelines/llada/eval.py \ + --tasks "mmlu_pro" \ + --model "llada" \ + --apply_chat_template \ + --num_fewshot 0 \ + --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0" +``` + +We also provide scripts to automatically evaluate [LLaDA](https://arxiv.org/abs/2502.09992), [Dream](https://arxiv.org/abs/2508.15487), and [BERT-Chat](https://huggingface.co/collections/dllm-collection/bert-chat) on all benchmarks. +For example, you can launch [`examples/llada/eval.sh`](/examples/llada/eval.sh) directly using the following commands: +```shell +bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True +bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False +``` + + +## Citation +``` +@misc{dllm, + author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song}, + title = {dLLM: Simple Diffusion Language Modeling}, + year = {2025}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/ZHZisZZ/dllm}}, +} +``` diff --git a/dllm/assets/JetBrainsMono-VariableFont_wght.ttf b/dllm/assets/JetBrainsMono-VariableFont_wght.ttf new file mode 100644 index 0000000..4c96e79 Binary files /dev/null and b/dllm/assets/JetBrainsMono-VariableFont_wght.ttf differ diff --git a/dllm/assets/chat.gif b/dllm/assets/chat.gif new file mode 100644 index 0000000..690474c Binary files /dev/null and b/dllm/assets/chat.gif differ diff --git a/dllm/assets/logo.gif b/dllm/assets/logo.gif new file mode 100644 index 0000000..794db89 Binary files /dev/null and b/dllm/assets/logo.gif differ diff --git a/dllm/assets/logo.png b/dllm/assets/logo.png new file mode 100644 index 0000000..a4051e9 Binary files /dev/null and b/dllm/assets/logo.png differ diff --git a/dllm/assets/logo.py b/dllm/assets/logo.py new file mode 100644 index 0000000..75f8db9 --- /dev/null +++ b/dllm/assets/logo.py @@ -0,0 +1,119 @@ +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import os + +# ---------- Configuration (smaller size) ---------- +W, H = 480, 210 # lower resolution +TOTAL_DURATION = 3.0 +FPS = 15 # lower fps +TEXT = "dLLM" +# TEXT_COLOR = (235, 235, 235) +TEXT_COLOR = (0, 0, 0) +OUTPUT = "logo.gif" +LAST_FRAME_PNG = "logo.png" + +DIFFUSION_PORTION = 0.3 # fewer diffusion frames +SEED = 8 + + +# ---------- Auto font size ---------- +def load_font_auto_size(text, w, h, target_width_ratio=0.95, target_height_ratio=0.95): + lo, hi = 10, 2000 + best_font, best_size = None, lo + while lo <= hi: + mid = (lo + hi) // 2 + try: + font = ImageFont.truetype( + "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", size=mid + ) + except: + font = ImageFont.load_default() + dummy = Image.new("L", (w, h), 0) + d = ImageDraw.Draw(dummy) + bbox = d.textbbox((0, 0), text, font=font) + tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] + + width_ok = tw <= w * target_width_ratio + height_ok = th <= h * target_height_ratio + + if width_ok and height_ok: + best_font, best_size = font, mid + lo = mid + 1 + else: + hi = mid - 1 + + return best_font if best_font is not None else font + + +# ---------- Text rendering ---------- +def render_text_mask(w, h, text, font): + img = Image.new("L", (w, h), 0) + d = ImageDraw.Draw(img) + bbox = d.textbbox((0, 0), text, font=font) + tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] + x = (w - tw) // 2 - bbox[0] + y = (h - th) // 2 - bbox[1] + d.text((x, y), text, font=font, fill=255) + return np.asarray(img, np.float32) / 255.0 + + +# ---------- Initialization ---------- +font = load_font_auto_size(TEXT, W, H) +mask = render_text_mask(W, H, TEXT, font) + +num_frames = int(TOTAL_DURATION * FPS) +diffusion_frames = max(1, int(num_frames * DIFFUSION_PORTION)) +hold_ms = int((TOTAL_DURATION - diffusion_frames / FPS) * 1000) + +rng = np.random.default_rng(SEED) +frames = [] + +# ---------- Diffusion stage ---------- +for i in range(diffusion_frames): + t = i / max(1, diffusion_frames - 1) + progress = t**0.9 + noise_sigma = (1.0 - progress) ** 2.2 + + noise = rng.standard_normal((H, W, 1)).astype(np.float32) + noise_img = 1.0 - noise_sigma * 0.5 * np.abs(noise) + np.clip(noise_img, 0.0, 1.0, out=noise_img) + + alpha = progress**2.0 + alpha_map = (mask * alpha).astype(np.float32)[..., None] + + text_rgb = np.zeros((H, W, 3), dtype=np.float32) + for c in range(3): + text_rgb[..., c] = (mask > 0).astype(np.float32) * (TEXT_COLOR[c] / 255.0) + + frame = (1.0 - alpha_map) * noise_img + alpha_map * text_rgb + frame = (np.clip(frame, 0.0, 1.0) * 255).astype(np.uint8) + frames.append(Image.fromarray(frame, mode="RGB")) + +# ---------- Last frame ---------- +final_frame = frames[-1] + +# ---------- Save last frame as PNG ---------- +final_frame.save(LAST_FRAME_PNG) +print(f"🖼️ Last frame saved as: {LAST_FRAME_PNG}") + +# ---------- Quantization (reduce size) ---------- +pal_frames = [f.convert("P", palette=Image.ADAPTIVE, colors=64) for f in frames] +pal_final = final_frame.convert("P", palette=Image.ADAPTIVE, colors=64) + +# ---------- Save GIF ---------- +normal_ms = int(1000 / FPS) +durations = [normal_ms] * len(pal_frames) + [hold_ms] + +pal_frames[0].save( + OUTPUT, + save_all=True, + append_images=pal_frames[1:] + [pal_final], + duration=durations, + loop=0, + optimize=True, +) + +print(f"✅ GIF saved: {OUTPUT}") +print( + f"Frames (diffusion only): {len(pal_frames)} at {FPS} FPS, final hold {hold_ms} ms, resolution {W}x{H}" +) diff --git a/dllm/dllm/__init__.py b/dllm/dllm/__init__.py new file mode 100644 index 0000000..98020f0 --- /dev/null +++ b/dllm/dllm/__init__.py @@ -0,0 +1 @@ +from . import core, data, pipelines, utils diff --git a/dllm/dllm/core/__init__.py b/dllm/dllm/core/__init__.py new file mode 100644 index 0000000..329094a --- /dev/null +++ b/dllm/dllm/core/__init__.py @@ -0,0 +1 @@ +from dllm.core import trainers, schedulers, generation diff --git a/dllm/dllm/core/generation/__init__.py b/dllm/dllm/core/generation/__init__.py new file mode 100644 index 0000000..72bf90e --- /dev/null +++ b/dllm/dllm/core/generation/__init__.py @@ -0,0 +1 @@ +from . import generator, visualizer diff --git a/dllm/dllm/core/generation/generator.py b/dllm/dllm/core/generation/generator.py new file mode 100644 index 0000000..9ef8ce3 --- /dev/null +++ b/dllm/dllm/core/generation/generator.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch +from transformers import PreTrainedTokenizer, PreTrainedModel + +from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler + + +@dataclass +class GeneratorOutput: + sequences: torch.Tensor + histories: list[torch.Tensor] | None = None + + +@dataclass +class GeneratorConfig: + return_dict_in_generate: bool = False + + +@dataclass +class BaseGenerator(ABC): + model: PreTrainedModel + tokenizer: PreTrainedTokenizer + scheduler: BaseAlphaScheduler | None = None + + def __post_init__(self): + if self.scheduler is None: + self.scheduler = LinearAlphaScheduler() + + @abstractmethod + @torch.no_grad() + def generate( + self, + prompts: list[torch.Tensor, list], + config: GeneratorConfig | None = None, + **kwargs, + ) -> GeneratorOutput: + raise NotImplementedError + + @abstractmethod + @torch.no_grad() + def infill( + self, + inputs: list[torch.Tensor, list], + config: GeneratorConfig | None = None, + **kwargs, + ) -> GeneratorOutput: + raise NotImplementedError diff --git a/dllm/dllm/core/generation/visualizer.py b/dllm/dllm/core/generation/visualizer.py new file mode 100644 index 0000000..15b602a --- /dev/null +++ b/dllm/dllm/core/generation/visualizer.py @@ -0,0 +1,427 @@ +from __future__ import annotations +import os +import re +import sys +import time +from dataclasses import dataclass +from abc import ABC, abstractmethod +from typing import Sequence, Optional + +import torch +from tqdm import tqdm + +from transformers import PreTrainedTokenizer + + +@dataclass +class BaseVisualizer(ABC): + tokenizer: PreTrainedTokenizer + + @abstractmethod + def visualize(self, history: list[torch.Tensor, list], **kwargs): + raise NotImplementedError + + +@dataclass +class VideoVisualizer(BaseVisualizer): + + def visualize( + self, + history: list[torch.Tensor, list], + output_path: str = "visualization.gif", + **kwargs, + ): + raise NotImplementedError + + +@dataclass +class TerminalVisualizer(BaseVisualizer): + + # Configuration (adjust as needed) + HEADER_SIZE = 3 # Fixed number of lines for the header (0 if show_header is False) + PROGRESS_SIZE = 3 # Fixed number of lines for the progress bar + PANEL_PADDING_TOP = 1 # Top padding of the Panel (padding=(top, side)) + PANEL_PADDING_BOTTOM = 1 # Bottom padding of the Panel + PANEL_PADDING_SIDE = 1 # Number of characters used for left and right padding + PANEL_BORDER = 2 # Number of columns taken by the Panel border (usually 2) + MIN_TOTAL_HEIGHT = 10 # Minimum terminal height (in lines) + MAX_TOTAL_HEIGHT = 60 # Maximum terminal height to prevent overflowing the terminal + DEFAULT_TERM_WIDTH = 120 # Default terminal width (in columns) + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") # Regex to match ANSI escape codes + + def visualize( + self, + history: list[torch.Tensor], # list of tokens per step: [T] or [B,T] + fps: int = 16, + rich: bool = True, + title: str = "dllm", + max_chars: int = None, + every_n_steps: int = 1, + show_header: bool = True, + skip_special_tokens: bool = False, + ) -> None: + """ + Visualize a masked-diffusion decoding trajectory stored in `history`. + If items have batch dimension [B, T], visualize each sequence separately. + """ + try: + # detect batch size + first_step = history[0] + if first_step.dim() > 1 and first_step.shape[0] > 1: + B = first_step.shape[0] + for b_idx in range(B): + # build per-sequence history + seq_history = [step[b_idx].unsqueeze(0) for step in history] + self.visualize_one_history( + seq_history, + fps, + rich, + title=f"{title} (Batch {b_idx})", + max_chars=max_chars, + every_n_steps=every_n_steps, + show_header=show_header, + skip_special_tokens=skip_special_tokens, + ) + else: + # no batch, just visualize normally + self.visualize_one_history( + history, + fps, + rich, + title, + max_chars, + every_n_steps, + show_header, + skip_special_tokens, + ) + except Exception as e: + print(f"(Visualization skipped due to error: {e})") + + def visualize_one_history( + self, + history: list[torch.Tensor], # list of tokens per step: [T] or [B,T] + fps: int = 16, + rich: bool = True, + title: str = "dllm", + max_chars: int = None, + every_n_steps: int = 1, # re-render frequency (perf knob) + show_header: bool = True, + skip_special_tokens: bool = False, # NEW ARGUMENT + ) -> None: + """ + Visualize a masked-diffusion decoding trajectory stored in `history`. + + Args: + history: Sequence of token tensors for each step. Each item is [T] or [B,T]. + fps: Frames per second for the live UI (Rich) or sleep cadence for tqdm fallback. + title: Header title. + max_chars: Cap on rendered characters to keep terminal snappy. + every_n_steps: Only redraw text every N steps (progress still updates every step). + show_header: Show the magenta header bar (Rich path). + skip_special_tokens: Whether to skip special/pad/eos tokens when rendering (default: False). + Notes: + - Masked positions are detected via `self.tokenizer.mask_token_id`. + - Special tokens are determined via `self.tokenizer.all_special_ids`. + - All layout, styling, and progress are encapsulated here. + """ + # --------- imports & env checks ---------- + try: + from rich.console import Console + from rich.live import Live + from rich.text import Text + from rich.panel import Panel + from rich.progress import ( + Progress, + BarColumn, + TextColumn, + TimeRemainingColumn, + MofNCompleteColumn, + SpinnerColumn, + ) + from rich.layout import Layout + + _RICH_IMPORTED = True + except Exception: + _RICH_IMPORTED = False + + try: + from tqdm import tqdm + + _TQDM_IMPORTED = True + except Exception: + _TQDM_IMPORTED = False + + if self.tokenizer is None: + raise ValueError( + "TerminalVisualizer.tokenizer must be set to a valid tokenizer." + ) + + tokenizer = self.tokenizer + specials: set[int] = set(getattr(tokenizer, "all_special_ids", []) or []) + self._specials = specials # store for helpers + self._mask_token_id: Optional[int] = getattr(tokenizer, "mask_token_id", None) + self._pad_token_id: Optional[int] = getattr(tokenizer, "pad_token_id", None) + self._eos_token_id: Optional[int] = getattr(tokenizer, "eos_token_id", None) + + # --------- helpers inside class scope ---------- + # (keep everything inside this class as requested) + + # throttle settings + sleep_s = 0.0 if fps <= 0 else 1.0 / float(max(1, fps)) + total_steps = len(history) + every_n_steps = max(1, int(every_n_steps)) + + # decode final text up-front (used after render) + final_text = self._detok(history[-1], skip_special_tokens=skip_special_tokens) + final_text = self._truncate(final_text, max_chars) + + # ------------------ new: estimate height from final_text ------------------ + import textwrap + import shutil + + def strip_ansi(s: str) -> str: + return self.ansi_escape.sub("", s) if s else "" + + def estimate_height_from_text(text: str, console_width: int) -> int: + """ + Estimate how many terminal rows the panel with `text` will need given console_width. + Uses class constants for paddings/borders and header/progress sizes. + """ + plain = strip_ansi(text or "") + # inner width = console width minus left/right panel paddings & border + inner_width = max( + 10, console_width - 2 * self.PANEL_PADDING_SIDE - self.PANEL_BORDER + ) + lines = 0 + # preserve existing newlines: wrap each paragraph separately + for para in plain.splitlines() or [""]: + if para.strip() == "": + lines += 1 + continue + wrapped = textwrap.wrap( + para, + width=inner_width, + replace_whitespace=False, + drop_whitespace=False, + ) + lines += max(1, len(wrapped)) + text_block_lines = ( + lines + self.PANEL_PADDING_TOP + self.PANEL_PADDING_BOTTOM + ) + extra = 2 # for panel title / subtitle / small margin + header_h = self.HEADER_SIZE if show_header else 0 + total = header_h + text_block_lines + self.PROGRESS_SIZE + extra + # clamp + total = max(self.MIN_TOTAL_HEIGHT, min(total, self.MAX_TOTAL_HEIGHT)) + return int(total) + + # try to detect terminal width; fallback to 100 + try: + term_width = shutil.get_terminal_size().columns + if not isinstance(term_width, int) or term_width <= 0: + term_width = self.DEFAULT_TERM_WIDTH + except Exception: + term_width = self.DEFAULT_TERM_WIDTH + + est_height = estimate_height_from_text(final_text, console_width=term_width) + # ------------------ end new ---------------------------------------------- + + # choose rich or tqdm + use_rich = bool(rich and _RICH_IMPORTED) + + if not use_rich or not _RICH_IMPORTED: + # ---------- tqdm fallback ---------- + if not _TQDM_IMPORTED: + for i, toks in enumerate(history, start=1): + if sleep_s > 0: + time.sleep(sleep_s) + print("\n✨ Generation complete!\n") + print(final_text) + return + + pbar = tqdm(total=total_steps, desc="Diffusion", leave=True) + for i, toks in enumerate(history, start=1): + pbar.update(1) + pbar.set_postfix( + { + "masks": self._count_masks(toks), + "pct": f"{int(100 * i / max(total_steps, 1))}%", + } + ) + if sleep_s > 0: + time.sleep(sleep_s) + pbar.close() + print("\n✨ Generation complete!\n") + if final_text: + print(final_text) + return + + # ---------- rich live UI ---------- + # replaced fixed height=100 with the estimated height from history[-1] + console = Console( + force_terminal=True, + color_system="truecolor", + width=term_width, + height=est_height, + ) + layout = Layout() + layout.split_column( + ( + Layout(name="header", size=3) + if show_header + else Layout(name="header", size=0) + ), + Layout(name="text", ratio=1), + Layout(name="progress", size=3), + ) + + progress = Progress( + SpinnerColumn(), + TextColumn("[bold blue]Diffusion"), + BarColumn(), + MofNCompleteColumn(), + TextColumn("•"), + TextColumn("[cyan]Masks: {task.fields[masks]}"), + TextColumn("•"), + TextColumn("[magenta]{task.fields[pct]:>4s}"), + TimeRemainingColumn(), + expand=True, + ) + + init_masks = self._count_masks(history[0]) if history else 0 + task_id = progress.add_task( + "Generating", total=total_steps, masks=init_masks, pct="0%" + ) + + with Live(layout, console=console, refresh_per_second=max(1, fps)): + for step_idx, toks in enumerate(history, start=1): + if show_header: + header = Text(title, style="bold magenta", justify="center") + layout["header"].update(Panel(header, border_style="bright_blue")) + + # progress bar + masks_remaining = self._count_masks(toks) + pct = f"{int(100 * step_idx / max(total_steps, 1))}%" + progress.update(task_id, advance=1, masks=masks_remaining, pct=pct) + + # text panel: decode whole sequence (avoids Ġ/Ċ artifacts) + if ( + every_n_steps <= 1 + or (step_idx % every_n_steps == 0) + or step_idx in (1, total_steps) + ): + text_str = self._detok( + toks, skip_special_tokens=skip_special_tokens + ) + text_str = self._truncate(text_str, max_chars) + text_rich = Text.from_ansi(text_str) if text_str else Text("") + layout["text"].update( + Panel( + ( + text_rich + if text_rich.plain + else Text("[dim]— no tokens —[/dim]") + ), + title="[bold]Generated Text", + subtitle=f"[dim]Step {step_idx}/{total_steps}[/dim]", + border_style="cyan", + padding=(1, 1), + ) + ) + + layout["progress"].update(Panel(progress)) + if sleep_s > 0: + time.sleep(sleep_s) + + console.print("\n[bold green]✨ Generation complete![/bold green]\n") + # console.print( + # Panel( + # final_text if final_text else "[dim]— no decodable text —[/dim]", + # title="[bold]Final Generated Text", + # border_style="green", + # padding=(1, 2), + # ) + # ) + + # ======================== helpers (kept inside class) ======================== + + def _has_tty(self) -> bool: + return sys.stdout.isatty() and os.environ.get("TERM", "") not in ("", "dumb") + + def _first_item(self, x: torch.Tensor) -> torch.Tensor: + return x[0] if x.dim() > 1 else x + + def _count_masks(self, toks: torch.Tensor) -> int: + if getattr(self, "_mask_token_id", None) is None: + return 0 + t = self._first_item(toks) + return int((t == self._mask_token_id).sum().item()) + + def _detok(self, ids_or_tensor, *, skip_special_tokens: bool) -> str: + """ + Robust detokenize for list[int] / torch.Tensor([T]) / torch.Tensor([B,T]). + Decode the whole sequence to avoid byte-level artifacts like Ġ/Ċ. + """ + tokenizer = self.tokenizer + # normalize to python list[int] + if isinstance(ids_or_tensor, torch.Tensor): + t = self._first_item(ids_or_tensor).long() + ids = t.tolist() + elif isinstance(ids_or_tensor, (list, tuple)): + ids = list(ids_or_tensor) + else: + # unknown type + return "" + + # Optionally drop specials/pad/eos *before* decode if desired + if skip_special_tokens: + keep = [] + specials = getattr(self, "_specials", set()) + pad_id = getattr(self, "_pad_token_id", None) + eos_id = getattr(self, "_eos_token_id", None) + for tid in ids: + if tid in specials: + continue + if pad_id is not None and tid == pad_id: + continue + if eos_id is not None and tid == eos_id: + continue + keep.append(tid) + ids = keep + + # Prefer tokenizer.decode (handles Ġ/Ċ, merges properly) + text = "" + try: + if hasattr(tokenizer, "decode"): + text = tokenizer.decode( + ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + else: + # fallback: tokens -> string + toks = tokenizer.convert_ids_to_tokens(ids) + if hasattr(tokenizer, "convert_tokens_to_string"): + text = tokenizer.convert_tokens_to_string(toks) + else: + text = " ".join(map(str, toks)) + except Exception: + # extremely defensive fallback + try: + text = tokenizer.decode(ids, skip_special_tokens=True) + except Exception: + text = "" + + # sanitize control chars for terminal + if text: + text = text.replace("\r", "") + return text + + def _truncate(self, s: str, max_chars: Optional[int]) -> str: + if max_chars is None or (isinstance(max_chars, int) and max_chars < 0): + return s + return s[:max_chars] + + +if __name__ == "__main__": + pass diff --git a/dllm/dllm/core/schedulers/__init__.py b/dllm/dllm/core/schedulers/__init__.py new file mode 100644 index 0000000..6838401 --- /dev/null +++ b/dllm/dllm/core/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .alpha import * +from .kappa import * diff --git a/dllm/dllm/core/schedulers/alpha.py b/dllm/dllm/core/schedulers/alpha.py new file mode 100644 index 0000000..c8a2372 --- /dev/null +++ b/dllm/dllm/core/schedulers/alpha.py @@ -0,0 +1,132 @@ +from __future__ import annotations +import dataclasses +import math +from typing import ClassVar, Dict, Type, Any, Union + +import torch + +Number = Union[float, torch.Tensor] + + +# ---------------- Registry-enabled Base ---------------- # +@dataclasses.dataclass +class BaseAlphaScheduler: + __registry__: ClassVar[dict[str, type[BaseAlphaScheduler]]] = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + BaseAlphaScheduler.__registry__[cls.__name__] = cls + BaseAlphaScheduler.__registry__[cls.__name__.lower()] = cls + + # Make instances callable (sched(i) -> alpha(i)) + def __call__(self, t: Number) -> Number: + return self.alpha(t) + + # ---- common API ---- + def alpha(self, i: Number) -> Number: + i_t = torch.as_tensor( + i, + dtype=torch.float32, + device=i.device if isinstance(i, torch.Tensor) else None, + ) + if not torch.all((0.0 <= i_t) & (i_t <= 1.0)): + raise ValueError(f"i={i} not in [0,1]") + out = self._alpha(i_t) + return out.item() if isinstance(i, float) else out + + def alpha_derivative(self, i: Number) -> Number: + i_t = torch.as_tensor( + i, + dtype=torch.float32, + device=i.device if isinstance(i, torch.Tensor) else None, + ) + if not torch.all((0.0 <= i_t) & (i_t <= 1.0)): + raise ValueError(f"i={i} not in [0,1]") + out = self._alpha_derivative(i_t) + return out.item() if isinstance(i, float) else out + + def reverse_mask_prob(self, s: Number, t: Number) -> Number: + t_t = torch.as_tensor( + t, + dtype=torch.float32, + device=t.device if isinstance(t, torch.Tensor) else None, + ) + s_t = torch.as_tensor( + s, + dtype=torch.float32, + device=s.device if isinstance(s, torch.Tensor) else None, + ) + if not torch.all((0.0 <= s_t) & (s_t < 1.0) & (0.0 < t_t) & (t_t <= 1.0)): + raise ValueError(f"(t={t}, s={s}) out of range") + if not torch.all(s_t < t_t): + raise ValueError(f"Require s < t elementwise, but got (t={t}, s={s})") + out = (1 - self(s_t)) / (1 - self(t_t)) + return out.item() if isinstance(t, float) and isinstance(s, float) else out + + def weight(self, i: Number) -> Number: + # w(t) = - α'(t) / (1 - α(t)) + return - self.alpha_derivative(i) / (1 - self.alpha(i) + 1e-6) + + # ---- hooks implemented by subclasses ---- + def _alpha(self, i: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +# ---------------- Implementations ---------------- # + + +@dataclasses.dataclass +class LinearAlphaScheduler(BaseAlphaScheduler): + def _alpha(self, i: torch.Tensor) -> torch.Tensor: + return 1 - i + + def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor: + return -torch.ones_like(i) + + +@dataclasses.dataclass +class CosineAlphaScheduler(BaseAlphaScheduler): + def _alpha(self, i: torch.Tensor) -> torch.Tensor: + return 1 - torch.cos((math.pi / 2) * (1 - i)) + + def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor: + return -(math.pi / 2) * torch.sin((math.pi / 2) * (1 - i)) + + +# ---------------- Factory helpers ---------------- # + + +def get_alpha_scheduler_class(name: str) -> type[BaseAlphaScheduler]: + """Return the scheduler class by name (case-insensitive).""" + cls = BaseAlphaScheduler.__registry__.get( + name + ) or BaseAlphaScheduler.__registry__.get(name.lower()) + if cls is None: + available = sorted(k for k in BaseAlphaScheduler.__registry__ if k[0].isupper()) + raise ValueError(f"Unknown scheduler '{name}'. Available: {available}") + return cls + + +def make_alpha_scheduler(name: str, **kwargs: Any) -> BaseAlphaScheduler: + """Instantiate a scheduler by name with optional kwargs.""" + cls = get_alpha_scheduler_class(name) + return cls(**kwargs) + + +# ---------------- Example usage ---------------- # + +if __name__ == "__main__": + lin_sched = make_alpha_scheduler("LinearalphaScheduler") + print("Linear α(0.5):", lin_sched.alpha(0.5)) + print("Linear w(0.5):", lin_sched.weight(0.5)) + print("Linear α([.25,.5,.75]):", lin_sched.alpha(torch.tensor([0.25, 0.5, 0.75]))) + print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75]))) + print("==========================================") + cos_sched = make_alpha_scheduler("CosinealphaScheduler") + print("Cosine α(0.5):", cos_sched.alpha(0.5)) + print("Cosine w(0.5):", cos_sched.weight(0.5)) + print("Cosine α([.25,.5,.75]):", cos_sched.alpha(torch.tensor([0.25, 0.5, 0.75]))) + print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75]))) diff --git a/dllm/dllm/core/schedulers/kappa.py b/dllm/dllm/core/schedulers/kappa.py new file mode 100644 index 0000000..db18cf5 --- /dev/null +++ b/dllm/dllm/core/schedulers/kappa.py @@ -0,0 +1,128 @@ +from __future__ import annotations +import dataclasses +import math +from typing import ClassVar, Dict, Type, Any, Union + +import torch + +Number = Union[float, torch.Tensor] + + +# ---------------- Registry-enabled Base ---------------- # +@dataclasses.dataclass +class BaseKappaScheduler: + __registry__: ClassVar[dict[str, type[BaseKappaScheduler]]] = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + BaseKappaScheduler.__registry__[cls.__name__] = cls + BaseKappaScheduler.__registry__[cls.__name__.lower()] = cls + + # Make instances callable (sched(t) -> kappa(t)) + def __call__(self, t: Number) -> Number: + return self.kappa(t) + + # ---- common API ---- + def kappa(self, t: Number) -> Number: + t_tensor = torch.as_tensor( + t, + dtype=torch.float32, + device=t.device if isinstance(t, torch.Tensor) else None, + ) + if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)): + raise ValueError(f"t={t} not in [0,1]") + out = self._kappa(t_tensor) + return out.item() if isinstance(t, float) else out + + def kappa_derivative(self, t: Number) -> Number: + t_tensor = torch.as_tensor( + t, + dtype=torch.float32, + device=t.device if isinstance(t, torch.Tensor) else None, + ) + if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)): + raise ValueError(f"t={t} not in [0,1]") + out = self._kappa_derivative(t_tensor) + return out.item() if isinstance(t, float) else out + + def weight(self, t: Number) -> Number: + # w(t) = κ'(t) / (1 - κ(t)) + return self.kappa_derivative(t) / (1 - self.kappa(t) + 1e-6) + + # ---- hooks implemented by subclasses ---- + def _kappa(self, t: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +# ---------------- Implementations ---------------- # + + +@dataclasses.dataclass +class CubicKappaScheduler(BaseKappaScheduler): + a: float = 1.0 + b: float = 1.0 + + def _kappa(self, t: torch.Tensor) -> torch.Tensor: + # κ(t) = (a+1) t^3 - (a+b+1) t^2 + (b+1) t + return (self.a + 1) * (t**3) - (self.a + self.b + 1) * (t**2) + (self.b + 1) * t + + def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor: + # κ'(t) = 3(a+1) t^2 - 2(a+b+1) t + (b+1) + return 3 * (self.a + 1) * (t**2) - 2 * (self.a + self.b + 1) * t + (self.b + 1) + + +@dataclasses.dataclass +class LinearKappaScheduler(CubicKappaScheduler): + # Special case: κ(t) = t corresponds to a=-1, b=0 + a: float = -1.0 + b: float = 0.0 + + +@dataclasses.dataclass +class CosineKappaScheduler(BaseKappaScheduler): + def _kappa(self, t: torch.Tensor) -> torch.Tensor: + # κ(t) = 1 - cos((π/2) * t) + return 1.0 - torch.cos(0.5 * math.pi * t) + + def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor: + # κ'(t) = (π/2) * sin((π/2) * t) + return 0.5 * math.pi * torch.sin(0.5 * math.pi * t) + + +# ---------------- Factory helpers ---------------- # + + +def get_kappa_scheduler_class(name: str) -> type[BaseKappaScheduler]: + """Return the scheduler class by name (case-insensitive).""" + cls = BaseKappaScheduler.__registry__.get( + name + ) or BaseKappaScheduler.__registry__.get(name.lower()) + if cls is None: + available = sorted(k for k in BaseKappaScheduler.__registry__ if k[0].isupper()) + raise ValueError(f"Unknown scheduler '{name}'. Available: {available}") + return cls + + +def make_kappa_scheduler(name: str, **kwargs: Any) -> BaseKappaScheduler: + """Instantiate a scheduler by name with optional kwargs.""" + cls = get_kappa_scheduler_class(name) + return cls(**kwargs) + + +# ---------------- Example usage ---------------- # + +if __name__ == "__main__": + lin_sched = make_kappa_scheduler("LinearKappaScheduler") + print("Linear κ(0.5):", lin_sched.kappa(0.5)) + print("Linear w(0.5):", lin_sched.weight(0.5)) + print("Linear κ([.25,.5,.75]):", lin_sched.kappa(torch.tensor([0.25, 0.5, 0.75]))) + print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75]))) + print("==========================================") + cos_sched = make_kappa_scheduler("CosineKappaScheduler") + print("Cosine κ(0.5):", cos_sched.kappa(0.5)) + print("Cosine w(0.5):", cos_sched.weight(0.5)) + print("Cosine κ([.25,.5,.75]):", cos_sched.kappa(torch.tensor([0.25, 0.5, 0.75]))) + print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75]))) diff --git a/dllm/dllm/core/trainers/__init__.py b/dllm/dllm/core/trainers/__init__.py new file mode 100644 index 0000000..b9252d7 --- /dev/null +++ b/dllm/dllm/core/trainers/__init__.py @@ -0,0 +1 @@ +from dllm.core.trainers.mdlm import MDLMTrainer diff --git a/dllm/dllm/core/trainers/mdlm.py b/dllm/dllm/core/trainers/mdlm.py new file mode 100644 index 0000000..b8edd42 --- /dev/null +++ b/dllm/dllm/core/trainers/mdlm.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from typing import Any + +from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler + + +class MDLMTrainer(transformers.Trainer): + """ + Masked Diffusion Language Model Trainer. + """ + + def __init__( + self, + *args, + scheduler: BaseAlphaScheduler | None = None, + time_epsilon: float = 1e-3, + loss_weight_type: str = "scheduler", # "ones" + **kwargs, + ): + super().__init__(*args, **kwargs) + self.scheduler = scheduler or LinearAlphaScheduler() + if not (0.0 < time_epsilon < 1.0): + raise ValueError("time_epsilon must be in (0, 1)") + self.time_epsilon = time_epsilon + self.loss_weight_type = loss_weight_type + + def _preprocess_inputs(self, inputs): + pass + + def _postprocess_outputs(self, outputs): + pass + + def _compute_loss_weights( + self, + t: torch.Tensor, + inputs: dict[str, Any], + *args, + **kwargs, + ) -> torch.Tensor: + """Compute loss weights given timestep t and other arguments.""" + b, l = inputs["input_ids"].shape + if self.loss_weight_type == "scheduler": + loss_weights = self.scheduler.weight(t).unsqueeze(1).repeat(1, l) # b, 1 + elif self.loss_weight_type == "ones": + loss_weights = torch.ones_like(inputs["input_ids"]) + else: + raise NotImplementedError + return loss_weights + + @torch.no_grad() + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + if prediction_loss_only: + return (loss.detach(), None, None) + + logits = getattr(outputs, "logits", outputs) + if isinstance(logits, torch.Tensor): + logits = logits.detach().contiguous() + + labels = inputs.get("labels") + if isinstance(labels, torch.Tensor): + labels = labels.detach().contiguous() + + return (loss.detach(), logits, labels) + + def compute_loss( + self, + model: transformers.PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs: bool = False, + **kwargs, + ): + assert self.processing_class.padding_side == "right" + self._preprocess_inputs(inputs) + input_ids, labels, attention_mask = ( + inputs["input_ids"], + inputs["labels"], + inputs.get("attention_mask", None), + ) + b, l = input_ids.shape + + # === 1. Sample diffusion timesteps === + # Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0. + # The scheduler defines the masking rate α(t); we convert it to a masking probability p_mask = 1 - α(t). + t = self.time_epsilon + (1 - self.time_epsilon) * torch.rand( + b, device=input_ids.device + ) + p_mask = 1 - self.scheduler(t).unsqueeze(1).expand(b, l) + + # === 2. Apply stochastic masking === + # Tokens are masked independently according to p_mask(t). + # Positions with label = -100 are excluded (ignored in loss). + masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & ( + labels != -100 + ) + # Replace masked tokens with the special [MASK] token. + noised_input_ids = torch.where( + masked_indices, self.processing_class.mask_token_id, input_ids + ) + + # === 3. Forward pass through the model === + # The model predicts clean tokens given noised inputs. + outputs = model(input_ids=noised_input_ids, attention_mask=attention_mask) + self._postprocess_outputs(outputs) + logits = outputs.logits + + # === 4. Handle degenerate cases (no tokens masked) === + # If no positions were masked, return a zero loss to keep gradients valid. + # This step is necessary for Deepspeed Zero-{2,3} + if not masked_indices.any(): + return ( + (logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0 + ) + + # === 5. Compute per-token loss weights === + # Depending on the configuration, weights may depend on timestep t + # (e.g., scheduler-based) or be uniform (ones). + loss_weights = self._compute_loss_weights( + t=t, inputs=inputs, masked_indices=masked_indices + ) + + # === 6. Compute weighted cross-entropy === + # Only masked tokens contribute to the loss. + assert (input_ids[masked_indices] == labels[masked_indices]).all() + token_loss = F.cross_entropy( + logits[masked_indices], input_ids[masked_indices], reduction="none" + ) + token_loss = token_loss * loss_weights[masked_indices] + + # === 7. Normalize loss per effective token length === + # Normalize each sequence’s contribution by its number of valid tokens, + # then average over the batch for stability across variable-length inputs. + effective_lengths = torch.sum(labels != -100, dim=1, keepdim=True).expand(b, l) + loss = torch.sum(token_loss / effective_lengths[masked_indices]) / b + + # === 8. Return final loss (and optionally model outputs) === + return (loss, outputs) if return_outputs else loss diff --git a/dllm/dllm/data/__init__.py b/dllm/dllm/data/__init__.py new file mode 100644 index 0000000..3d994ba --- /dev/null +++ b/dllm/dllm/data/__init__.py @@ -0,0 +1 @@ +from .utils import load_sft_dataset, load_pt_dataset diff --git a/dllm/dllm/data/alpaca.py b/dllm/dllm/data/alpaca.py new file mode 100644 index 0000000..8dfe466 --- /dev/null +++ b/dllm/dllm/data/alpaca.py @@ -0,0 +1,63 @@ +from typing import Optional +from datasets import load_dataset, DatasetDict + + +def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str: + """Construct a clean text prompt from Alpaca fields. + + We intentionally *do not* include Anthropic-style role tags (e.g., "Human:", "Assistant:") + in the returned prompt, to mirror the return shape of `load_hh_rlhf_dataset` which removes + those tags from the prompt it returns. + """ + instruction = (instruction or "").strip() + input_text = (input_text or "").strip() + + if input_text: + # Keep instruction and input separated by a blank line for readability. + return f"{instruction}\n\n{input_text}" + else: + return instruction + + +def load_dataset_alpaca(dataset_name_or_path: str) -> DatasetDict: + """Load the Alpaca dataset (tatsu-lab/alpaca) and expose unified fields. + + Returns a `DatasetDict` where each split contains: + - prompt: Combined instruction (+ optional input), with clean formatting + - response: The target output (model answer) + + Parameters + ---------- + dataset_name_or_path : str + Usually "tatsu-lab/alpaca" or a local path. + """ + dataset = load_dataset(dataset_name_or_path) + + def map_fn(example): + prompt = _build_alpaca_prompt( + example.get("instruction", ""), example.get("input", "") + ) + response = (example.get("output", "") or "").strip() + return { + "messages": [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, + ] + } + + dataset = dataset.map( + map_fn, remove_columns=dataset["train"].column_names, num_proc=4 + ) + # make train test split + dataset = dataset["train"].train_test_split(test_size=0.1, seed=42) + return dataset + + +if __name__ == "__main__": + from dllm.utils import resolve_with_base_env + + dataset_name_or_path = resolve_with_base_env( + "tatsu-lab/alpaca", "BASE_DATASETS_DIR" + ) + dataset = load_dataset_alpaca(dataset_name_or_path) + breakpoint() diff --git a/dllm/dllm/data/opc.py b/dllm/dllm/data/opc.py new file mode 100644 index 0000000..633b7bb --- /dev/null +++ b/dllm/dllm/data/opc.py @@ -0,0 +1,133 @@ +from typing import Optional, Text, List, Dict +from datasets import ( + load_dataset, + get_dataset_config_names, + concatenate_datasets, + DatasetDict, + Dataset, + IterableDatasetDict, +) +from dllm.data.utils import ( + _merge_datasetdicts, + _merge_iterabledatasetdicts, + _ensure_datasetdict, + _ensure_iterabledatasetdict, + _ensure_datasetdict, +) + + +def load_dataset_opc_sft( + dataset_name_or_path: str, name: str | None = None, lang: str | None = None +) -> DatasetDict: + """ + Load OpenCoder OPC SFT dataset(s) and produce a DatasetDict with a train/test split. + - If `name` is provided: load that specific config. + - If `name` is None: load *all* available configs and concatenate them. + """ + + def _map_to_messages(ds: Dataset) -> Dataset: + def map_fn(example): + return { + "messages": [ + {"role": "user", "content": example["instruction"]}, + {"role": "assistant", "content": example["output"]}, + ] + } + + # Remove all original columns after mapping + remove_cols = ds.column_names + return ds.map(map_fn, remove_columns=remove_cols, num_proc=4) + + def _load_one_config(dataset_name_or_path: str, cfg_name: str) -> Dataset: + ds = load_dataset(dataset_name_or_path, cfg_name, split="train") + return _map_to_messages(ds) + + if name is not None: + train_ds = _load_one_config(dataset_name_or_path, name) + else: + # Enumerate and load all configs, then concatenate + cfgs: list[str] = get_dataset_config_names(dataset_name_or_path) + if not cfgs: + raise ValueError(f"No configs found for dataset: {dataset_name_or_path}") + parts = [_load_one_config(dataset_name_or_path, c) for c in cfgs] + train_ds = concatenate_datasets(parts) + + # Final split + ds_dict = train_ds.train_test_split(test_size=0.1, seed=42) + if lang is not None: + ds_dict = ds_dict.filter(lambda row: lang in row["messages"][1]["content"]) + + return DatasetDict(ds_dict) + + +def load_dataset_opc_annealing( + dataset_name_or_path: str, + name: str | None = None, + lang: str | None = None, + streaming: bool = True, +) -> DatasetDict: + def _load_one_config(_name): + ds = load_dataset( + dataset_name_or_path, _name, split="train", streaming=streaming + ) + if lang: + if _name in ["synthetic_code_snippet", "algorithmic_corpus"]: + ds = ds.filter(lambda row: row["lang"] == lang) + elif _name in ["synthetic_qa"]: + ds = ds.filter(lambda row: row["program_lang"] == lang) + else: + raise NotImplementedError + # return IterableDatasetDict({"train": ds}) + if streaming: + return _ensure_iterabledatasetdict(ds) + return _ensure_datasetdict(ds) + + if name is not None: + return _load_one_config(name) + + if streaming: + parts = [ + _load_one_config(name) + for name in get_dataset_config_names(dataset_name_or_path) + ] + merged = parts[0] + for p in parts[1:]: + merged = _merge_iterabledatasetdicts(merged, p) + return merged + else: + parts = [ + _load_one_config(name) + for name in get_dataset_config_names(dataset_name_or_path) + ] + if len(parts) == 1: + return _ensure_datasetdict(parts[0]) + merged = parts[0] + for p in parts[1:]: + merged = _merge_datasetdicts(merged, p) + return _ensure_datasetdict(merged) + + +if __name__ == "__main__": + from dllm.utils import resolve_with_base_env + + dataset_name_or_path = resolve_with_base_env( + "OpenCoder-LLM/opc-sft-stage1", "BASE_DATASETS_DIR" + ) + # If you want a specific config: + dataset_edu = load_dataset_opc_sft(dataset_name_or_path, "realuser_instruct") + # Otherwise, all configs concatenated: + dataset_all = load_dataset_opc_sft(dataset_name_or_path, None) + dataset_all_python = load_dataset_opc_sft(dataset_name_or_path, None, "python") + breakpoint() + + # streaming = True + # dataset_name_or_path = resolve_with_base_env( + # "OpenCoder-LLM/opc-annealing-corpus", "BASE_DATASETS_DIR" + # ) + # # If you want a specific config: + # dataset_alg_all = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus") + # dataset_alg_python = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus", "python") + # # Otherwise, all configs concatenated: + # dataset_all_python = load_dataset_opc_annealing(dataset_name_or_path, None, "python") + # dataset_all_all = load_dataset_opc_annealing(dataset_name_or_path, None) + # breakpoint() diff --git a/dllm/dllm/data/ultrachat.py b/dllm/dllm/data/ultrachat.py new file mode 100644 index 0000000..badc6ed --- /dev/null +++ b/dllm/dllm/data/ultrachat.py @@ -0,0 +1,108 @@ +from typing import Optional, List, Dict +from datasets import load_dataset, DatasetDict + + +def _extract_first_turn(messages: list[dict[str, str]]) -> dict[str, str] | None: + """ + Given a list of chat messages like: + [{"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."}, + ...] + return a dict with the first user/assistant exchange as: + {"prompt": , "response": } + If no valid first turn exists, return None. + """ + if not isinstance(messages, list) or len(messages) < 2: + return None + + # Find the first user message and the first assistant *after* that user msg + # (Most entries start as [user, assistant, ...], but we guard anyway.) + user_idx = None + for i, m in enumerate(messages): + if ( + isinstance(m, dict) + and m.get("role") == "user" + and isinstance(m.get("content"), str) + ): + user_idx = i + break + if user_idx is None: + return None + + # Find first assistant after that user + for j in range(user_idx + 1, len(messages)): + m = messages[j] + if ( + isinstance(m, dict) + and m.get("role") == "assistant" + and isinstance(m.get("content"), str) + ): + user_text = messages[user_idx]["content"].strip() + assistant_text = m["content"].strip() + if user_text and assistant_text: + return {"prompt": user_text, "response": assistant_text} + return None + return None + + +def load_dataset_ultrachat(dataset_name_or_path: str) -> DatasetDict: + """ + Load the UltraChat 200k dataset (HuggingFaceH4/ultrachat_200k) and keep only the *first turn* + (first user message and the assistant reply). + + Returns a `DatasetDict` where each split contains: + - prompt: first user message content + - response: first assistant reply content + + Parameters + ---------- + dataset_name_or_path : str + Typically "HuggingFaceH4/ultrachat_200k" or a local path. + data_dir : Optional[str] + Optional subdirectory (for local paths). + """ + dataset = load_dataset(dataset_name_or_path) + + # We only keep examples that have a valid first (user, assistant) turn. + def has_first_turn(example): + messages = example.get("messages") + return _extract_first_turn(messages) is not None + + dataset = dataset.filter(has_first_turn, num_proc=4) + + def map_fn(example): + first = _extract_first_turn(example["messages"]) + # Fallbacks for robustness (shouldn't be hit after filter, but just in case) + if first is None: + first = {"prompt": (example.get("prompt") or "").strip(), "response": ""} + return {"prompt": first["prompt"], "response": first["response"]} + + # Remove original columns for a clean schema (infer from any available split) + cols_to_remove = None + for split_name in dataset.keys(): + cols_to_remove = dataset[split_name].column_names + break + + dataset = dataset.map(map_fn, remove_columns=cols_to_remove, num_proc=4) + dataset = DatasetDict( + { + new: dataset[old] + for old, new in { + "train_sft": "train", + "test_sft": "test", + }.items() + if old in dataset + } + ) + return dataset + + +if __name__ == "__main__": + # Mirrors the style from your previous loaders: resolve path via env helper if available. + from dllm.utils import resolve_with_base_env + + dataset_name_or_path = resolve_with_base_env( + "HuggingFaceH4/ultrachat_200k", "BASE_DATASETS_DIR" + ) + dataset = load_dataset_ultrachat(dataset_name_or_path) + breakpoint() diff --git a/dllm/dllm/data/utils.py b/dllm/dllm/data/utils.py new file mode 100644 index 0000000..9659cb7 --- /dev/null +++ b/dllm/dllm/data/utils.py @@ -0,0 +1,377 @@ +from datasets import ( + Dataset, + DatasetDict, + IterableDatasetDict, + IterableDataset, + load_dataset, + load_from_disk, +) + +from dllm.utils.utils import resolve_with_base_env, parse_spec, get_default_logger + + +logger = get_default_logger(__name__) + + +def load_sft_dataset( + dataset_args: str, load_preprocessed_data: bool = False +) -> DatasetDict: + """ + Examples of dataset_args: + - "tatsu-lab/alpaca" + - "OpenCoder-LLM/opc-sft-stage2[name:educational_instruct]" + - "tatsu-lab/alpaca[train:5000]" + - "tatsu-lab/alpaca[train:5000] | HuggingFaceH4/ultrachat_200k[train:5000]" + """ + from dllm.data.alpaca import load_dataset_alpaca + from dllm.data.opc import load_dataset_opc_sft + + specs = [p.strip() for p in dataset_args.split("|") if p.strip()] + all_parts = [] + + for raw in specs: + dataset_name_or_path, kvs = parse_spec(raw) + + dataset_name_or_path = resolve_with_base_env( + dataset_name_or_path, "BASE_DATASETS_DIR" + ) + + if load_preprocessed_data: + logger.info("Load preprocessed data from disk.") + ds = load_from_disk(dataset_name_or_path) + # Implement your customized dataset here + elif _match(dataset_name_or_path, "tatsu-lab/alpaca"): + ds = load_dataset_alpaca(dataset_name_or_path) + elif _match(dataset_name_or_path, "allenai/tulu-3-sft-mixture"): + ds = load_dataset(dataset_name_or_path) + ds = ds["train"].train_test_split(test_size=0.1, seed=42) + elif _match(dataset_name_or_path, "HuggingFaceTB/smoltalk"): + name = kvs.pop("name", "all") + ds = load_dataset(dataset_name_or_path, name=name) + elif _match(dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage1") or _match( + dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage2" + ): + name = kvs.pop("name", None) + lang = kvs.pop("lang", None) + ds = load_dataset_opc_sft(dataset_name_or_path, name=name, lang=lang) + elif _match(dataset_name_or_path, "HuggingFaceH4/ultrachat_200k"): + ds = load_dataset(dataset_name_or_path) + ds = DatasetDict({"train": ds["train_sft"], "test": ds["test_sft"]}) + else: + ds = load_dataset(dataset_name_or_path) + + # Normalize to DatasetDict and apply per-split limits + ds = _ensure_datasetdict(ds) + ds = _truncate_dataset(ds, kvs) + all_parts.append(ds) + + # If only one part, return as DatasetDict + if len(all_parts) == 1: + return _ensure_datasetdict(all_parts[0]) + + # Merge all parts into a single DatasetDict + merged = all_parts[0] + for part in all_parts[1:]: + merged = _merge_datasetdicts(merged, part) + return _ensure_datasetdict(merged) + + +def load_pt_dataset( + dataset_args: str, streaming: bool = True, load_preprocessed_data: bool = False +) -> DatasetDict | IterableDatasetDict: + """ + Examples of dataset_args: + - "mlfoundations/dclm-baseline-1.0" + - "OpenCoder-LLM/opc-fineweb-code-corpus" + - "OpenCoder-LLM/opc-fineweb-math-corpus" + - "OpenCoder-LLM/opc-annealing-corpus[lang:python]" + - "wikitext[name:wikitext-103-v1}]" + """ + from dllm.data.opc import load_dataset_opc_annealing + + specs = [p.strip() for p in dataset_args.split("|") if p.strip()] + if not specs: + raise ValueError("Empty dataset_args for load_pt_dataset.") + + # ---------- Shared loader (only differs by streaming flag) ---------- + def _load_base_dataset( + raw: str, *, streaming: bool + ) -> tuple[DatasetDict | IterableDatasetDict, dict, str]: + """ + Returns: (base, kvs, dataset_name_or_path) + - Pops 'name' from kvs when applicable (e.g., wikitext). + - Applies identical matching logic for both streaming/non-streaming. + """ + dataset_name_or_path, kvs = parse_spec(raw) + dataset_name_or_path = resolve_with_base_env( + dataset_name_or_path, "BASE_DATASETS_DIR" + ) + name = kvs.pop("name", None) + + if load_preprocessed_data: + base = load_from_disk(dataset_name_or_path) + elif _match(dataset_name_or_path, ["OpenCoder-LLM/opc-annealing-corpus"]): + lang = kvs.pop("lang", None) + base = load_dataset_opc_annealing( + dataset_name_or_path, name=name, lang=lang, streaming=streaming + ) + else: + base = load_dataset(dataset_name_or_path, name=name, streaming=streaming) + + return base, kvs, dataset_name_or_path + + # ---------- Streaming path ---------- + def _load_one_streaming_spec(raw: str) -> IterableDatasetDict: + base, kvs, dataset_name_or_path = _load_base_dataset(raw, streaming=True) + + split_names = list(base.keys()) + single_split = len(split_names) == 1 + single_split_name = split_names[0] if single_split else None + + n_train = kvs.get("train") + n_test = kvs.get("test") + + if (n_train is not None) or (n_test is not None): + if (n_train is not None) and (n_test is not None): + if single_split: + stream = base[single_split_name] + head = stream.take(n_train + n_test) + test = head.take(n_test) + train = head.skip(n_test).take(n_train) + return IterableDatasetDict({"train": train, "test": test}) + else: + if "train" not in base or "test" not in base: + raise ValueError( + f"{dataset_name_or_path}: require 'train' and 'test' splits for train+test limits." + ) + train = base["train"].take(n_train) + test = base["test"].take(n_test) + return IterableDatasetDict({"train": train, "test": test}) + + if n_train is not None: + if single_split: + train = base[single_split_name].take(n_train) + else: + if "train" not in base: + raise ValueError( + f"{dataset_name_or_path}: missing 'train' split for train limit." + ) + train = base["train"].take(n_train) + return IterableDatasetDict({"train": train}) + + if n_test is not None: + if single_split: + test = base[single_split_name].take(n_test) + else: + if "test" not in base: + raise ValueError( + f"{dataset_name_or_path}: missing 'test' split for test limit." + ) + test = base["test"].take(n_test) + return IterableDatasetDict({"test": test}) + + return base # already an IterableDatasetDict + + # ---------- Non-streaming path (mirror load_sft_dataset; NO shuffle) ---------- + def _load_one_nonstreaming_spec(raw: str) -> DatasetDict: + base, kvs, _ = _load_base_dataset(raw, streaming=False) + ds = _ensure_datasetdict(base) # normalize + ds = _truncate_dataset(ds, kvs) # apply limits (train/test/...) + return ds + + # ---------- Load & Merge ---------- + if streaming: + logger.info("Loading dataset in streaming mode.") + parts = [_load_one_streaming_spec(raw) for raw in specs] + merged = parts[0] + for p in parts[1:]: + merged = _merge_iterabledatasetdicts(merged, p) + # repeat streaming dataset infinitely + merged = IterableDatasetDict( + {k: (v.repeat(None) if k == "train" else v) for k, v in merged.items()} + ) + return merged + else: + logger.info("Loading dataset in non-streaming mode.") + parts = [_load_one_nonstreaming_spec(raw) for raw in specs] + if len(parts) == 1: + return _ensure_datasetdict(parts[0]) + merged = parts[0] + for p in parts[1:]: + merged = _merge_datasetdicts(merged, p) + return _ensure_datasetdict(merged) + + +def _truncate_split(split_data, n: int): + if n is None: + return split_data + try: + if hasattr(split_data, "select"): + # Hugging Face Dataset path + total = getattr(split_data, "num_rows", None) + if total is None: + # some Dataset types expose len(...) + total = len(split_data) + idx = list(range(min(n, total))) + return split_data.select(idx) + except Exception: + pass + try: + return split_data[:n] + except Exception: + # Last resort: iterate + return type(split_data)(item for i, item in enumerate(split_data) if i < n) + + +def _truncate_dataset(ds, limits: dict): + """ + Ensure and return a DatasetDict, truncating splits mentioned in `limits`. + """ + ds = _ensure_datasetdict(ds) # normalize first + out = {} + for split, data in ds.items(): + n = limits.get(split, None) + out[split] = _truncate_split(data, n) if n is not None else data + return DatasetDict(out) + + +def _concat_splits(a, b): + """ + Concatenate two split objects (prefer 🤗 datasets). + """ + if a is b: + return a + if a is None: + return b + if b is None: + return a + + # Prefer datasets' concatenate_datasets when both are Datasets + try: + from datasets import concatenate_datasets + + if isinstance(a, Dataset) and isinstance(b, Dataset): + return concatenate_datasets([a, b]) + except Exception: + pass + + # Fallbacks + try: + return a + b + except Exception: + pass + try: + return type(a)(list(a) + list(b)) + except Exception: + pass + + raise TypeError( + f"Cannot concatenate split objects of types {type(a)} and {type(b)}" + ) + + +def _merge_datasetdicts(d1, d2): + """ + Merge two DatasetDict-like mappings by concatenating splits present in either. + Always returns a DatasetDict. + """ + d1 = _ensure_datasetdict(d1) + d2 = _ensure_datasetdict(d2) + all_splits = set(d1.keys()) | set(d2.keys()) + out = {} + for split in all_splits: + a = d1.get(split, None) + b = d2.get(split, None) + if a is None: + out[split] = b + elif b is None: + out[split] = a + else: + out[split] = _concat_splits(a, b) + return DatasetDict(out) + + +def _ensure_datasetdict(ds): + """ + Normalize various loader outputs into a DatasetDict. + - If loader returns a DatasetDict, return as is. + - If loader returns a mapping (e.g., dict of splits), wrap into DatasetDict. + - If loader returns a single Dataset/list/etc., assume it's 'train'. + """ + if isinstance(ds, DatasetDict): + return ds + if isinstance(ds, dict): + # Try to convert each split value to a Dataset if they aren't already. + # If they are already Datasets, DatasetDict will accept them directly. + return DatasetDict(ds) + # Single split -> assume train + return DatasetDict({"train": ds}) + + +def _match(name: str, needle) -> bool: + """ + Returns True if `name` matches any of the provided needles. + Accepts a single string or a list/tuple of strings. + Match condition: name endswith(needle) or needle in name. + """ + if isinstance(needle, (list, tuple)): + return any(name.endswith(n) or n in name for n in needle) + return name.endswith(needle) or needle in name + + +def _concat_iterable_datasets(parts: list[IterableDataset]) -> IterableDataset: + """ + Concatenate IterableDatasets sequentially without materialization. + Preserves streaming nature; supports downstream .take()/.skip()/.shuffle(). + """ + if not parts: + raise ValueError("No IterableDatasets to concatenate.") + # Try to reuse features from the first dataset when available + features = getattr(parts[0], "features", None) + + def _gen(): + for ds in parts: + yield from ds + + return IterableDataset.from_generator(_gen, features=features) + + +def _ensure_iterabledatasetdict(obj) -> IterableDatasetDict: + if isinstance(obj, IterableDatasetDict): + return obj + if isinstance(obj, dict): + return IterableDatasetDict(obj) + # Single stream -> assume train + return IterableDatasetDict({"train": obj}) + + +def _merge_iterabledatasetdicts( + d1: IterableDatasetDict, d2: IterableDatasetDict +) -> IterableDatasetDict: + """ + Merge by concatenating any overlapping splits (streaming-safe). + """ + d1 = _ensure_iterabledatasetdict(d1) + d2 = _ensure_iterabledatasetdict(d2) + all_splits = set(d1.keys()) | set(d2.keys()) + out = {} + for split in all_splits: + a = d1.get(split, None) + b = d2.get(split, None) + if a is None: + out[split] = b + elif b is None: + out[split] = a + else: + out[split] = _concat_iterable_datasets([a, b]) + return IterableDatasetDict(out) + + +def _truncate_stream(ds: IterableDataset, n: int | None) -> IterableDataset: + if n is None: + return ds + return ds.take(n) + + +if __name__ == "__main__": + breakpoint() diff --git a/dllm/dllm/pipelines/__init__.py b/dllm/dllm/pipelines/__init__.py new file mode 100644 index 0000000..a483d6d --- /dev/null +++ b/dllm/dllm/pipelines/__init__.py @@ -0,0 +1 @@ +from . import llada, dream, rnd, editflow diff --git a/dllm/dllm/pipelines/bert/eval.py b/dllm/dllm/pipelines/bert/eval.py new file mode 100644 index 0000000..36920af --- /dev/null +++ b/dllm/dllm/pipelines/bert/eval.py @@ -0,0 +1,362 @@ +""" +accelerate launch \ + --num_processes 2 \ + dllm/pipelines/bert/eval.py \ + --tasks gsm8k \ + --batch_size 1 \ + --model bert \ + --device cuda \ + --num_fewshot 8 \ + --model_args "pretrained=dllm-collection/ModernBERT-base-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" +""" + +from types import SimpleNamespace +from dataclasses import dataclass + +import accelerate +import torch +import torch.nn.functional as F +from datasets import Dataset +from tqdm import tqdm +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from lm_eval.models.utils import get_dtype + +import dllm +from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig + + +@dataclass +class BERTEvalConfig(LLaDAGeneratorConfig): + max_new_tokens: int = 128 + max_length: int = 512 + steps: int = 128 + block_length: int = 128 + + pretrained: str = "" + dtype: str | torch.dtype = "auto" + batch_size: int = 32 + mc_num: int = 128 + is_check_greedy: bool = True + device: str = "cuda" + + +@register_model("bert") +class BERTEvalHarness(LM): + def __init__( + self, + config: BERTEvalConfig | None = None, + **kwargs, + ): + super().__init__() + + # Initialize config if not provided + if config is None: + config = BERTEvalConfig() + + # Pull args from config, allow kwargs to override + pretrained = kwargs.get("pretrained", config.pretrained) + dtype = kwargs.get("dtype", config.dtype) + batch_size = kwargs.get("batch_size", config.batch_size) + mc_num = kwargs.get("mc_num", config.mc_num) + is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy) + device = kwargs.get("device", config.device) + cfg = kwargs.get("cfg", config.cfg_scale) + steps = kwargs.get("steps", config.steps) + max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens) + block_length = kwargs.get("block_length", config.block_length) + max_length = kwargs.get("max_length", config.max_length) + remasking = kwargs.get("remasking", config.remasking) + + accelerator = accelerate.Accelerator() + + # Get GLOBAL rank from torch.distributed (not accelerator) + if torch.distributed.is_initialized(): + self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15) + self._world_size = ( + torch.distributed.get_world_size() + ) # ← GLOBAL world size (16) + else: + self._rank = 0 + self._world_size = 1 + + # Use accelerator for device placement + self.model = dllm.utils.get_model( + SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype)) + ) + self.model.eval() + + if accelerator.num_processes > 1: + # Let accelerator handle device placement + self.model = accelerator.prepare(self.model) + self.device = ( + accelerator.device + ) # ← Accelerator figures out local device correctly + self.accelerator = accelerator + else: + # Single GPU + self.model = self.model.to(device) + self.device = torch.device(device) + self.accelerator = None + + self.tokenizer = dllm.utils.get_tokenizer( + SimpleNamespace(model_name_or_path=pretrained, model=self.model) + ) + + # generation params + self.mask_id = self.tokenizer.mask_token_id + self.batch_size = int(batch_size) + self.max_length = int(max_length) + self.max_new_tokens = int(max_new_tokens) + self.block_length = int(block_length) + self.steps = int(steps) + self.cfg = float(cfg) + self.remasking = remasking + self.is_check_greedy = is_check_greedy + + # loglikelihood params + self.mc_num = int(mc_num) + assert mc_num % self.batch_size == 0 + self.sampling_eps = 0.0 + + def apply_chat_template( + self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + return chat_templated + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def _forward_process( + self, batch: torch.Tensor, prompt_index: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + b, l = batch.shape + + target_len = (l - prompt_index.sum()).item() + k = torch.randint(1, target_len + 1, (), device=batch.device) + + x = torch.round( + torch.linspace( + float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device + ) + ).long() + x = ((x - 1) % target_len) + 1 + assert x.min() >= 1 and x.max() <= target_len + + indices = torch.arange(target_len, device=batch.device).repeat(b, 1) + is_mask = indices < x.unsqueeze(1) + + for i in range(b): + is_mask[i] = is_mask[i][torch.randperm(target_len)] + + is_mask = torch.cat( + ( + torch.zeros( + b, prompt_index.sum(), dtype=torch.bool, device=batch.device + ), + is_mask, + ), + dim=1, + ) + + noisy_batch = torch.where(is_mask, self.mask_id, batch) + + return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l) + + @torch.no_grad() + def get_logits( + self, batch: torch.Tensor, prompt_index: torch.Tensor + ) -> torch.Tensor: + if self.cfg > 0.0: + assert len(prompt_index) == batch.shape[1] + prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) + un_batch = batch.clone() + un_batch[prompt_index] = self.mask_id + batch = torch.cat([batch, un_batch]) + + logits = self.model(batch).logits + + if self.cfg > 0.0: + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (self.cfg + 1) * (logits - un_logits) + return logits[:, : batch.shape[1]] + + @torch.no_grad() + def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float: + seq = torch.concatenate([prefix, target])[None, :] + seq = seq.repeat((self.batch_size, 1)).to(self.device) + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + + loss_acc = [] + for _ in range(self.mc_num // self.batch_size): + perturbed_seq, p_mask = self._forward_process(seq, prompt_index) + + mask_indices = perturbed_seq == self.mask_id + + logits = self.get_logits(perturbed_seq, prompt_index) + + loss = ( + F.cross_entropy( + logits[mask_indices], seq[mask_indices], reduction="none" + ) + / p_mask[mask_indices] + ) + loss = loss.sum() / self.batch_size + loss_acc.append(loss.item()) + + return -sum(loss_acc) / len(loss_acc) + + @torch.no_grad() + def suffix_greedy_prediction( + self, prefix: torch.Tensor, target: torch.Tensor + ) -> bool: + if not self.is_check_greedy: + return False + + seq = torch.full( + (1, len(prefix) + len(target)), self.mask_id, device=self.device + ) + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + prefix, target = prefix.to(self.device), target.to(self.device) + seq[0, : len(prefix)] = prefix + + for i in range(len(target)): + mask_index = seq == self.mask_id + logits = self.get_logits(seq, prompt_index)[mask_index] + x0 = torch.argmax(logits, dim=-1) + + p = torch.softmax(logits.to(torch.float32), dim=-1) + confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze( + dim=-1 + ) + _, index = torch.sort(confidence, descending=True) + x0[index[1:]] = self.mask_id + seq[mask_index] = x0.clone() + correct = target == seq[0, len(prefix) :] + correct = torch.all(correct) + return correct + + def _encode_pair( + self, context: str, continuation: str + ) -> tuple[torch.Tensor, torch.Tensor]: + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tokenizer(context + continuation)["input_ids"] + context_enc = self.tokenizer(context)["input_ids"] + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: + def _tokenize(e): + prefix, target = self._encode_pair(e["prefix"], e["target"]) + return { + "prefix_text": e["prefix"], + "target_text": e["target"], + "prefix": prefix, + "target": target, + } + + ds = [] + ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] + + assert max(prompt_len) <= 4096 + + out = [] + with torch.no_grad(): + for elem in tqdm(ds, desc="Computing likelihood..."): + prefix = elem["prefix"] + target = elem["target"] + + ll = self.get_loglikelihood(prefix, target) + + is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) + + out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) + torch.cuda.empty_cache() + return out + + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: + raise NotImplementedError + + def generate_until(self, requests: list[Instance]): + def _tokenize(e): + return { + "question": self.tokenizer(e["question"])["input_ids"], + "question_text": e["question"], + "until": e["until"], + } + + ds = [ + {"question": req.args[0], "until": req.args[1]["until"]} for req in requests + ] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + + out = [] + generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer) + + for elem in tqdm(ds, desc="Generating..."): + prompt = [elem["question"][1:-1].to(self.device)] + stop_tokens = elem["until"] + generated_ids = generator.generate( + inputs=prompt, + steps=self.steps, + max_new_tokens=self.max_new_tokens, + block_length=self.block_length, + temperature=0.0, + cfg_scale=self.cfg, + remasking=self.remasking, + ) + generated_answer = self.tokenizer.decode( + generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False + ) + breakpoint() + for stop_seq in stop_tokens: + if stop_seq in generated_answer: + generated_answer = generated_answer.split(stop_seq)[0] + + # remove special tokens + generated_answer_ids = self.tokenizer(generated_answer)["input_ids"] + generated_answer = self.tokenizer.decode( + generated_answer_ids, skip_special_tokens=True + ) + out.append(generated_answer) + if self.accelerator is not None: + self.accelerator.wait_for_everyone() + + return out + + +if __name__ == "__main__": + cli_evaluate() diff --git a/dllm/dllm/pipelines/dream/__init__.py b/dllm/dllm/pipelines/dream/__init__.py new file mode 100644 index 0000000..54818ee --- /dev/null +++ b/dllm/dllm/pipelines/dream/__init__.py @@ -0,0 +1,6 @@ +from . import generator, models, trainer, utils +from .models.modeling_dream import DreamModel +from .models.configuration_dream import DreamConfig +from .models.tokenization_dream import DreamTokenizer +from .generator import DreamGeneratorConfig, DreamGenerator +from .trainer import DreamTrainer diff --git a/dllm/dllm/pipelines/dream/eval.py b/dllm/dllm/pipelines/dream/eval.py new file mode 100644 index 0000000..03066bb --- /dev/null +++ b/dllm/dllm/pipelines/dream/eval.py @@ -0,0 +1,533 @@ +""" +accelerate launch \ + --num_processes 2 \ + dllm/pipelines/dream/eval.py \ + --tasks gsm8k \ + --batch_size 1 \ + --model dream \ + --device cuda + --num_fewshot 0 \ + --model_args "pretrained=Dream-org/Dream-v0-Base-7B,mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true" +""" + +import logging +from types import SimpleNamespace +from dataclasses import dataclass + +import accelerate +import torch +import torch.nn.functional as F +from datasets import Dataset +from tqdm import tqdm +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from lm_eval.models.utils import get_dtype + +import dllm +from dllm.pipelines.dream import DreamGenerator, DreamGeneratorConfig + +eval_logger = logging.getLogger(__name__) + + +@dataclass +class DreamEvalConfig(DreamGeneratorConfig): + top_p: float | None = None + top_k: float | None = None + max_new_tokens: int = 128 + max_length: int = 2048 + steps: int = 128 + temperature: float = 0.0 + alg: str = "entropy" + + pretrained: str = "" + batch_size: int = 1 + device: str = "cuda" + dtype: str | torch.dtype = "auto" + add_bos_token: bool = False + nll_type: str = "mc" + log_type: str = "ftb" + mc_num: int = 128 + classifier_free_guidance: float = 1.0 + sampling_eps: float = 1e-3 + escape_until: bool = False + + +@register_model("dream") +class DreamEvalHarness(LM): + def __init__( + self, + config: DreamEvalConfig | None = None, + **kwargs, + ) -> None: + super().__init__() + + # Initialize config if not provided + if config is None: + config = DreamEvalConfig() + + # Pull args from config, allow kwargs to override + pretrained = kwargs.get("pretrained", config.pretrained) + batch_size = kwargs.get("batch_size", config.batch_size) + device = kwargs.get("device", config.device) + dtype = kwargs.get("dtype", config.dtype) + max_length = kwargs.get("max_length", config.max_length) + add_bos_token = kwargs.get("add_bos_token", config.add_bos_token) + nll_type = kwargs.get("nll_type", config.nll_type) + log_type = kwargs.get("log_type", config.log_type) + mc_num = kwargs.get("mc_num", config.mc_num) + max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens) + classifier_free_guidance = kwargs.get( + "classifier_free_guidance", config.classifier_free_guidance + ) + sampling_eps = kwargs.get("sampling_eps", config.sampling_eps) + steps = kwargs.get("steps", config.steps) + temperature = kwargs.get("temperature", config.temperature) + top_p = kwargs.get("top_p", config.top_p) + top_k = kwargs.get("top_k", config.top_k) + alg = kwargs.get("alg", config.alg) + alg_temp = kwargs.get("alg_temp", config.alg_temp) + escape_until = kwargs.get("escape_until", config.escape_until) + + accelerator = accelerate.Accelerator() + + # Get GLOBAL rank from torch.distributed (not accelerator) + if torch.distributed.is_initialized(): + self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15) + self._world_size = ( + torch.distributed.get_world_size() + ) # ← GLOBAL world size (16) + else: + self._rank = 0 + self._world_size = 1 + + # Use accelerator for device placement + self.model = dllm.utils.get_model( + SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype)) + ) + self.model.eval() + + if accelerator.num_processes > 1: + # Let accelerator handle device placement + self.model = accelerator.prepare(self.model) + self.device = ( + accelerator.device + ) # ← Accelerator figures out local device correctly + self.accelerator = accelerator + else: + # Single GPU + self.model = self.model.to(device) + self.device = torch.device(device) + self.accelerator = None + + self.tokenizer = dllm.utils.get_tokenizer( + SimpleNamespace(model_name_or_path=pretrained, model=self.model) + ) + + # generation params + self.mask_id = self.tokenizer.mask_token_id + self.max_length = max_length + self.add_bos_token = add_bos_token + self.batch_size = int(batch_size) + self.max_new_tokens = max_new_tokens + self.steps = steps + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.alg = alg + self.alg_temp = alg_temp + self.escape_until = escape_until + + # loglikelihood params + self.nll_type = nll_type + self.log_type = log_type + self.mc_num = mc_num + self.classifier_free_guidance = classifier_free_guidance + self.sampling_eps = sampling_eps + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def tok_decode( + self, tokens: torch.Tensor | list[int], skip_special_tokens: bool = True + ) -> str: + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def tok_encode(self, text: str, add_special_tokens: bool = True) -> torch.Tensor: + return self.tokenizer( + text, return_tensors="pt", add_special_tokens=add_special_tokens + ).input_ids + + def apply_chat_template( + self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + return chat_templated + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def generate_until( + self, requests: list[Instance], disable_tqdm: bool = False + ) -> list[str]: + res = [] + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + generator = DreamGenerator(model=self.model, tokenizer=self.tokenizer) + for batch_idx in range(0, len(requests), self.batch_size): + batch_requests = requests[batch_idx : batch_idx + self.batch_size] + contexts, gen_args = zip(*[req.arguments for req in batch_requests]) + + # ====== BEGIN merged _generate_batch logic ====== + prompts = list(contexts) + if self.add_bos_token: + prompts = [self.tokenizer.bos_token + p for p in prompts] + + # tokenize + prompt_ids = [ + self.tokenizer( + p, return_tensors="pt", padding=False + ).input_ids.squeeze() + for p in prompts + ] + prompt_lens = [len(p_id) for p_id in prompt_ids] + + if max(prompt_lens) > self.max_length - self.max_new_tokens: + cutoff_len = self.max_length - self.max_new_tokens + eval_logger.warning( + f"Prompt length {max(prompt_lens)} exceeds {cutoff_len}, cutoff on the left side" + ) + # ✅ Correct: trim from the left side (keep the last cutoff_len tokens) + prompt_ids = [p_id[-cutoff_len:] for p_id in prompt_ids] + + # generation + generation_ids = generator.generate( + max_new_tokens=self.max_new_tokens, + inputs=prompt_ids, + steps=self.steps, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + alg=self.alg, + alg_temp=self.alg_temp, + output_history=False, + return_dict_in_generate=False, + ) + # decode and cleanup + cleaned_generation_ids = [ + ( + seq[seq.ne(self.tokenizer.eos_token_id).float().argmax().long() :] + if (seq != self.tokenizer.eos_token_id).any() + else seq[-1:] + ) + for seq in generation_ids + ] + truncated_generation_ids = [ + seq[prompt_lens[i] :] for i, seq in enumerate(cleaned_generation_ids) + ] + responses = [ + g.lstrip("<|endoftext|>").split(self.tokenizer.eos_token, 1)[0] + for g in self.tokenizer.batch_decode(truncated_generation_ids) + ] + + # ====== END merged _generate_batch logic ====== + + # handle "until" truncation + if not self.escape_until: + for i, r in enumerate(responses): + for s in gen_args[0]["until"]: + r = r.split(s)[0] + responses[i] = r + + res.extend(responses) + pbar.update(len(contexts)) + + return res + + def _forward_process( + self, batch: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + b, l = batch.shape + # sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1 + u0 = torch.rand(1, device=batch.device, dtype=torch.float32) + indices = torch.arange(b, device=batch.device).float() + t = (u0 + indices / b) % 1 + + p_mask = (1 - self.sampling_eps) * t + self.sampling_eps + + p_mask = p_mask[:, None].repeat(1, l) + + mask_indices = torch.rand((b, l), device=batch.device) < p_mask + # always unmask bos and eos + mask_indices[:, 0] = False + mask_indices[:, -1] = False + + noisy_batch = torch.where(mask_indices, self.mask_id, batch) + return noisy_batch, p_mask + + @torch.no_grad() + def get_logits( + self, batch: torch.Tensor, prompt_index: torch.Tensor + ) -> torch.Tensor: + """ + prompt_index : 1D bool tensor, length=batch.shape[1] + """ + if self.classifier_free_guidance > 1.0: + assert len(prompt_index) == batch.shape[1] + prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) + un_batch = batch.clone() + un_batch[prompt_index] = self.mask_id + batch = torch.cat([batch, un_batch]) + + input = batch + + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + logits = self.model(input).logits + # since bos always unmask, the first logits will not be used + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + if self.classifier_free_guidance > 1.0: + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + self.cfg * (logits - un_logits) + return logits[:, : batch.shape[1]] + + @torch.no_grad() + def _eval_target_nll_mc( + self, prefix: torch.Tensor | None, target: torch.Tensor + ) -> float: + if prefix is None: + seq = target[None, :] + else: + seq = torch.concatenate([prefix, target])[None, :] + seq = seq.repeat((self.batch_size, 1)).to(self.device) + + if self.log_type == "ftb": + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + else: + prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix) + + loss_acc = [] + for _ in range(max(self.mc_num // self.batch_size, 1)): + perturbed_seq = seq.clone() + # eval_logger.info("before noising") + perturbed_seq_, p_mask = self._forward_process(seq) + # eval_logger.info("end noising") + if self.log_type == "ftb": + perturbed_seq[:, -len(target) :] = perturbed_seq_[:, -len(target) :] + elif self.log_type == "btf": + perturbed_seq[:, : len(prefix)] = perturbed_seq_[:, : len(prefix)] + elif self.log_type == "union": + perturbed_seq = perturbed_seq_ + else: + raise NotImplementedError(self.log_type) + + mask_indices = perturbed_seq == self.mask_id + logits = self.get_logits(perturbed_seq, prompt_index) + loss = ( + F.cross_entropy( + logits[mask_indices], seq[mask_indices], reduction="none" + ) + / p_mask[mask_indices] + ) + loss = loss.sum() / self.batch_size + loss_acc.append(loss.item()) + + return sum(loss_acc) / len(loss_acc) + + @torch.no_grad() + def _eval_target_nll_ar(self, prefix: torch.Tensor, target: torch.Tensor) -> float: + prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2 + assert self.log_type in ["ftb", "btf"] + assert self.nll_type in ["ar_ftb", "ar_btf"] + + if self.log_type == "ftb": + prompt_index = ( + torch.arange(prefix.shape[1] + target.shape[1], device=self.device) + < prefix.shape[1] + ) + else: + prompt_index = ( + torch.arange(prefix.shape[1] + target.shape[1], device=self.device) + >= prefix.shape[1] + ) + + if self.log_type == "ftb": + perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2 + else: + perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1 + + mask_index = torch.ones( + (perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool + ) + if self.nll_type == "ar_ftb": + mask_index = torch.triu(mask_index) + else: + mask_index = torch.tril(mask_index) + perturbed_[mask_index] = self.mask_id + if self.log_type == "ftb": + perturbed_seq = torch.cat( + [prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1 + ) + else: + perturbed_seq = torch.cat( + [perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1 + ) + + logits_ = [] + num = ( + len(perturbed_seq) // self.batch_size + if len(perturbed_seq) % self.batch_size == 0 + else len(perturbed_seq) // self.batch_size + 1 + ) + for i in range(num): + end = ( + (i + 1) * self.batch_size + if (i + 1) * self.batch_size < len(perturbed_seq) + else len(perturbed_seq) + ) + perturbed_seq_ = perturbed_seq[i * self.batch_size : end] + perturbed_seq_ = perturbed_seq_.to(self.device) + if len(perturbed_seq_.shape) == 1: + perturbed_seq_ = perturbed_seq_.unsqueeze(0) + logits = self.get_logits(perturbed_seq_, prompt_index) + logits_.append(logits.cpu()) + logits = torch.cat(logits_, dim=0) + + temp_index = torch.ones( + (perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool + ) + if self.nll_type == "ar_ftb": + temp_index = torch.triu(temp_index, diagonal=1) + else: + temp_index = torch.tril(temp_index, diagonal=-1) + mask_index[temp_index] = False + if self.log_type == "ftb": + logits_index = torch.cat( + [ + torch.zeros( + (perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool + ), + mask_index, + ], + dim=-1, + ) + else: + logits_index = torch.cat( + [ + mask_index, + torch.zeros( + (perturbed_.shape[1], target.shape[1]), dtype=torch.bool + ), + ], + dim=-1, + ) + + if self.log_type == "ftb": + loss = ( + F.cross_entropy(logits[logits_index], target[0], reduction="sum") + .cpu() + .item() + ) + else: + loss = ( + F.cross_entropy(logits[logits_index], prefix[0], reduction="sum") + .cpu() + .item() + ) + return loss + + def _encode_pair( + self, context: str, continuation: str + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.add_bos_token: + context = self.tokenizer.bos_token + context + + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tokenizer.encode(context + continuation) + [ + self.tokenizer.eos_token_id + ] + context_enc = self.tokenizer.encode(context) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + # by default truncate on the left + cutoff_length = max(len(whole_enc) - self.max_length, 0) + if cutoff_length > 0: + eval_logger.warning( + f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side" + ) + context_remain = context_enc_len - cutoff_length + if context_remain > 0: + context_enc = context_enc[-context_remain:] + else: + eval_logger.warning(f"All context (prompt) is truncated.") + context_enc = "" + continuation_enc = whole_enc[-self.max_length :] + return context_enc, continuation_enc + + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: + def _tokenize(e): + prefix, target = self._encode_pair(e["prefix"], e["target"]) + return { + "prefix_text": e["prefix"], + "target_text": e["target"], + "prefix": prefix, + "target": target, + } + + ds = [] + ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + + out = [] + with torch.no_grad(): + for elem in tqdm(ds, desc="Computing likelihood..."): + prefix = elem["prefix"] + target = elem["target"] + # likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py + if self.nll_type == "mc": + ll = -self._eval_target_nll_mc(prefix, target) + if self.log_type == "union": + ll = ll / (len(target) + len(prefix)) + elif self.nll_type == "ar_ftb" or self.nll_type == "ar_btf": + ll = -self._eval_target_nll_ar(prefix, target) + else: + raise NotImplementedError(self.nll_type) + + # TODO: greedy decoding + is_target_greedy_dec = False + + out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) + return out + + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: + raise NotImplementedError + + +if __name__ == "__main__": + cli_evaluate() diff --git a/dllm/dllm/pipelines/dream/generator.py b/dllm/dllm/pipelines/dream/generator.py new file mode 100644 index 0000000..d6475cc --- /dev/null +++ b/dllm/dllm/pipelines/dream/generator.py @@ -0,0 +1,426 @@ +""" +reference: https://huggingface.co/Dream-org/Dream-v0-Base-7B/blob/main/generation_utils.py +""" + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +import torch.distributions as dists + +from dllm.utils.generation_utils import get_num_transfer_tokens +from dllm.pipelines.dream.utils import top_p_logits, top_k_logits +from dllm.core.generation.generator import ( + GeneratorOutput, + GeneratorConfig, + BaseGenerator, +) + + +def sample_tokens( + logits: torch.Tensor, + temperature: float = 0.0, + top_p: float | None = None, + top_k: int | None = None, + margin_confidence: bool = False, + neg_entropy: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + logits = top_p_logits(logits, top_p) + if top_k is not None: + logits = top_k_logits(logits, top_k) + + probs = torch.softmax(logits, dim=-1) + + if temperature > 0: + try: + x0 = dists.Categorical(probs=probs).sample() + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + except Exception: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + + +@dataclass +class DreamGeneratorConfig(GeneratorConfig): + max_new_tokens: int = 20 + max_length: int = ( + None # The max_length is set as input_ids.shape[1] + 20: generation_config.max_length = generation_config.max_length + input_ids_length + ) + steps: int = 512 + eps: float = 1e-3 + alg: str = "origin" + alg_temp: float = 0.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = 50 + stochastic_transfer: bool = False + + +@dataclass +class DreamGenerator(BaseGenerator): + @torch.no_grad() + def generate( + self, + inputs: list[torch.Tensor, list], + config: DreamGeneratorConfig | None = None, + generation_tokens_hook_func=lambda step, x, logits: x, + generation_logits_hook_func=lambda step, x, logits: logits, + **kwargs, + ) -> GeneratorOutput | torch.Tensor: + """ + Diffusion-style masked decoding for *generation from inputs*. + (docstring unchanged) + """ + if config is None: + config = DreamGeneratorConfig() + + # ----- pull args from config, allow kwargs to override ----- + max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens) + max_length = kwargs.get("max_length", config.max_length) + steps = kwargs.get("steps", config.steps) + eps = kwargs.get("eps", config.eps) + alg = kwargs.get("alg", config.alg) + alg_temp = kwargs.get("eps", config.alg_temp) + temperature = kwargs.get("temperature", config.temperature) + top_p = kwargs.get("top_p", config.top_p) + top_k = kwargs.get("top_k", config.top_k) + stochastic_transfer = kwargs.get( + "stochastic_transfer", config.stochastic_transfer + ) + # generation_tokens_hook_func = kwargs.get("generation_tokens_hook_func", config.generation_tokens_hook_func) + # generation_logits_hook_func = kwargs.get("generation_logits_hook_func", config.generation_logits_hook_func) + return_dict_in_generate = kwargs.get( + "return_dict_in_generate", config.return_dict_in_generate + ) + + # --- Initialization --- + mask_token_id = self.tokenizer.mask_token_id + eos_token_id = self.tokenizer.eos_token_id + + if isinstance(inputs[0], list): + inputs = [ + torch.as_tensor(p, dtype=torch.long, device=self.model.device) + for p in inputs + ] + prompt_lens = [p.shape[0] for p in inputs] + if max_new_tokens: + max_length = max_new_tokens + max(prompt_lens) + else: + max_new_tokens = max_length - max(prompt_lens) + + B = len(inputs) + T = max_length + x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device) + + seq_length = [] + for i, p in enumerate(inputs): + total_len = prompt_lens[i] + max_new_tokens + seq_length.append(total_len) + start = T - total_len + x[i, start : start + prompt_lens[i]] = p + x[i, start + prompt_lens[i] : T] = mask_token_id + + attention_mask = torch.zeros( + (B, T), dtype=torch.float32, device=self.model.device + ) + for j, L in enumerate(seq_length): + if L > 0: + attention_mask[j, -L:] = 1.0 # Mandate to be left-padding + + if attention_mask is not None and torch.any(attention_mask == 0.0): + pos_id = attention_mask.long().cumsum(-1) - 1 + pos_id.masked_fill_(attention_mask == 0, 1) + else: + pos_id = None + + mask_index = x == mask_token_id + num_transfer_tokens_list = get_num_transfer_tokens( + mask_index=mask_index, + steps=steps, + scheduler=self.scheduler, + stochastic=stochastic_transfer, + ) + effective_steps = num_transfer_tokens_list.size(1) + + # --- Iterative refinement --- + x = generation_tokens_hook_func(None, x, None) + histories = [x.clone()] if return_dict_in_generate else None + for i in range(effective_steps): + mask_index = x == mask_token_id + + logits = self.model(x, attention_mask, pos_id).logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + logits = generation_logits_hook_func(i, x, logits) + + mask_logits = logits[mask_index] + + if alg == "maskgit_plus": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k + ) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, + temperature=temperature, + top_p=top_p, + top_k=top_k, + margin_confidence=True, + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, + temperature=temperature, + top_p=top_p, + top_k=top_k, + neg_entropy=True, + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + + full_confidence = torch.full_like( + x, -torch.inf, device=self.model.device, dtype=logits.dtype + ) + full_confidence[mask_index] = confidence + + for j in range(full_confidence.shape[0]): + number_transfer_tokens = num_transfer_tokens_list[j, i] + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk( + full_confidence[j], number_transfer_tokens + ) + else: + fc = full_confidence[j] / alg_temp + fc = F.softmax(fc, dim=-1) + transfer_index = torch.multinomial( + fc, num_samples=number_transfer_tokens + ) + + x_ = torch.full_like(x, mask_token_id, device=self.model.device) + x_[mask_index] = x0.clone() + x[j, transfer_index] = x_[j, transfer_index] + + x = generation_tokens_hook_func(i, x, logits) + if histories is not None: + histories.append(x.clone()) + + if not return_dict_in_generate: + return x + else: + return GeneratorOutput(sequences=x, histories=histories) + + @torch.no_grad() + def infill( + self, + inputs: list[torch.Tensor, list], + config, + generation_tokens_hook_func=lambda step, x, logits: x, + generation_logits_hook_func=lambda step, x, logits: logits, + **kwargs, + ) -> GeneratorOutput | torch.Tensor: + """ + Fill in-place the tokenizer's `` tokens contained in `inputs`. + The whole (right-aligned) canvas is denoised iteratively: at each step, a scheduler + decides how many masked positions to commit, and a confidence rule (`alg`) + selects *which* positions to reveal (MaskGIT-style). Non-mask tokens are never changed. + + High-level: + 1) Build a right-aligned canvas per sample (left side padded with EOS). + 2) Compute a per-sample transfer schedule via `scheduler` and `steps`. + 3) At each step: forward pass → AR-shift logits → score masked positions + via `alg` → choose indices to commit (top-k or soft sampling) → write tokens. + + Notes: + - Right padding uses EOS (serves as pad here). + - Only `[MASK]` positions are updated; original tokens remain intact. + - Logits are AR-shifted to preserve next-token prediction alignment. + + Args: + model: + Mask predictor; returns logits of shape [B, T, V] when called as + `model(x, attention_mask, pos_id)`. + tokenizer: + Must provide `mask_token_id` and `eos_token_id`. + inputs: + List of 1D LongTensors (token ids). Each may contain `` tokens + to be filled; other tokens are treated as fixed context. + scheduler (BaseAlphaScheduler): + Controls how many masks to commit per step (deterministic or stochastic). + generation_tokens_hook_func / generation_logits_hook_func: + Optional hooks to intercept tokens/logits at each step. + output_history (bool): + If True, save intermediate canvases at each step. + return_dict_in_generate (bool): + If True, return `DreamModelOutput(sequences, history)`, else only `[B, T]`. + steps (int): + Total reverse-diffusion steps (quality–speed trade-off). + alg (str): + Confidence rule to rank masked positions: + - "maskgit_plus": softmax probs + - "topk_margin": top1 - top2 margin + - "entropy": negative entropy + alg_temp (float): + Temperature for *confidence-based index sampling* (when > 0, soft selection). + temperature / top_p / top_k: + Token sampling hyperparameters within `sample_tokens`. + stochastic_transfer (bool): + If True, sample the number of transfers per step (Binomial); else use expectation. + + Returns: + DreamModelOutput | torch.LongTensor: + If `return_dict_in_generate=True`, returns + - sequences: `[B, T]` final tokens + - history: optional list of intermediate canvases + Otherwise returns only `[B, T]`. + """ + # ----- pull args from config, allow kwargs to override ----- + steps = kwargs.get("steps", config.steps) + eps = kwargs.get("eps", config.eps) + alg = kwargs.get("alg", config.alg) + alg_temp = kwargs.get("eps", config.alg_temp) + temperature = kwargs.get("temperature", config.temperature) + top_p = kwargs.get("top_p", config.top_p) + top_k = kwargs.get("top_k", config.top_k) + stochastic_transfer = kwargs.get( + "stochastic_transfer", config.stochastic_transfer + ) + # generation_tokens_hook_func = kwargs.get("stochastic_transfer", config.generation_tokens_hook_func) + # generation_logits_hook_func = kwargs.get("stochastic_transfer", config.generation_logits_hook_func) + return_dict_in_generate = kwargs.get( + "return_dict_in_generate", config.return_dict_in_generate + ) + + # --- Initialization --- + mask_token_id = self.tokenizer.mask_token_id + eos_token_id = self.tokenizer.eos_token_id + + if isinstance(inputs[0], list): + inputs = [ + torch.as_tensor(p, dtype=torch.long, device=self.model.device) + for p in inputs + ] + + B = len(inputs) + seq_lens = [t.shape[0] for t in inputs] + T = max(seq_lens) + + # Build right-aligned canvas; left side padded with EOS (used as pad) + x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device) + for i, t in enumerate(inputs): + L = seq_lens[i] + x[i, -L:] = t + + # Build 1D attention mask (valid tokens on the right) + attention_mask = torch.zeros((B, T), dtype=torch.bool, device=self.model.device) + for j, L in enumerate(seq_lens): + if L > 0: + attention_mask[j, -L:] = True + + # Expand to pairwise attention if left padding is present + if torch.any(attention_mask == 0.0): + pos_id = attention_mask.long().cumsum(-1) - 1 + pos_id.masked_fill_(attention_mask == 0, 1) + else: + pos_id = None + attention_mask = "full" + + # Precompute per-sample transfer schedule (how many to commit per step) + mask_index = x == mask_token_id + num_transfer_tokens_list = get_num_transfer_tokens( + mask_index=mask_index, + steps=steps, + scheduler=self.scheduler, + stochastic=stochastic_transfer, + ) + effective_steps = num_transfer_tokens_list.size(1) + + # Optional initial token hook + x = generation_tokens_hook_func(None, x, None) + histories = [x.clone()] if return_dict_in_generate else None + for i in range(effective_steps): + mask_index = x == mask_token_id + + # Forward pass, then AR-shift to predict token at position i+1 + logits = self.model(x, attention_mask, pos_id).logits + logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + logits = generation_logits_hook_func(i, x, logits) + + # Logits restricted to current `[MASK]` positions + mask_logits = logits[mask_index] + + # Confidence scoring for masked positions + if alg == "maskgit_plus": + confidence, x0 = sample_tokens( + mask_logits, temperature=temperature, top_p=top_p, top_k=top_k + ) + elif alg == "topk_margin": + confidence, x0 = sample_tokens( + mask_logits, + temperature=temperature, + top_p=top_p, + top_k=top_k, + margin_confidence=True, + ) + elif alg == "entropy": + confidence, x0 = sample_tokens( + mask_logits, + temperature=temperature, + top_p=top_p, + top_k=top_k, + neg_entropy=True, + ) + else: + raise RuntimeError(f"Unknown alg: {alg}") + + # Scatter per-position confidence back to full canvas + full_confidence = torch.full_like( + x, -torch.inf, device=self.model.device, dtype=logits.dtype + ) + full_confidence[mask_index] = confidence + + # Commit the scheduled number of tokens per sample + for j in range(B): + number_transfer_tokens = num_transfer_tokens_list[j, i] + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk( + full_confidence[j], number_transfer_tokens + ) + else: + fc = full_confidence[j] / alg_temp + fc = F.softmax(fc, dim=-1) + transfer_index = torch.multinomial( + fc, num_samples=number_transfer_tokens + ) + + # Candidate tokens at masked positions only + x_ = torch.full_like(x, mask_token_id, device=self.model.device) + x_[mask_index] = x0.clone() + x[j, transfer_index] = x_[j, transfer_index] + + # Optional token hook + history logging + x = generation_tokens_hook_func(i, x, logits) + if histories is not None: + histories.append(x.clone()) + + if not return_dict_in_generate: + return x + else: + return GeneratorOutput(sequences=x, histories=histories) diff --git a/dllm/dllm/pipelines/dream/models/__init__.py b/dllm/dllm/pipelines/dream/models/__init__.py new file mode 100644 index 0000000..d7093cb --- /dev/null +++ b/dllm/dllm/pipelines/dream/models/__init__.py @@ -0,0 +1,13 @@ +from .configuration_dream import DreamConfig +from .modeling_dream import DreamModel + +# Register with HuggingFace Auto classes for local usage +try: + from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM + + AutoConfig.register("Dream", DreamConfig) + AutoModel.register(DreamConfig, DreamModel) + AutoModelForMaskedLM.register(DreamConfig, DreamModel) +except ImportError: + # transformers not available or Auto classes not imported + pass diff --git a/dllm/dllm/pipelines/dream/models/configuration_dream.py b/dllm/dllm/pipelines/dream/models/configuration_dream.py new file mode 100644 index 0000000..3497b1d --- /dev/null +++ b/dllm/dllm/pipelines/dream/models/configuration_dream.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dream model configuration""" +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class DreamConfig(PretrainedConfig): + model_type = "Dream" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=False, # cache not used in diffusion + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + mask_token_id=151666, + pad_token_id=151643, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.mask_token_id = mask_token_id + self.pad_token_id = pad_token_id diff --git a/dllm/dllm/pipelines/dream/models/generation_utils.py b/dllm/dllm/pipelines/dream/models/generation_utils.py new file mode 100644 index 0000000..156ddca --- /dev/null +++ b/dllm/dllm/pipelines/dream/models/generation_utils.py @@ -0,0 +1,465 @@ +# coding=utf-8 +# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +import copy +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.distributions as dists +from torch.nn import functional as F +from transformers import __version__ +from transformers.generation.configuration_utils import ( + GenerationConfig +) +from transformers.utils import ( + ModelOutput, + is_torchdynamo_compiling, + logging, +) + +logger = logging.get_logger(__name__) + + +def top_p_logits(logits, top_p=None): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) + mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) + return logits + +def top_k_logits(logits, top_k=None): + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + + +def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): + + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + logits = top_p_logits(logits, top_p) + if top_k is not None: + logits = top_k_logits(logits, top_k) + probs = torch.softmax(logits, dim=-1) + + if temperature > 0: + try: + x0 = dists.Categorical(probs=probs).sample() + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + except: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # Extract top1 and top2 probabilities + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + # Calculate confidence as top1 - top2 + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 + + +@dataclass +class DreamModelOutput(ModelOutput): + sequences: torch.LongTensor = None + history: Optional[Tuple[torch.FloatTensor]] = None + + +class DreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + self.temperature: float = kwargs.pop("temperature", 0.0) + self.top_p: Optional[float] = kwargs.pop("top_p", None) + self.top_k: Optional[int] = kwargs.pop("top_k", None) + self.max_length = kwargs.pop("max_length", 20) + self.max_new_tokens = kwargs.pop("max_new_tokens", None) + # diffusion specific params + self.eps: float = kwargs.pop("eps", 1e-3) + self.steps: int = kwargs.pop("steps", 512) + self.alg: str = kwargs.pop("alg", 'origin') + self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) + + # Parameters that define the output variables of `generate` + self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) + self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) + self.output_history: bool = kwargs.pop("output_history", False) + + # Special tokens that can be used at generation time + self.mask_token_id = kwargs.pop("mask_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + + # Wild card + self.generation_kwargs = kwargs.pop("generation_kwargs", {}) + + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub + # interface. + self._from_model_config = kwargs.pop("_from_model_config", False) + self._commit_hash = kwargs.pop("_commit_hash", None) + self.transformers_version = kwargs.pop("transformers_version", __version__) + + # Additional attributes without default values + if not self._from_model_config: + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a + # model's default configuration file + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + # Validate the values of the attributes + self.validate(is_init=True) + + def validate(self, is_init=False, **kwargs): + pass + +class DreamGenerationMixin: + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + # Do not call torch.repeat_interleave if expand_size is 1 because it clones + # the input tensor and thus requires more memory although no change is applied + if expand_size == 1: + return input_ids, attention_mask + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) + return input_ids, attention_mask + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + """Performs validation related to the resulting generated length""" + + # Can't throw warnings/exceptions during compilation + if is_torchdynamo_compiling(): + return + + # 1. Max length warnings related to poor parameterization + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + # 20 is the default max_length of the generation config + warnings.warn( + f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " + "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " + "generation.", + UserWarning, + ) + if input_ids_length >= generation_config.max_length: + input_ids_string = "input_ids" + raise ValueError( + f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_length` or, better yet, setting `max_new_tokens`." + ) + + def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + input_ids_length, + ): + """Prepared max and min length in generation configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + elif has_default_max_length: + if generation_config.max_length == DreamGenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + + return generation_config + + def _prepare_generation_config( + self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict + ) -> DreamGenerationConfig: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. This + function handles retrocompatibility with respect to configuration files. + """ + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + using_model_generation_config = False + if generation_config is None: + generation_config = DreamGenerationConfig.from_model_config(self.config) + using_model_generation_config = True + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): + generation_config = copy.deepcopy(generation_config) + _kwargs = generation_config.update(**kwargs) + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + if not using_model_generation_config: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.mask_token_id is None: + generation_config.mask_token_id = self.generation_config.mask_token_id + + return generation_config + + def _prepare_special_tokens( + self, + generation_config: DreamGenerationConfig, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + + Note that `generation_config` is changed in place and stops being serializable after this method is called. + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): + if token is None: + return token + + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) + + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_tensor is None and eos_token_tensor is not None: + pad_token_tensor = eos_token_tensor[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") + + # Update generation config with the updated special tokens tensors + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._mask_token_tensor = mask_token_tensor + + @torch.no_grad() + def diffusion_generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[DreamGenerationConfig] = None, + **kwargs, + ) -> Union[DreamModelOutput, torch.LongTensor]: + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + generation_config = self._prepare_generation_config(generation_config, **kwargs) + generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x) + generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) + + # 2. Define model inputs + assert inputs is not None + input_ids = inputs + device = input_ids.device + attention_mask = kwargs.pop("attention_mask", None) + self._prepare_special_tokens(generation_config, device=device) + + # 3. Prepare `max_length`. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + input_ids_length=input_ids_length, + ) + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 4. Check input_ids + if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + if ( + hasattr(generation_config, "pad_token_id") and + torch.any(input_ids == generation_config.pad_token_id) and + attention_mask is None + ): + warnings.warn( + "Padding was detected but no attention mask is passed here. For correct " + "generation results, please set `attention_mask` when batch-padding inputs.", + UserWarning, + ) + + input_ids, attention_mask = self._expand_inputs_for_generation( + expand_size=generation_config.num_return_sequences, + input_ids=input_ids, + attention_mask=attention_mask + ) + + result = self._sample( + input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + generation_tokens_hook_func=generation_tokens_hook_func, + generation_logits_hook_func=generation_logits_hook_func + ) + return result + + def _sample( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.LongTensor], + generation_config: DreamGenerationConfig, + generation_tokens_hook_func, + generation_logits_hook_func + ) -> Union[DreamModelOutput, torch.LongTensor]: + # init values + + output_history = generation_config.output_history + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + mask_token_id = generation_config.mask_token_id + steps = generation_config.steps + eps = generation_config.eps + alg = generation_config.alg + alg_temp = generation_config.alg_temp + temperature = generation_config.temperature + top_p = generation_config.top_p + top_k = generation_config.top_k + + histories = [] if (return_dict_in_generate and output_history) else None + + # pad input_ids to max_length + x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) + + if attention_mask is not None and torch.any(attention_mask == 0.0): + # we do not mask the [MASK] tokens so value = 1.0 + attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) + tok_idx = attention_mask.long().cumsum(-1) - 1 + tok_idx.masked_fill_(attention_mask == 0, 1) + # attention_mask is of shape [B, N] + # broadcast to [B, 1, N, N] + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + else: + tok_idx = None + attention_mask = "full" + + timesteps = torch.linspace(1, eps, steps + 1, device=x.device) + + # this allows user-defined token control of the intermediate steps + x = generation_tokens_hook_func(None, x, None) + for i in range(steps): + + mask_index = (x == mask_token_id) + logits = self(x, attention_mask, tok_idx).logits + logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) + # this allows user-defined logits control of the intermediate steps + logits = generation_logits_hook_func(i, x, logits) + + mask_logits = logits[mask_index] + t = timesteps[i] + s = timesteps[i + 1] + + if alg == 'origin': + p_transfer = 1 - s / t if i < steps - 1 else 1 + x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id + transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer + _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) + x[mask_index] = x0.clone() + else: + if alg == 'maskgit_plus': + confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k) + elif alg == 'topk_margin': + confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True) + elif alg == 'entropy': + confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True) + else: + raise RuntimeError(f"Unknown alg: {alg}") + num_mask_token = mask_index.sum() / mask_index.shape[0] + number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token) + full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype) + full_confidence[mask_index] = confidence + if number_transfer_tokens > 0: + if alg_temp is None or alg_temp == 0: + _, transfer_index = torch.topk(full_confidence, number_transfer_tokens) + else: + full_confidence = full_confidence / alg_temp + full_confidence = F.softmax(full_confidence, dim=-1) + transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens) + x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id + x_[mask_index] = x0.clone() + row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index) + x[row_indices,transfer_index] = x_[row_indices,transfer_index] + + # this allows user-defined token control of the intermediate steps + x = generation_tokens_hook_func(i, x, logits) + + if histories is not None: + histories.append(x.clone()) + + if return_dict_in_generate: + return DreamModelOutput( + sequences=x, + history=histories, + ) + else: + return x diff --git a/dllm/dllm/pipelines/dream/models/modeling_dream.py b/dllm/dllm/pipelines/dream/models/modeling_dream.py new file mode 100644 index 0000000..4d9af95 --- /dev/null +++ b/dllm/dllm/pipelines/dream/models/modeling_dream.py @@ -0,0 +1,850 @@ +# coding=utf-8 +# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT and Qwen implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Dream model.""" + +import math +from typing import List, Optional, Tuple, Union +import os +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from transformers import PretrainedConfig +from .configuration_dream import DreamConfig +from .generation_utils import DreamGenerationMixin, DreamGenerationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "Dream-7B" +_CONFIG_FOR_DOC = "DreamConfig" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream +class DreamRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DreamRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream +class DreamRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[DreamConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def reset_parameters(self): + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream +class DreamMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class DreamAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = False + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = DreamRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class DreamSdpaAttention(DreamAttention): + """ + Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from DreamAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # causal_mask = attention_mask + # if attention_mask is not None: # no matter the length, we just slice it + # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + # is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=False, # hard coded + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class DreamDecoderLayer(nn.Module): + def __init__(self, config: DreamConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + # self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = DreamSdpaAttention(config, layer_idx) + + self.mlp = DreamMLP(config) + self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +class DreamPreTrainedModel(PreTrainedModel): + config_class = DreamConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DreamDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + weights_only: bool = True, + **kwargs, + ): + _model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + # NOTE(Lin): we need to override the generation config + # because the generation config loaded in `from_pretrained` + # does not include all the attributes of DreamGenerationConfig + resume_download = kwargs.get("resume_download", None) + proxies = kwargs.get("proxies", None) + subfolder = kwargs.get("subfolder", "") + from_auto_class = kwargs.get("_from_auto", False) + from_pipeline = kwargs.get("_from_pipeline", None) + _model.generation_config = DreamGenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + ) + return _model + +class DreamBaseModel(DreamPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`] + + Args: + config: DreamConfig + """ + + def __init__(self, config: DreamConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DreamRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DreamModel(DreamGenerationMixin, DreamPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DreamBaseModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def reset_rope_parameters(self): + self.model.rotary_emb.reset_parameters() + for layer in self.model.layers: + layer.self_attn.rotary_emb.reset_parameters() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, MaskedLMOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if isinstance(attention_mask, str) and attention_mask == "full" or attention_mask == None: + # whether attention_mask is full + pass + + elif isinstance(attention_mask, torch.Tensor): + if not torch.any(attention_mask == 0.0): + attention_mask = 'full' + elif attention_mask.dim() == 2: + # [B, L] → [B, 1, L, L] + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + attention_mask = attention_mask.to(torch.bool) + + elif attention_mask.dim() in (3, 4): + # already extended/broadcasted form + if attention_mask.dtype != torch.bool: + attention_mask = attention_mask.to(torch.bool) + + else: + raise ValueError(f"Unexpected attention_mask shape: {attention_mask.shape}") + + else: + raise TypeError(f"Unsupported attention_mask type: {type(attention_mask)}") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/dllm/dllm/pipelines/dream/models/tokenization_dream.py b/dllm/dllm/pipelines/dream/models/tokenization_dream.py new file mode 100644 index 0000000..7202ab4 --- /dev/null +++ b/dllm/dllm/pipelines/dream/models/tokenization_dream.py @@ -0,0 +1,346 @@ +# coding=utf-8 +# Copyright 2024 The Dream team, HKUNLP Group and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on Qwen's implementations in this library. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Dream.""" + +import json +import os +import unicodedata +from functools import lru_cache +from typing import Optional, Tuple + +import regex as re + +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + + +MAX_MODEL_INPUT_SIZES = {"dream/dream-tokenizer": 32768} + +PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" + + +@lru_cache() +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class DreamTokenizer(PreTrainedTokenizer): + """ + Construct a Dream tokenizer. Based on byte-level Byte-Pair-Encoding. + + Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will + be encoded differently whether it is at the beginning of the sentence (without space) or not: + + ```python + >>> from transformers import AutoTokenizer + + >>> tokenizer = AutoTokenizer.from_pretrained("Dream-org/Dream-v0-Base-7B", trust_remote_code=True) + >>> tokenizer("Hello world")["input_ids"] + [9707, 1879] + + >>> tokenizer(" Hello world")["input_ids"] + [21927, 1879] + ``` + This is expected. + + You should not use GPT2Tokenizer instead, because of the different pretokenization rules. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str`, *optional*): + The beginning of sequence token. Not applicable for this tokenizer. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding, for example when batching sequences of different lengths. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = + ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', + '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + clean_up_tokenization_spaces=False, + split_special_tokens=False, + **kwargs, + ): + # Dream vocab does not contain control tokens; added tokens need to be special + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) + if isinstance(pad_token, str) + else pad_token + ) + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_merges = [] + with open(merges_file, encoding="utf-8") as merges_handle: + for i, line in enumerate(merges_handle): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + bpe_merges.append(tuple(line.split())) + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # NOTE: the cache can grow without bound and will get really large for long running processes + # (esp. for texts of language that do not use space between word, e.g. Chinese); technically + # not a memory leak but appears as one. + # GPT2Tokenizer has the same problem, so let's be consistent. + self.cache = {} + + self.pat = re.compile(PRETOKENIZE_REGEX) + + if kwargs.get("add_prefix_space", False): + logger.warning_once( + f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." + ) + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + unk_token=unk_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + split_special_tokens=split_special_tokens, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def decode( + self, + token_ids, + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers + # and cannot be configured elsewhere, but it should default to False for DreamTokenizer + return super().decode( + token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + def prepare_for_tokenization(self, text, **kwargs): + text = unicodedata.normalize("NFC", text) + return (text, kwargs) + + +from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING +from .configuration_dream import DreamConfig + +TOKENIZER_MAPPING.register(DreamConfig, (DreamTokenizer, None)) diff --git a/dllm/dllm/pipelines/dream/trainer.py b/dllm/dllm/pipelines/dream/trainer.py new file mode 100644 index 0000000..a28c12d --- /dev/null +++ b/dllm/dllm/pipelines/dream/trainer.py @@ -0,0 +1,84 @@ +from typing import Any + +import torch + +from dllm.core.trainers import MDLMTrainer + + +def cart_weight( + masked_indices: torch.Tensor, t: torch.Tensor, p: float = 0.3 +) -> torch.Tensor: + """ + Optimized CART weight computation using matrix operations. + + Args: + masked_indices (torch.Tensor): (b, l) bool tensor indicating masked positions. + t (torch.Tensor): (b,) time steps (0-1 sampled uniformly). Not directly used in CART. + p (float): Parameter of geometric distribution (0 < p <= 1). + + Returns: + torch.Tensor: (b, l) float tensor of weights. + """ + b, l = masked_indices.shape + device = masked_indices.device + + idx = torch.arange(l, device=device) + dist_matrix = (idx[None, :] - idx[:, None]).abs() - 1 + dist_matrix = torch.clamp(dist_matrix, min=0) # (l, l) + geo_matrix = ( + torch.log(torch.tensor(p, device=device)) + + (dist_matrix - 1).clamp(min=0) * torch.log(torch.tensor(1 - p, device=device)) + ).exp() * 0.5 # Ensure numerical stability + geo_matrix.masked_fill_(dist_matrix == 0, 0.0) # ignore distance = 0 + + valid_mask = (~masked_indices).float() # (b, l), 1 = unmasked + weights = valid_mask @ geo_matrix.T # (b, l) + weights = weights * masked_indices.float() + return weights + + +class DreamTrainer(MDLMTrainer): + """ + DreamTrainer: specialization of MDLMTrainer for Dream training. + """ + + def __init__( + self, + *args, + loss_weight_type: str = "cart[geo_p:0.3]", + **kwargs, + ): + super().__init__(*args, loss_weight_type=loss_weight_type, **kwargs) + + def _preprocess_inputs(self, inputs): + labels = inputs["labels"] + assert (labels[:, 0] == -100).all() + + def _postprocess_outputs(self, outputs): + logits = outputs.logits + outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) + + def _compute_loss_weights( + self, + t: torch.Tensor, + inputs: dict[str, Any], + masked_indices: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + if self.loss_weight_type.startswith("cart"): + # parse geo_p + import re + + match = re.search(r"geo_p:(0\.\d+)", self.loss_weight_type) + geo_p = float(match.group(1)) if match else 0.3 + loss_weights = cart_weight(masked_indices, t, p=geo_p) + else: + loss_weights = super()._compute_loss_weights( + t=t, + inputs=inputs, + masked_indices=masked_indices, + *args, + **kwargs, + ) + return loss_weights diff --git a/dllm/dllm/pipelines/dream/utils.py b/dllm/dllm/pipelines/dream/utils.py new file mode 100644 index 0000000..270ec11 --- /dev/null +++ b/dllm/dllm/pipelines/dream/utils.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F +import transformers + + +def top_p_logits(logits, top_p=None): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) + mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) + return logits + + +def top_k_logits(logits, top_k=None): + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + + +@dataclass +class DreamSFTCollator(transformers.DataCollatorForSeq2Seq): + """ + Randomly crop response length to reduce length bias during generation. + + Reference: https://github.com/DreamLM/Dream/blob/main/src/trainer/fsdp_sft_trainer.py + """ + + perbatch_cutoff: bool = True # Use prebatch truncation if True + resp_cutoff_ratio: float = 0.0 # Prob. of post-collation truncation + + # ------------------------------------------------------------------------- + # 1) Pre-collation truncation (per-sample) + # ------------------------------------------------------------------------- + def apply_perbatch_cutoff(self, features): + """ + Randomly pick a response length from batch (`kept_len`) and trim other responses. + Before: + [<--promptA----><------responseA------>] + [<--promptB-><---responseB--->] + [<---promptC----><--respC-->] + After: + [<--promptA----><---respA--->] + [<--promptB-><--respB-->] + [<---promptC----><--respC-->] + kept_len = 10 → trim each response to ≤10 tokens (before padding) + """ + resp_lens = torch.tensor( + [len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long + ) + kept_len = int(np.random.choice(resp_lens)) + for f, r_len in zip(features, resp_lens): + remove_len = max(r_len - kept_len, 0) + if remove_len > 0: + # f["input_ids"] = f["input_ids"][:-remove_len] + # f["attention_mask"] = f["attention_mask"][:-remove_len] + # f["labels"] = f["labels"][:-remove_len] + for key in ["input_ids", "labels", "attention_mask"]: + if key in f: + f[key] = f[key][:-remove_len] + return features + + # ------------------------------------------------------------------------- + # 2) Post-collation truncation + # ------------------------------------------------------------------------- + def apply_resp_cutoff(self, batch, features): + """ + Uniformly chop tail *after padding*. All sequences truncated to new_seq_len. + Before: + [<--promptA----><-----respA----->] 40 + [<--promptB-><----pad---->] 40 + [<---promptC----><--respC-->] 40 + cutoff_len = 5 + After: + [<--promptA----><--respA--->] 35 + [<--promptB-><--pad->] 35 + [<---promptC----><--respC-->] 35 + """ + orig_seq_lens = [len(f["input_ids"]) for f in features] + resp_lens = torch.tensor( + [len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long + ) + min_resp_len = resp_lens.min().item() + if min_resp_len <= 1: + return batch + + cutoff_len = int(np.random.randint(1, min_resp_len)) + new_seq_len = max(orig_seq_lens) - cutoff_len + + for key in ["input_ids", "labels", "attention_mask"]: + if key in batch: + batch[key] = batch[key][:, :new_seq_len].contiguous() + return batch + + # ------------------------------------------------------------------------- + # 3) Main call: pick truncation mode + # ------------------------------------------------------------------------- + def __call__(self, features, return_tensors=None): + # optional pre-collation truncation + if self.perbatch_cutoff: + features = self.apply_perbatch_cutoff(features) + + # always collate only the needed fields + base = [ + {k: f[k] for k in ("input_ids", "labels", "attention_mask") if k in f} + for f in features + ] + batch = super().__call__(base, return_tensors=return_tensors) + + # optional post-collation truncation + if ( + not self.perbatch_cutoff + and self.resp_cutoff_ratio > 0 + and np.random.rand() < self.resp_cutoff_ratio + ): + batch = self.apply_resp_cutoff(batch, features) + + batch.pop("prompt_len", None) + return batch + + +@dataclass +class DreamPTCollator(transformers.DataCollatorForSeq2Seq): + random_length_ratio: float = 0.01 + + def __call__(self, features, return_tensors=None): + outputs = super().__call__(features, return_tensors=return_tensors) + input_ids, labels, attention_mask = ( + outputs["input_ids"], + outputs["labels"], + outputs["attention_mask"], + ) + bsz, seq_len = input_ids.shape + + # --- Random truncation for robustness --- + if torch.rand(1).item() < self.random_length_ratio: + random_len = torch.randint(1, seq_len + 1, (1,)).item() + input_ids = input_ids[:, :random_len] + labels = labels[:, :random_len] + attention_mask = attention_mask[:, :random_len] + + # --- Add BOS token to the beginning of input_ids --- + bos = torch.full( + (bsz, 1), + self.tokenizer.bos_token_id, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([bos, input_ids], dim=1) + + # --- Prepend zeros to labels instead of BOS --- + ignore_labels = self.label_pad_token_id * torch.ones( + (bsz, 1), dtype=labels.dtype, device=labels.device + ) + labels = torch.cat([ignore_labels, labels], dim=1) + + # --- Prepend ones to attention_mask --- + bos_attention = torch.ones( + (bsz, 1), dtype=attention_mask.dtype, device=attention_mask.device + ) + attention_mask = torch.cat([bos_attention, attention_mask], dim=1) + + # --- Update and return --- + outputs["input_ids"] = input_ids + outputs["labels"] = labels + outputs["attention_mask"] = attention_mask + # Check if attention_mask is all ones and set it to None + if torch.all(outputs["attention_mask"] == 1): + outputs.pop("attention_mask") + return outputs diff --git a/dllm/dllm/pipelines/editflow/__init__.py b/dllm/dllm/pipelines/editflow/__init__.py new file mode 100644 index 0000000..c7aa217 --- /dev/null +++ b/dllm/dllm/pipelines/editflow/__init__.py @@ -0,0 +1,14 @@ +from . import trainer, utils +from .models.dream.modelling_dream import ( + EditFlowDreamConfig, + EditFlowDreamModel, +) +from .models.llada.modelling_llada import ( + EditFlowLLaDAConfig, + EditFlowLLaDAModel, +) +from .models.bert.modelling_modernbert import ( + EditFlowModernBertConfig, + EditFlowModernBertModel, +) +from dllm.pipelines.editflow.trainer import EditFlowTrainer diff --git a/dllm/dllm/pipelines/editflow/models/bert/modelling_modernbert.py b/dllm/dllm/pipelines/editflow/models/bert/modelling_modernbert.py new file mode 100644 index 0000000..d029802 --- /dev/null +++ b/dllm/dllm/pipelines/editflow/models/bert/modelling_modernbert.py @@ -0,0 +1,89 @@ +import torch +from torch import nn + + +import transformers + + +class EditFlowModernBertConfig(transformers.ModernBertConfig): + model_type = "editflow-modernbert" # <- NEW model_type + + +class EditFlowModernBertModel(transformers.ModernBertForMaskedLM): + config_class = EditFlowModernBertConfig + modules_to_save = { + "rate_heads", + "sub_logits", + "ins_logits", + } # fully fintuned even using lora + + def __init__(self, config): + # fa2 has bugs when forward(output_hidden_states=True) + config._attn_implementation = "sdpa" + super().__init__(config) + in_lm, out_lm = self.decoder.in_features, self.decoder.out_features + use_bias = self.decoder.bias is not None + # Create new, independent heads (no deepcopy) + self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias) + self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias) + self.rate_heads = nn.Sequential(nn.Linear(in_lm, 3), nn.Softplus()) + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor | None = None, + t: torch.Tensor | None = None, + **kwargs, + ): + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + h = output["hidden_states"][-1] # final hidden states + h = self.head(h) + # Position heads + sub_log = self.sub_logits(h) # [B, L, V] + ins_log = self.ins_logits(h) # [B, L, V] + + rates = self.rate_heads(h) + sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind( + -1 + ) # [B, L], [B, L], [B, L] + return dict( + sub_rate_hat=sub_rate_hat, # [B,L] + del_rate_hat=del_rate_hat, # [B,L] + ins_rate_hat=ins_rate_hat, # [B,L] + ins_logits=ins_log, # [B,L,V] + sub_logits=sub_log, # [B,L,V] + ) + + +from transformers.models.auto import AutoModel, AutoConfig + +# Register the model so that it is available for transformer pipelines, auto-loading, etc. +AutoConfig.register("editflow-modernbert", EditFlowModernBertConfig) +AutoModel.register(EditFlowModernBertConfig, EditFlowModernBertModel) + + +if __name__ == "__main__": + import dllm + import torch + from transformers import AutoConfig, AutoModel + + # Load a config from a local path (either a directory containing config.json, or the file itself) + config_path = dllm.utils.resolve_with_base_env( + "answerdotai/ModernBERT-base", "BASE_MODELS_DIR" + ) + config = EditFlowModernBertConfig.from_pretrained(config_path) + if hasattr(config, "auto_map"): + delattr(config, "auto_map") + if hasattr(config, "architectures"): + delattr(config, "architectures") + + torch.set_default_device("cuda") + model = EditFlowModernBertModel(config) + model.save_pretrained("models-tmp/editflow-modernbert") + auto_model = AutoModel.from_pretrained("models-tmp/editflow-modernbert") diff --git a/dllm/dllm/pipelines/editflow/models/dream/modelling_dream.py b/dllm/dllm/pipelines/editflow/models/dream/modelling_dream.py new file mode 100644 index 0000000..54e623c --- /dev/null +++ b/dllm/dllm/pipelines/editflow/models/dream/modelling_dream.py @@ -0,0 +1,97 @@ +import copy +from typing import Optional + +import torch +from torch import nn + +from dllm.pipelines import dream + + +class EditFlowDreamConfig(dream.DreamConfig): + model_type = "editflow-dream" # <- NEW model_type + + +class EditFlowDreamModel(dream.DreamModel): + config_class = EditFlowDreamConfig + modules_to_save = { + "rate_heads", + "sub_logits", + "ins_logits", + } # fully fintuned even using lora + + def __init__(self, config): + super().__init__(config) + in_lm, out_lm = self.lm_head.in_features, self.lm_head.out_features + use_bias = self.lm_head.bias is not None + # Create new, independent heads (no deepcopy) + self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias) + self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias) + self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus()) + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor | None = None, + t: torch.Tensor | None = None, + **kwargs, + ): + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + h = output["hidden_states"][-1] # final hidden states + # Position heads + sub_log = self.sub_logits(h) # [B, L, V] + sub_log = torch.concatenate( + [torch.zeros_like(sub_log)[:, :1], sub_log[:, :-1]], dim=1 + ) # [B, L, V] + ins_log = self.ins_logits(h) # [B, L, V] + + rates = self.rate_heads(h) + sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind( + -1 + ) # [B, L], [B, L], [B, L] + sub_rate_hat = torch.concatenate( + [torch.zeros_like(sub_rate_hat[:, :1]), sub_rate_hat[:, :-1]], dim=1 + ) # [B, L] + del_rate_hat = torch.concatenate( + [torch.zeros_like(del_rate_hat[:, :1]), del_rate_hat[:, :-1]], dim=1 + ) # [B, L] + return dict( + sub_rate_hat=sub_rate_hat, # [B,L] + del_rate_hat=del_rate_hat, # [B,L] + ins_rate_hat=ins_rate_hat, # [B,L] + ins_logits=ins_log, # [B,L,V] + sub_logits=sub_log, # [B,L,V] + ) + + +from transformers.models.auto import AutoModel, AutoConfig + +# Register the model so that it is available for transformer pipelines, auto-loading, etc. +AutoConfig.register("editflow-dream", EditFlowDreamConfig) +AutoModel.register(EditFlowDreamConfig, EditFlowDreamModel) + + +if __name__ == "__main__": + import dllm + import torch + from transformers import AutoConfig, AutoModel + + # Load a config from a local path (either a directory containing config.json, or the file itself) + config_path = dllm.utils.resolve_with_base_env( + "Dream-org/Dream-v0-Base-7B", "BASE_MODELS_DIR" + ) + config = EditFlowDreamConfig.from_pretrained(config_path) + if hasattr(config, "auto_map"): + delattr(config, "auto_map") + if hasattr(config, "architectures"): + delattr(config, "architectures") + + torch.set_default_device("cuda") + model = EditFlowDreamModel(config) + model.save_pretrained("models-tmp/editflow-dream") + auto_model = AutoModel.from_pretrained("models-tmp/editflow-dream") diff --git a/dllm/dllm/pipelines/editflow/models/llada/modelling_llada.py b/dllm/dllm/pipelines/editflow/models/llada/modelling_llada.py new file mode 100644 index 0000000..f47e7bc --- /dev/null +++ b/dllm/dllm/pipelines/editflow/models/llada/modelling_llada.py @@ -0,0 +1,91 @@ +import copy +from typing import Optional + +import torch +from torch import nn + +from dllm.pipelines import llada + + +class EditFlowLLaDAConfig(llada.LLaDAConfig): + model_type = "editflow-llada" # <- NEW model_type + + +class EditFlowLLaDAModel(llada.LLaDAModelLM): + config_class = EditFlowLLaDAConfig + modules_to_save = { + "rate_heads", + "sub_logits", + "ins_logits", + } # fully fintuned even using lora + + def __init__(self, config): + # TODO: time embedding + super().__init__(config) + ff = self.model.transformer.ff_out + in_f, out_f = ff.in_features, ff.out_features + use_bias = ff.bias is not None + # Create new, independent heads (no deepcopy) + self.sub_logits = nn.Linear(in_f, out_f, bias=use_bias) + self.ins_logits = nn.Linear(in_f, out_f, bias=use_bias) + self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus()) + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: torch.Tensor | None = None, + t: torch.Tensor | None = None, + **kwargs, + ): + # TODO: time embedding + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + h = output["hidden_states"][-1] # final hidden states + # Position heads + sub_log = self.sub_logits(h) # [B, L, V] + ins_log = self.ins_logits(h) # [B, L, V] + + rates = self.rate_heads(h) + sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind( + -1 + ) # [B, L], [B, L], [B, L] + return dict( + sub_rate_hat=sub_rate_hat, # [B,L] + del_rate_hat=del_rate_hat, # [B,L] + ins_rate_hat=ins_rate_hat, # [B,L] + ins_logits=ins_log, # [B,L,V] + sub_logits=sub_log, # [B,L,V] + ) + + +from transformers.models.auto import AutoModel, AutoConfig + +# Register the model so that it is available for transformer pipelines, auto-loading, etc. +AutoConfig.register("editflow-llada", EditFlowLLaDAConfig) +AutoModel.register(EditFlowLLaDAConfig, EditFlowLLaDAModel) + + +if __name__ == "__main__": + import dllm + import torch + from transformers import AutoConfig, AutoModel + + # Load a config from a local path (either a directory containing config.json, or the file itself) + config_path = dllm.utils.resolve_with_base_env( + "GSAI-ML/LLaDA-8B-Base", "BASE_MODELS_DIR" + ) + config = EditFlowLLaDAConfig.from_pretrained(config_path) + if hasattr(config, "auto_map"): + delattr(config, "auto_map") + if hasattr(config, "architectures"): + delattr(config, "architectures") + + torch.set_default_device("cuda") + model = EditFlowLLaDAModel(config) + model.save_pretrained("models-tmp/editflow-llada") + auto_model = AutoModel.from_pretrained("models-tmp/editflow-llada") diff --git a/dllm/dllm/pipelines/editflow/trainer.py b/dllm/dllm/pipelines/editflow/trainer.py new file mode 100644 index 0000000..33c67c6 --- /dev/null +++ b/dllm/dllm/pipelines/editflow/trainer.py @@ -0,0 +1,407 @@ +from typing import Any, Dict, Union, List, Tuple, Optional +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import transformers + +from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler +from dllm.pipelines.editflow.utils import pad_1d + + +BLANK = -1 + + +def align_with_blanks( + x0: List[int], x1: List[int], sub_cost: int = 1, gap_cost: int = 1 +) -> Dict: + """ + Needleman–Wunsch global alignment of two integer sequences with: + match cost = 0, substitution cost = sub_cost, gap cost = gap_cost. + Returns aligned sequences (z0, z1) of equal length containing BLANK = ε where gaps occur. + """ + n, m = len(x0), len(x1) + # DP tables + dp = [[0] * (m + 1) for _ in range(n + 1)] + ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag', 'up', 'left' + + for i in range(1, n + 1): + dp[i][0] = i * gap_cost + ptr[i][0] = "up" + for j in range(1, m + 1): + dp[0][j] = j * gap_cost + ptr[0][j] = "left" + + for i in range(1, n + 1): + for j in range(1, m + 1): + cost_diag = dp[i - 1][j - 1] + (0 if x0[i - 1] == x1[j - 1] else sub_cost) + cost_up = dp[i - 1][j] + gap_cost + cost_left = dp[i][j - 1] + gap_cost + best = min(cost_diag, cost_up, cost_left) + dp[i][j] = best + if best == cost_diag: + ptr[i][j] = "diag" + elif best == cost_up: + ptr[i][j] = "up" + else: + ptr[i][j] = "left" + + # traceback + z0, z1 = [], [] + i, j = n, m + while i > 0 or j > 0: + p = ptr[i][j] + if p == "diag": + z0.append(x0[i - 1]) + z1.append(x1[j - 1]) + i -= 1 + j -= 1 + elif p == "up": + z0.append(x0[i - 1]) + z1.append(BLANK) + i -= 1 + else: # 'left' + z0.append(BLANK) + z1.append(x1[j - 1]) + j -= 1 + z0.reverse() + z1.reverse() + # return Alignment(z0=z0, z1=z1) + # return {"z0": z0, "z1": z1} + return dict(z0=z0, z1=z1) + + +# def align_with_blanks( +# x0: list[int], x1: list[int], sub_cost: int = 1, gap_cost: int = 1 +# ) -> dict: +# """ +# Needleman–Wunsch with a secondary objective that defers gaps to the end: +# - 'up' (gap in z1) is penalized if j < m +# - 'left' (gap in z0) is penalized if i < n +# This pushes blanks (-1) to the *right* whether x0 > x1 or x0 < x1. +# """ +# n, m = len(x0), len(x1) + +# dp_cost = [[0] * (m + 1) for _ in range(n + 1)] +# dp_pen = [[0] * (m + 1) for _ in range(n + 1)] +# ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag' | 'up' | 'left' + +# # Left edge: all 'up' moves with j=0 (< m) → penalize each step +# for i in range(1, n + 1): +# dp_cost[i][0] = i * gap_cost +# dp_pen[i][0] = i # i early 'up' moves +# ptr[i][0] = "up" + +# # Top edge: all 'left' moves with i=0 (< n) → penalize each step +# for j in range(1, m + 1): +# dp_cost[0][j] = j * gap_cost +# dp_pen[0][j] = j # j early 'left' moves +# ptr[0][j] = "left" + +# for i in range(1, n + 1): +# xi = x0[i - 1] +# for j in range(1, m + 1): +# yj = x1[j - 1] + +# # diag +# cost_diag = dp_cost[i - 1][j - 1] + (0 if xi == yj else sub_cost) +# pen_diag = dp_pen[i - 1][j - 1] +# cand_diag = (cost_diag, pen_diag) + +# # up: add blank to z1, penalize if j < m (early) +# cost_up = dp_cost[i - 1][j] + gap_cost +# pen_up = dp_pen[i - 1][j] + (1 if j < m else 0) +# cand_up = (cost_up, pen_up) + +# # left: add blank to z0, penalize if i < n (early) +# cost_left = dp_cost[i][j - 1] + gap_cost +# pen_left = dp_pen[i][j - 1] + (1 if i < n else 0) +# cand_left = (cost_left, pen_left) + +# # choose (cost,pen) min; deterministic tie-break: diag > left > up +# best = min(cand_diag, cand_left, cand_up) +# dp_cost[i][j], dp_pen[i][j] = best +# if best == cand_diag: +# ptr[i][j] = "diag" +# elif best == cand_left: +# ptr[i][j] = "left" +# else: +# ptr[i][j] = "up" + +# # traceback +# z0, z1 = [], [] +# i, j = n, m +# while i > 0 or j > 0: +# p = ptr[i][j] +# if p == "diag": +# z0.append(x0[i - 1]) +# z1.append(x1[j - 1]) +# i -= 1 +# j -= 1 +# elif p == "up": +# z0.append(x0[i - 1]) +# z1.append(BLANK) +# i -= 1 +# else: # 'left' +# z0.append(BLANK) +# z1.append(x1[j - 1]) +# j -= 1 + +# z0.reverse() +# z1.reverse() +# return dict(z0=z0, z1=z1) + + +def strip_blanks(z: list[int]) -> list[int]: + # IMPORTANT: do NOT strip BOS; we only remove BLANKs + return [t for t in z if t != BLANK] + + +@dataclass +class Edit: + kind: str # "SUB" | "DEL" | "INS" + pos: int # position (for SUB/DEL) or token-row idx for INS (incl. BOS row 0) + token: int | None # token for SUB/INS, else None + + +def build_remaining_edits(zt: list[int], z1: list[int]) -> list[Edit]: + edits: list[Edit] = [] + + def count_nonblank_prefix(z: list[int], j: int) -> int: + c = 0 + for k in range(j): + if z[k] != BLANK: + c += 1 + return c + + for j, (a, b) in enumerate(zip(zt, z1)): + if a == b: + continue + nb = count_nonblank_prefix( + zt, j + ) # counts BOS as 1, first content token will be nb=1 before its column + + if a == BLANK and b != BLANK: + # INSERT after row (nb-1): BOS insert => nb=1 -> gap=0; general case works too + gap = max(nb - 1, 0) + edits.append(Edit("INS", gap, b)) + + elif a != BLANK and b == BLANK: + # DELETE token at row nb (first content token => nb=1, allowed; BOS is never BLANK so nb>=1) + pos = nb + # if pos > 0: # forbid BOS (row 0) + edits.append(Edit("DEL", pos, None)) + + else: # a != BLANK, b != BLANK, a != b + # SUB token at row nb + pos = nb + # if pos > 0: # forbid BOS (row 0) + edits.append(Edit("SUB", pos, b)) + return edits + + +class EditFlowTrainer(transformers.Trainer): + """ + Trainer for Edit Flows where the model returns: + - sub_logits: [B,L,V] (token dist for SUB) + - ins_logits: [B,L,V] (token dist for INS) + - sub_rate_hat: [B,L] (normalized rates; NO kappa factor) + - del_rate_hat: [B,L] + - ins_rate_hat: [B,L] + True intensities are w * rate_hat, with w = kappa_dot(t) / (1 - kappa(t)). + """ + + def __init__( + self, + *args, + scheduler: BaseKappaScheduler | None = None, + normalize_per_position: bool = True, + time_epsilon: float = 1e-3, + max_w: float | None = None, + **kwargs, + ): + self.scheduler = scheduler or CubicKappaScheduler() + self.normalize_per_position = normalize_per_position + self.time_epsilon = time_epsilon + self.max_w = max_w + super().__init__(*args, **kwargs) + + def compute_loss( + self, + model: transformers.PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs: bool = False, + **kwargs, + ): + device = self.model.device + B = len(inputs["x0_ids"]) + + # -------- 1) Align with blanks (z0,z1) and sample time t -------- + aligns = [ + align_with_blanks(x0, x1) + for x0, x1 in zip(inputs["x0_ids"], inputs["x1_ids"]) + ] + z0_list = [a["z0"] for a in aligns] + z1_list = [a["z1"] for a in aligns] + assert all(len(z0) == len(z1) for z0, z1 in zip(z0_list, z1_list)) + assert all(z0[0] != BLANK for z0 in z0_list) # BOS must remain + + t = (1 - self.time_epsilon) * torch.rand(B, 1, device=device) # [B,1] + k = self.scheduler.kappa(t).to(device) # [B,1] + w = self.scheduler.weight(t).squeeze(1).to(device) # [B] + if self.max_w: + w = w.clamp(max=self.max_w) + + # -------- 2) Sample z_t by κ-mixing (vectorized per example) -------- + # Keep python lists -> tensors per-example to reuse build_remaining_edits + zt_list: list[list[int]] = [] + for z0, z1, kb in zip(z0_list, z1_list, k.squeeze(1).tolist()): + # per-column Bernoulli(κ) mix; BOS is equal in z0/z1 so it stays BOS + choose_target = torch.rand(len(z0)) < kb + zt = [b if choose_target[j] else a for j, (a, b) in enumerate(zip(z0, z1))] + zt_list.append(zt) + + # -------- 3) Strip blanks to x_t and compute remaining edits -------- + xt_list = [strip_blanks(zt) for zt in zt_list] + edits_list: list[list[Edit]] = [ + build_remaining_edits(zt, z1) for zt, z1 in zip(zt_list, z1_list) + ] + + # -------- 4) Collate x_t for the model -------- + x_tok, x_mask = pad_1d( + xt_list, pad_val=self.processing_class.pad_token_id + ) # [B,Lmax], [B,Lmax] + x_tok = x_tok.to(device) + x_mask = x_mask.to(device) + + # -------- 5) Forward pass -------- + out = model(input_ids=x_tok, attention_mask=x_mask, t=t.to(device)) + # Rename for clarity: model returns normalized rates (no kappa) + sub_rate_hat = out["sub_rate_hat"] # [B,L] + del_rate_hat = out["del_rate_hat"] # [B,L] + ins_rate_hat = out["ins_rate_hat"] # [B,L] + logQ_sub = F.log_softmax(out["sub_logits"], dim=-1) # [B,L,V] + logQ_ins = F.log_softmax(out["ins_logits"], dim=-1) # [B,L,V] + + # *** NEW: zero-cost anchor to "touch" every head even if unused this step *** + # Using .sum() * 0.0 keeps a graph dependency without changing the loss value. + # Include both logits (for SUB/INS heads) and rates (for SUB/DEL/INS heads). + # This is important for Deepspeed ZeRO stage 2/3 to avoid skipping unused parameters. + anchor = ( + sub_rate_hat.sum() * 0.0 + + del_rate_hat.sum() * 0.0 + + ins_rate_hat.sum() * 0.0 + + logQ_sub.sum() * 0.0 + + logQ_ins.sum() * 0.0 + ) + + # Utility + def safe_log(x: torch.Tensor) -> torch.Tensor: + return torch.log(x.clamp_min(1e-12)) + + # -------- 6) Survival term -------- + # Survival = E[sum of true intensities over valid rows] + # true intensity = w[b] * rate_hat[b, i] + mask_f = x_mask.float() + # L = mask_f.sum(dim=1).clamp_min(1.0) # [B] number of positions (incl. BOS) + L1 = torch.tensor( + [len(x1) for x1 in inputs["x1_ids"]], device=device, dtype=torch.float + ).clamp_min(1.0) + denom = L1 if self.normalize_per_position else torch.ones_like(L1) + + Lambda_hat = ((sub_rate_hat + del_rate_hat + ins_rate_hat) * mask_f).sum( + dim=1 + ) # [B] + loss_surv = ((w * Lambda_hat) / denom).mean() + + # -------- 7) Positive edit terms -------- + # For each remaining edit e: -log true rate(e) - log token prob(e) if tokenized + # loss_pos_per = sub_rate_hat.new_zeros(B) # [B] + # for b, edits in enumerate(edits_list): + # if not edits: + # continue + # cur_len = int(x_mask[b].sum().item()) + # for e in edits: + # pos = e.pos + # assert 0 <= pos < cur_len, f"pos {pos} out of range {cur_len}" + # if e.kind == "SUB": + # loss_pos_per[b] -= logQ_sub[b, pos, e.token] + safe_log( + # sub_rate_hat[b, pos] + # ) + # elif e.kind == "DEL": + # loss_pos_per[b] -= safe_log(del_rate_hat[b, pos]) + # else: # "INS" + # loss_pos_per[b] -= logQ_ins[b, pos, e.token] + safe_log( + # ins_rate_hat[b, pos] + # ) + + # -------- 7) Positive edit terms (vectorized) -------- + pos_sub, tok_sub, pos_ins, tok_ins, pos_del = [], [], [], [], [] + for b, edits in enumerate(edits_list): + cur_len = int(x_mask[b].sum().item()) + ps, ts, pi, ti, pd = [], [], [], [], [] + for e in edits: + if not (0 <= e.pos < cur_len): + raise AssertionError( + f"pos {e.pos} out of range {cur_len} for b={b}" + ) + if e.kind == "SUB": + ps.append(e.pos) + ts.append(e.token) + elif e.kind == "INS": + pi.append(e.pos) + ti.append(e.token) + else: + pd.append(e.pos) + pos_sub.append( + torch.tensor(ps, device=x_tok.device, dtype=torch.long) if ps else None + ) + tok_sub.append( + torch.tensor(ts, device=x_tok.device, dtype=torch.long) if ts else None + ) + pos_ins.append( + torch.tensor(pi, device=x_tok.device, dtype=torch.long) if pi else None + ) + tok_ins.append( + torch.tensor(ti, device=x_tok.device, dtype=torch.long) if ti else None + ) + pos_del.append( + torch.tensor(pd, device=x_tok.device, dtype=torch.long) if pd else None + ) + + loss_pos_terms = [] + for b in range(B): + lp = x_tok.new_zeros(()) + if pos_sub[b] is not None: + lp = ( + lp + - ( + logQ_sub[b, pos_sub[b], tok_sub[b]] + + safe_log(sub_rate_hat[b, pos_sub[b]]) + ).sum() + ) + if pos_ins[b] is not None: + lp = ( + lp + - ( + logQ_ins[b, pos_ins[b], tok_ins[b]] + + safe_log(ins_rate_hat[b, pos_ins[b]]) + ).sum() + ) + if pos_del[b] is not None: + lp = lp - safe_log(del_rate_hat[b, pos_del[b]]).sum() + loss_pos_terms.append(lp) + loss_pos_per = torch.stack(loss_pos_terms) # [B] + + # # Average positive term per sequence (MC estimator across batch) + loss_pos = ((w * loss_pos_per) / denom).mean() + + # -------- 8) Total -------- + loss = loss_surv + loss_pos + anchor + return (loss, out) if return_outputs else loss + + +if __name__ == "__main__": + pass diff --git a/dllm/dllm/pipelines/editflow/utils.py b/dllm/dllm/pipelines/editflow/utils.py new file mode 100644 index 0000000..f2bf8bf --- /dev/null +++ b/dllm/dllm/pipelines/editflow/utils.py @@ -0,0 +1,218 @@ +import math +import random +from dataclasses import dataclass +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Text +from collections.abc import Callable + +import torch +import transformers + +from dllm.utils.utils import parse_spec + + +# ------------------------------- Collator (x0 source) -------------------------------- +@dataclass +class X0Sampler: + + def __call__(self, *args, **kwargs) -> list[int]: + raise NotImplementedError("Subclasses must implement __call__.") + + +@dataclass +class SampleX0Empty(X0Sampler): + """Return BOS-only (i.e., empty tail).""" + + def __call__(self, *args, **kwargs) -> list[int]: + return [] + + +@dataclass +class SampleX0Masks(X0Sampler): + """Return a run of mask tokens of given length.""" + + length: int = 128 + tokenizer: transformers.PreTrainedTokenizer = None + + def __call__(self, *args, **kwargs) -> list[int]: + mask_id = getattr(self.tokenizer, "mask_token_id", None) + if mask_id is None: + raise ValueError("tokenizer needs mask_token_id for mask-based sampler") + return [int(mask_id)] * self.length + + +# ---------------- Factory ---------------- # +_X0_SAMPLER_CLASSES: dict[str, type[X0Sampler]] = { + "empty": SampleX0Empty, + "masks": SampleX0Masks, +} + + +def make_x0_sampler(name: str, tokenizer: Any, **kwargs) -> X0Sampler: + try: + name, kvs = parse_spec(name) + cls = _X0_SAMPLER_CLASSES[name.lower()] + except KeyError: + raise ValueError( + f"Unknown x0 sampler '{name}'. Available: {list(_X0_SAMPLER_CLASSES)}" + ) + # merged_kwargs = {**kvs, **kwargs} + return cls(tokenizer=tokenizer, **kvs, **kwargs) + + +@dataclass +class EditFlowCollator: + tokenizer: transformers.PreTrainedTokenizer = None + x0_sampler: Callable | str | None = X0Sampler # can be func OR name + + def __post_init__(self): + if isinstance(self.x0_sampler, str): + self.x0_sampler = make_x0_sampler(self.x0_sampler, self.tokenizer) + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, list[Any]]: + if not features: + return {} + + keys = features[0].keys() + batch = {k: [ex[k] for ex in features] for k in keys} + batch["x1_ids"] = batch["input_ids"] + + if "prompt_len" not in batch: + assert self.tokenizer.bos_token_id is not None + bos = self.tokenizer.bos_token_id + batch["x1_ids"] = [ + x if x and x[0] == bos else [bos] + x for x in batch["x1_ids"] + ] + batch["x0_ids"] = [ + x1_ids[:1] + self.x0_sampler(x1_ids=x1_ids[1:]) + for x1_ids in batch["x1_ids"] + ] + else: + batch["x0_ids"] = [ + x1_ids[:prompt_len] + self.x0_sampler(x1_ids=x1_ids[prompt_len:]) + for x1_ids, prompt_len in zip(batch["x1_ids"], batch["prompt_len"]) + ] + + batch["return_loss"] = True + return batch + + +# ------------------------------- Trainer utils -------------------------------- +def pad_1d( + batch_lists: list[list[int]], pad_val: int +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pads a list of variable-length integer lists into: + - out: tensor of shape [B, Lmax] with padding value `pad_val` + - mask: tensor of shape [B, Lmax] with 1 for real tokens and 0 for padding (int mask) + """ + B = len(batch_lists) + Lmax = max((len(x) for x in batch_lists), default=0) + out = torch.full((B, Lmax), pad_val, dtype=torch.long) + mask = torch.zeros((B, Lmax), dtype=torch.long) # 0/1 mask (int) + + for b, x in enumerate(batch_lists): + if not x: + continue + L = len(x) + out[b, :L] = torch.tensor(x, dtype=torch.long) + mask[b, :L] = 1 # mark valid positions with 1 + + return out, mask + + +def init_editflow_from_src( + ef_model, src_model, lm_head_key: str = "lm_head", verbose: bool = True +): + """ + Initialize an EditFlowModel (ef_model) from a pretrained source model. + + If DeepSpeed ZeRO-3 is enabled (detected via HF's `is_deepspeed_zero3_enabled()`), + this function temporarily gathers full parameters for both models on rank 0, + performs the copy there, and then returns to sharded mode automatically. + Otherwise it behaves like a normal CPU/GPU single-process copy. + + Returns (missing_keys, unexpected_keys) from load_state_dict(strict=False). + """ + import deepspeed + from transformers.integrations import is_deepspeed_zero3_enabled + + dist_ok = torch.distributed.is_available() and torch.distributed.is_initialized() + rank = torch.distributed.get_rank() if dist_ok else 0 + + def _copy_once(): + src_sd = src_model.state_dict() + tgt_sd = ef_model.state_dict() + new_sd = OrderedDict() + + # 1) copy matching backbone tensors + for k, v in src_sd.items(): + if k in tgt_sd and tgt_sd[k].shape == v.shape: + new_sd[k] = v + + # 2) duplicate lm_head -> sub_logits & ins_logits (weight + optional bias) + lm_w = f"{lm_head_key}.weight" + lm_b = f"{lm_head_key}.bias" + + if lm_w in src_sd: + if "sub_logits.weight" in tgt_sd: + new_sd["sub_logits.weight"] = src_sd[lm_w] + if "ins_logits.weight" in tgt_sd: + new_sd["ins_logits.weight"] = src_sd[lm_w] + if lm_b in src_sd: + if "sub_logits.bias" in tgt_sd: + new_sd["sub_logits.bias"] = src_sd[lm_b] + if "ins_logits.bias" in tgt_sd: + new_sd["ins_logits.bias"] = src_sd[lm_b] + + # 3) non-strict load so new rate heads remain randomly initialized + missing, unexpected = ef_model.load_state_dict(new_sd, strict=False) + return new_sd, missing, unexpected + + if is_deepspeed_zero3_enabled(): + # All ranks enter/exit together; only rank 0 materializes full tensors. + params = list(ef_model.parameters()) + list(src_model.parameters()) + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + if rank == 0: + new_sd, missing, unexpected = _copy_once() + else: + new_sd, missing, unexpected = OrderedDict(), [], [] + + if dist_ok: + torch.distributed.barrier() + + if verbose and rank == 0: + _p = getattr(globals().get("dllm", None), "utils", None) + printer = getattr(_p, "print_main", print) if _p else print + printer( + f"[EditFlow init][ZeRO-3] Copied {len(new_sd)} tensors from Src Model." + ) + if missing: + printer(" Missing (expected for new rate heads, etc.):") + for k in missing: + printer(" -", k) + if unexpected: + printer(" Unexpected (check key names):") + for k in unexpected: + printer(" -", k) + return missing, unexpected + + # --- Non-ZeRO (or DS not present) path --- + new_sd, missing, unexpected = _copy_once() + if verbose: + _p = getattr(globals().get("dllm", None), "utils", None) + printer = getattr(_p, "print_main", print) if _p else print + printer(f"[EditFlow init] Copied {len(new_sd)} tensors from Src Model.") + if missing: + printer(" Missing (expected for new rate heads, etc.):") + for k in missing: + printer(" -", k) + if unexpected: + printer(" Unexpected (check key names):") + for k in unexpected: + printer(" -", k) + return missing, unexpected + + +if __name__ == "__main__": + pass diff --git a/dllm/dllm/pipelines/llada/__init__.py b/dllm/dllm/pipelines/llada/__init__.py new file mode 100644 index 0000000..7fcca10 --- /dev/null +++ b/dllm/dllm/pipelines/llada/__init__.py @@ -0,0 +1,7 @@ +from . import generator, trainer +from .models.modeling_llada import LLaDAModelLM +from .models.configuration_llada import LLaDAConfig +from .models.modeling_lladamoe import LLaDAMoEModelLM +from .models.configuration_lladamoe import LLaDAMoEConfig +from .generator import LLaDAGeneratorConfig, LLaDAGenerator +from .trainer import LLaDATrainer diff --git a/dllm/dllm/pipelines/llada/eval.py b/dllm/dllm/pipelines/llada/eval.py new file mode 100644 index 0000000..fdbf2de --- /dev/null +++ b/dllm/dllm/pipelines/llada/eval.py @@ -0,0 +1,357 @@ +""" +accelerate launch \ + --num_processes 2 \ + dllm/pipelines/llada/eval.py \ + --tasks gsm8k \ + --model llada \ + --num_fewshot 8 \ + --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" +""" + +from types import SimpleNamespace +from dataclasses import dataclass + +import accelerate +import torch +import torch.nn.functional as F +from datasets import Dataset +from tqdm import tqdm +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from lm_eval.models.utils import get_dtype + +import dllm +from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig + + +@dataclass +class LLaDAEvalConfig(LLaDAGeneratorConfig): + max_new_tokens: int = 1024 + max_length: int = 4096 + steps: int = 1024 + block_length: int = 1024 + + pretrained: str = "" + dtype: str | torch.dtype = "auto" + batch_size: int = 32 + mc_num: int = 128 + is_check_greedy: bool = True + device: str = "cuda" + + +@register_model("llada") +class LLaDAEvalHarness(LM): + def __init__( + self, + config: LLaDAEvalConfig | None = None, + **kwargs, + ): + super().__init__() + if config is None: + config = LLaDAEvalConfig() + + # Pull args from config, allow kwargs to override + pretrained = kwargs.get("pretrained", config.pretrained) + dtype = kwargs.get("dtype", config.dtype) + batch_size = kwargs.get("batch_size", config.batch_size) + mc_num = kwargs.get("mc_num", config.mc_num) + is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy) + device = kwargs.get("device", config.device) + cfg = kwargs.get("cfg", config.cfg_scale) + steps = kwargs.get("steps", config.steps) + max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens) + block_length = kwargs.get("block_length", config.block_length) + max_length = kwargs.get("max_length", config.max_length) + remasking = kwargs.get("remasking", config.remasking) + + accelerator = accelerate.Accelerator() + + # Get GLOBAL rank from torch.distributed (not accelerator) + if torch.distributed.is_initialized(): + self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15) + self._world_size = ( + torch.distributed.get_world_size() + ) # ← GLOBAL world size (16) + else: + self._rank = 0 + self._world_size = 1 + + # Use accelerator for device placement + self.model = dllm.utils.get_model( + SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype)) + ) + self.model.eval() + + if accelerator.num_processes > 1: + # Let accelerator handle device placement + self.model = accelerator.prepare(self.model) + self.device = ( + accelerator.device + ) # ← Accelerator figures out local device correctly + self.accelerator = accelerator + else: + # Single GPU + self.model = self.model.to(device) + self.device = torch.device(device) + self.accelerator = None + + self.tokenizer = dllm.utils.get_tokenizer( + SimpleNamespace(model_name_or_path=pretrained, model=self.model) + ) + + # generation params + self.mask_id = self.tokenizer.mask_token_id + self.batch_size = int(batch_size) + self.max_length = max_length + self.max_new_tokens = int(max_new_tokens) + self.block_length = int(block_length) + self.steps = int(steps) + self.cfg = float(cfg) + self.remasking = remasking + self.is_check_greedy = is_check_greedy + + # loglikelihood params + self.mc_num = int(mc_num) + assert mc_num % self.batch_size == 0 + self.sampling_eps = 0.0 + + def apply_chat_template( + self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + return chat_templated + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def _forward_process( + self, batch: torch.Tensor, prompt_index: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + b, l = batch.shape + + target_len = (l - prompt_index.sum()).item() + k = torch.randint(1, target_len + 1, (), device=batch.device) + + x = torch.round( + torch.linspace( + float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device + ) + ).long() + x = ((x - 1) % target_len) + 1 + assert x.min() >= 1 and x.max() <= target_len + + indices = torch.arange(target_len, device=batch.device).repeat(b, 1) + is_mask = indices < x.unsqueeze(1) + + for i in range(b): + is_mask[i] = is_mask[i][torch.randperm(target_len)] + + is_mask = torch.cat( + ( + torch.zeros( + b, prompt_index.sum(), dtype=torch.bool, device=batch.device + ), + is_mask, + ), + dim=1, + ) + + noisy_batch = torch.where(is_mask, self.mask_id, batch) + + return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l) + + @torch.no_grad() + def get_logits( + self, batch: torch.Tensor, prompt_index: torch.Tensor + ) -> torch.Tensor: + if self.cfg > 0.0: + assert len(prompt_index) == batch.shape[1] + prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) + un_batch = batch.clone() + un_batch[prompt_index] = self.mask_id + batch = torch.cat([batch, un_batch]) + + logits = self.model(batch).logits + + if self.cfg > 0.0: + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (self.cfg + 1) * (logits - un_logits) + return logits[:, : batch.shape[1]] + + @torch.no_grad() + def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float: + seq = torch.concatenate([prefix, target])[None, :] + seq = seq.repeat((self.batch_size, 1)).to(self.device) + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + + loss_acc = [] + for _ in range(self.mc_num // self.batch_size): + perturbed_seq, p_mask = self._forward_process(seq, prompt_index) + + mask_indices = perturbed_seq == self.mask_id + + logits = self.get_logits(perturbed_seq, prompt_index) + + loss = ( + F.cross_entropy( + logits[mask_indices], seq[mask_indices], reduction="none" + ) + / p_mask[mask_indices] + ) + loss = loss.sum() / self.batch_size + loss_acc.append(loss.item()) + + return -sum(loss_acc) / len(loss_acc) + + @torch.no_grad() + def suffix_greedy_prediction( + self, prefix: torch.Tensor, target: torch.Tensor + ) -> bool: + if not self.is_check_greedy: + return False + + seq = torch.full( + (1, len(prefix) + len(target)), self.mask_id, device=self.device + ) + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + prefix, target = prefix.to(self.device), target.to(self.device) + seq[0, : len(prefix)] = prefix + + for i in range(len(target)): + mask_index = seq == self.mask_id + logits = self.get_logits(seq, prompt_index)[mask_index] + x0 = torch.argmax(logits, dim=-1) + + p = torch.softmax(logits.to(torch.float32), dim=-1) + confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze( + dim=-1 + ) + _, index = torch.sort(confidence, descending=True) + x0[index[1:]] = self.mask_id + seq[mask_index] = x0.clone() + correct = target == seq[0, len(prefix) :] + correct = torch.all(correct) + return correct + + def _encode_pair( + self, context: str, continuation: str + ) -> tuple[torch.Tensor, torch.Tensor]: + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tokenizer(context + continuation)["input_ids"] + context_enc = self.tokenizer(context)["input_ids"] + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: + def _tokenize(e): + prefix, target = self._encode_pair(e["prefix"], e["target"]) + return { + "prefix_text": e["prefix"], + "target_text": e["target"], + "prefix": prefix, + "target": target, + } + + ds = [] + ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] + + assert max(prompt_len) <= 4096 + + out = [] + with torch.no_grad(): + for elem in tqdm(ds, desc="Computing likelihood..."): + prefix = elem["prefix"] + target = elem["target"] + + ll = self.get_loglikelihood(prefix, target) + + is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) + + out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) + torch.cuda.empty_cache() + return out + + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: + raise NotImplementedError + + def generate_until(self, requests: list[Instance]) -> list[str]: + def _tokenize(e): + return { + "question": self.tokenizer(e["question"])["input_ids"], + "question_text": e["question"], + "until": e["until"], + } + + ds = [ + {"question": req.args[0], "until": req.args[1]["until"]} for req in requests + ] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + + out = [] + generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer) + + for elem in tqdm(ds, desc="Generating..."): + prompt = [elem["question"].to(self.device)] + stop_tokens = elem["until"] + generated_ids = generator.generate( + inputs=prompt, + steps=self.steps, + max_new_tokens=self.max_new_tokens, + block_length=self.block_length, + temperature=0.0, + cfg_scale=self.cfg, + remasking=self.remasking, + ) + generated_answer = self.tokenizer.decode( + generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False + ) + for stop_seq in stop_tokens: + if stop_seq in generated_answer: + generated_answer = generated_answer.split(stop_seq)[0] + + # remove special tokens + generated_answer_ids = self.tokenizer(generated_answer)["input_ids"] + generated_answer = self.tokenizer.decode( + generated_answer_ids, skip_special_tokens=True + ) + out.append(generated_answer) + if self.accelerator is not None: + self.accelerator.wait_for_everyone() + + return out + + +if __name__ == "__main__": + cli_evaluate() diff --git a/dllm/dllm/pipelines/llada/generator.py b/dllm/dllm/pipelines/llada/generator.py new file mode 100644 index 0000000..9247af2 --- /dev/null +++ b/dllm/dllm/pipelines/llada/generator.py @@ -0,0 +1,379 @@ +""" +reference: https://github.com/ML-GSAI/LLaDA/blob/main/generate.py +""" + +import math +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F + +from dllm.utils.generation_utils import get_num_transfer_tokens +from dllm.core.generation.generator import ( + GeneratorOutput, + GeneratorConfig, + BaseGenerator, +) + + +def add_gumbel_noise(logits: torch.Tensor, temperature: float) -> torch.Tensor: + """ + The Gumbel max is a method for sampling categorical distributions. + According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. + Thus, we use float64. + """ + if temperature == 0: + return logits + logits = logits.to(torch.float64) + noise = torch.rand_like(logits, dtype=torch.float64) + gumbel_noise = (-torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +@dataclass +class LLaDAGeneratorConfig(GeneratorConfig): + max_new_tokens: int = 128 + max_length: int = ( + None # There's no explicit length_limit except for the tokenizer/model context + ) + block_length: int = 128 + steps: int = 128 + temperature: float = 0.0 + remasking: str = "low_confidence" + stochastic_transfer: bool = False + cfg_scale: float = 0.0 + cfg_keep_tokens: list[int] | None = None + + +@dataclass +class LLaDAGenerator(BaseGenerator): + @torch.no_grad() + def generate( + self, + inputs: list[torch.Tensor | list], + config: LLaDAGeneratorConfig | None = None, + **kwargs, + ) -> GeneratorOutput | torch.Tensor: + if config is None: + config = LLaDAGeneratorConfig() + + # ----- pull args from config, allow kwargs to override ----- + steps = kwargs.get("steps", config.steps) + max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens) + max_length = kwargs.get("max_length", config.max_length) + block_length = kwargs.get("block_length", config.block_length) + temperature = kwargs.get("temperature", config.temperature) + cfg_scale = kwargs.get("cfg_scale", config.cfg_scale) + cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens) + remasking = kwargs.get("remasking", config.remasking) + stochastic_transfer = kwargs.get( + "stochastic_transfer", config.stochastic_transfer + ) + return_dict_in_generate = kwargs.get( + "return_dict_in_generate", config.return_dict_in_generate + ) + + assert 1 <= block_length + assert 1 <= steps + mask_id = self.tokenizer.mask_token_id + eos_id = self.tokenizer.eos_token_id + + # ----- Shape bookkeeping: per-sample prompt lengths and final canvas width ----- + if isinstance(inputs[0], list): + inputs = [ + torch.as_tensor(p, dtype=torch.long, device=self.model.device) + for p in inputs + ] + prompt_lens = [p.shape[0] for p in inputs] + + if max_new_tokens: + max_length = max_new_tokens + max(prompt_lens) + else: + max_new_tokens = max_length - max(prompt_lens) + + B = len(inputs) + T = max_length + + # ----- Initialize canvas with EOS, copy inputs, and append mask tail ----- + x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device) + for i, p in enumerate(inputs): + x[i, : prompt_lens[i]] = p # keep original prompt tokens + x[i, prompt_lens[i] : prompt_lens[i] + max_new_tokens] = ( + mask_id # append `max_new_tokens` masks to be generated + ) + attention_mask = (x != eos_id).long() if B > 1 else None + + # Tokens that were *given* at the start (non-mask, non-EOS). + # These will be masked in the unconditional forward pass for CFG. + # Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG + unmasked_index = (x != mask_id) & (x != eos_id) + if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0): + keep_mask = torch.isin( + x, torch.as_tensor(cfg_keep_tokens, device=self.model.device) + ) + unmasked_index = unmasked_index & ~keep_mask + + # ----- Block scheduling over the appended mask tail ----- + num_blocks = math.ceil(max_new_tokens / block_length) + steps = math.ceil(steps / num_blocks) # per-block step budget + histories = [x.clone()] if return_dict_in_generate else None + + for b in range(num_blocks): + # Build a per-sample mask *within this block* (aligned to each prompt's tail) + block_mask_index = torch.zeros( + (B, block_length), dtype=torch.bool, device=x.device + ) + + for j in range(B): + start = prompt_lens[j] + b * block_length + end = min(start + block_length, prompt_lens[j] + max_new_tokens, T) + if start < end: + width = end - start + block_mask_index[j, :width] = ( + x[j, start:end] == mask_id + ) # which positions in this block are still masked + + # Decide how many tokens to reveal per step in this block + num_transfer_tokens = get_num_transfer_tokens( + mask_index=block_mask_index, + steps=steps, + scheduler=self.scheduler, + stochastic=stochastic_transfer, + ) + + # Some steps may be skipped if there are no transfers + effective_steps = num_transfer_tokens.size(1) + + # ----- Iterative reveal inside the current block ----- + for i in range(effective_steps): + mask_index = x == mask_id # current global mask map + + # Optional CFG: second forward where original prompt tokens are masked out + if cfg_scale > 0.0: + un_x = x.clone() + un_x[unmasked_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self.model( + x_, attention_mask=attention_mask + ).logits # Use attention mask here + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self.model( + x, attention_mask=attention_mask + ).logits # Use attention mask here + + # Argmax decoding with optional Gumbel-Max noise for exploration + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax( + logits_with_noise, dim=-1 + ) # [B, T] predicted token ids + + # Per-position confidence used to pick which masks to commit this step + if remasking == "low_confidence": + p = F.softmax(logits, dim=-1) + x0_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1 + ) # [B, T] confidence of predicted token + elif remasking == "random": + x0_p = torch.rand( + (x0.shape[0], x0.shape[1]), device=x0.device + ) # random scores + else: + raise NotImplementedError(remasking) + + # Restrict selection window to the *current block's* tail region + for j in range(B): + x0_p[j, prompt_lens[j] + (b + 1) * block_length :] = -np.inf + + # Only allow updates at currently masked positions; keep others fixed + x0 = torch.where(mask_index, x0, x) + confidence = torch.where( + mask_index, x0_p, -np.inf + ) # consider masked positions only + + # Pick exactly `num_transfer_tokens[j, i]` highest-confidence positions per sample + transfer_index = torch.zeros_like( + x0, dtype=torch.bool, device=x0.device + ) + for j in range(confidence.shape[0]): + _, select_index = torch.topk( + confidence[j], k=num_transfer_tokens[j, i] + ) + transfer_index[j, select_index] = True + + # Commit chosen predictions into the canvas + x[transfer_index] = x0[transfer_index] + if histories is not None: + histories.append(x.clone()) + + # ----- Output format ----- + if not return_dict_in_generate: + return x + else: + return GeneratorOutput(sequences=x, histories=histories) + + @torch.no_grad() + def infill( + self, inputs: list[torch.Tensor | list], config, **kwargs + ) -> GeneratorOutput | torch.Tensor: + """ + Fill in-place the <|mdm_mask|> tokens contained in `inputs`. + The whole (padded) sequence is split into block windows of length + `block_length`; within each window we progressively "unmask" positions + according to the scheduler and chosen remasking strategy. + + Notes: + - Right padding uses EOS. + - CFG masks out *originally known* (non-mask, non-EOS) tokens in the + unconditional branch, identical to `generate`. + - Only masked positions are ever updated; non-mask tokens are left intact. + """ + # ----- pull args from config, allow kwargs to override ----- + steps = kwargs.get("steps", config.steps) + block_length = kwargs.get("block_length", config.block_length) + temperature = kwargs.get("temperature", config.temperature) + cfg_scale = kwargs.get("cfg_scale", config.cfg_scale) + cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens) + remasking = kwargs.get("remasking", config.remasking) + stochastic_transfer = kwargs.get( + "stochastic_transfer", config.stochastic_transfer + ) + return_dict_in_generate = kwargs.get( + "return_dict_in_generate", config.return_dict_in_generate + ) + + mask_id = self.tokenizer.mask_token_id + eos_id = self.tokenizer.eos_token_id + + # ----- Build canvas: right-pad with EOS to the max length in the batch ----- + if isinstance(inputs[0], list): + inputs = [ + torch.as_tensor(p, dtype=torch.long, device=self.model.device) + for p in inputs + ] + + B = len(inputs) + seq_lens = [t.shape[0] for t in inputs] + T = max(seq_lens) + + # Default to a single block spanning the whole sequence + if block_length is None: + block_length = T + + assert 1 <= block_length + assert 1 <= steps + + x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device) + for i, t in enumerate(inputs): + x[i, : seq_lens[i]] = t + attention_mask = (x != eos_id).long() if B > 1 else None + + # Tokens that were *given* at the start (non-mask, non-EOS). + # These will be masked in the unconditional forward pass for CFG. + # Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG + unmasked_index = (x != mask_id) & (x != eos_id) + if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0): + keep_mask = torch.isin( + x, torch.as_tensor(cfg_keep_tokens, device=self.model.device) + ) + unmasked_index = unmasked_index & ~keep_mask + + # ----- Blockwise schedule over the *entire* (padded) sequence ----- + num_blocks = math.ceil(T / block_length) + steps_per_block = math.ceil(steps / num_blocks) + histories = [x.clone()] if return_dict_in_generate else None + + # Create attention mask where eos_token_id is masked (set to 0) + attention_mask = (x != eos_id).long() + + for b in range(num_blocks): + start = b * block_length + stop = min(start + block_length, T) + + # Per-sample view of which positions in this block are masks + block_mask_index = torch.zeros( + (B, block_length), dtype=torch.bool, device=self.model.device + ) + widths = [] + for j in range(B): + # Width limited by sample's true length and sequence end + width = max(0, min(seq_lens[j], stop) - start) + widths.append(width) + if width > 0: + block_mask_index[j, :width] = x[j, start : start + width] == mask_id + + # Decide how many tokens to reveal at each step in this block + num_transfer_tokens = get_num_transfer_tokens( + mask_index=block_mask_index, + steps=steps_per_block, + scheduler=self.scheduler, + stochastic=stochastic_transfer, + ) + + # Some blocks may have no masks => effective_steps == 0 + effective_steps = num_transfer_tokens.size(1) + + for s in range(effective_steps): + mask_index_full = x == mask_id + + # ----- Forward pass (+ optional CFG) ----- + if cfg_scale > 0.0: + un_x = x.clone() + un_x[unmasked_index] = mask_id + x_ = torch.cat([x, un_x], dim=0) + logits = self.model( + x_, attention_mask=attention_mask + ).logits # Use attention mask here + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (cfg_scale + 1) * (logits - un_logits) + else: + logits = self.model( + x, attention_mask=attention_mask + ).logits # Use attention mask here + + # Greedy with optional Gumbel-Max noise + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0 = torch.argmax(logits_with_noise, dim=-1) # [B, T] + + # Confidence used for choosing which masks to commit this step + if remasking == "low_confidence": + p = F.softmax(logits, dim=-1) + x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze( + -1 + ) # [B, T] + elif remasking == "random": + x0_p = torch.rand((B, T), device=self.model.device) + else: + raise NotImplementedError(remasking) + + # Restrict selection to the *current* block only + for j in range(B): + end_j = start + widths[j] + # Outside current block => impossible to select + x0_p[j, :start] = -np.inf + x0_p[j, end_j:] = -np.inf + + # Only consider currently-masked positions as candidates + x0 = torch.where(mask_index_full, x0, x) + confidence = torch.where(mask_index_full, x0_p, -np.inf) + + # Pick exactly num_transfer_tokens[j, s] positions per sample + transfer_index = torch.zeros_like(x, dtype=torch.bool) + for j in range(B): + k = int(num_transfer_tokens[j, s].item()) + if k > 0: + _, select_idx = torch.topk(confidence[j], k=k) + transfer_index[j, select_idx] = True + + # Commit selected predictions into the canvas + x[transfer_index] = x0[transfer_index] + if histories is not None: + histories.append(x.clone()) + + # ----- Output format ----- + if not return_dict_in_generate: + return x + else: + return GeneratorOutput(sequences=x, histories=histories) diff --git a/dllm/dllm/pipelines/llada/models/__init__.py b/dllm/dllm/pipelines/llada/models/__init__.py new file mode 100644 index 0000000..918f25d --- /dev/null +++ b/dllm/dllm/pipelines/llada/models/__init__.py @@ -0,0 +1,19 @@ +from .configuration_llada import LLaDAConfig +from .modeling_llada import LLaDAModelLM +from .configuration_lladamoe import LLaDAMoEConfig +from .modeling_lladamoe import LLaDAMoEModelLM + +# Register with HuggingFace Auto classes for local usage +try: + from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM + + AutoConfig.register("llada", LLaDAConfig) + AutoModel.register(LLaDAConfig, LLaDAModelLM) + AutoModelForMaskedLM.register(LLaDAConfig, LLaDAModelLM) + + AutoConfig.register("lladamoe", LLaDAMoEConfig) + AutoModel.register(LLaDAMoEConfig, LLaDAMoEModelLM) + AutoModelForMaskedLM.register(LLaDAMoEConfig, LLaDAMoEModelLM) +except ImportError: + # transformers not available or Auto classes not imported + pass diff --git a/dllm/dllm/pipelines/llada/models/configuration_llada.py b/dllm/dllm/pipelines/llada/models/configuration_llada.py new file mode 100644 index 0000000..e58b71d --- /dev/null +++ b/dllm/dllm/pipelines/llada/models/configuration_llada.py @@ -0,0 +1,459 @@ +""" +LLaDA configuration +""" +from transformers import PretrainedConfig + +from enum import Enum +from os import PathLike +from typing import Union +from dataclasses import asdict, dataclass, field +from glob import glob +from pathlib import Path +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) + + +__all__ = [ + "ActivationType", + "ActivationCheckpointingStrategy", + "BlockType", + "LayerNormType", + "InitFnType", + "ModelConfig", +] + +PathOrStr = Union[str, PathLike] + + +class StrEnum(str, Enum): + """ + This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. + We include this here for compatibility with older version of Python. + """ + + def __str__(self) -> str: + return self.value + + def __repr__(self) -> str: + return f"'{str(self)}'" + + +class LayerNormType(StrEnum): + default = "default" + """ + The default LayerNorm implementation, equivalent to PyTorch's built-in version. + """ + + low_precision = "low_precision" + """ + A low-precision version of the default LayerNorm. + """ + + rms = "rms" + """ + An RMSNorm implementation. When using ``torch.compile`` this is + probably the fastest implementation. + """ + + gemma_rms = "gemma_rms" + """ + An RMSNorm implementation by gemmma. When using ``torch.compile`` this is + probably the fastest implementation. + """ + + amd_compatible = "amd_compatible" + """ + LayerNorm implemented manually to work around an issue with ROCm. + """ + + +class ActivationType(StrEnum): + gelu = "gelu" + relu = "relu" + silu = "silu" + swiglu = "swiglu" + + +class BlockType(StrEnum): + sequential = "sequential" + parallel = "parallel" + + llama = "llama" + """ + A block similar to the sequential block with slightly different + implementations of operations like attention to imitate the behavior of Llama. + """ + + +class InitFnType(StrEnum): + mitchell = "mitchell" + """ + The strategy suggested to us by Mitchell Wortsman from UW. + This uses a truncated normal distribution with an adaptive standard deviation that depends + on the size of the weights as well as the depth of the layer. + """ + + normal = "normal" + """ + All weights are initialized from the same normal distribution. + """ + + kaiming_normal = "kaiming_normal" + """ + All weights are initialized with the Kaiming method from a normal distribution. + Note this currently won't work with FSDP. + """ + + fan_in = "fan_in" + """ + "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` + is the input dimensionality of the kernel. + """ + + full_megatron = "full_megatron" + """ + This is what metaseq calls "full megatron init". It is the init used for Llama 2. + """ + + +@dataclass +class ModelConfig(): + """ + LLaDA (model) configuration. + """ + + # Note that the defaults for these attributes are equivalent to the base GPT2 model. + + d_model: int = 768 + """ + The hidden size of the model. + """ + + n_heads: int = 12 + """ + The number of self-attention heads. + """ + + n_kv_heads: Optional[int] = None + """ + The number of heads to use for keys and values. Defaults to `n_heads`. + Set this to ``None`` or ``n_heads`` for normal multi-head attention. + Set this to 1 for multi-query attention. + Set it to some in-between value for Llama2-style grouped query attention. + """ + + n_layers: int = 12 + """ + The number of layers/blocks. + """ + + mlp_ratio: int = 4 + """ + The ratio of the inner MLP dimensionality to ``d_model``. + This is only used when ``mlp_hidden_size`` is not set. + """ + + mlp_hidden_size: Optional[int] = None + """ + Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. + """ + + activation_type: ActivationType = ActivationType.swiglu + """ + The activation function to use within the MLP layers. + """ + + block_type: BlockType = BlockType.sequential + """ + The transformer block implementation. + """ + + block_group_size: int = 1 + """ + The number of blocks to group together into a single parent block. + This has no affect on the number of parameters in the model and is only used to wrap groups + of blocks together with a single FSDP wrapper during training. + """ + + alibi: bool = False + """ + If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. + """ + + alibi_bias_max: float = 8.0 + """ + Maximum absolute value of ALiBi bias. + """ + + rope: bool = False + """ + Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. + """ + + rope_full_precision: bool = True + """ + If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, + apply RoPE at the precision of the input. + """ + + flash_attention: bool = False + """ + If ``True``, use ``FlashAttention``. + """ + + attention_dropout: float = 0.1 + """ + The dropout probability within the attention modules. + """ + + multi_query_attention: Optional[bool] = None + """ + Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters + and is more efficient during inference. + """ + + attention_layer_norm: bool = False + """ + Apply layer norm to the keys and queries within the attention mechanism. + This can help stabilize training. + """ + + residual_dropout: float = 0.1 + """ + The dropout probability for the MLP and attention output within each block. + """ + + embedding_dropout: float = 0.1 + """ + The dropout probability for embeddings. + """ + + input_emb_norm: bool = False + """ + An input hidden_states norm implementation by gemmma. + """ + + layer_norm_type: LayerNormType = LayerNormType.default + """ + The layernorm implementation to use. + """ + + layer_norm_with_affine: bool = True + """ + Whether to include bias and weight parameters for the layer norms. + This only affects layer norms that are immediately followed by a linear layer in the forward pass, + so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` + to ``False``. + """ + + rms_norm_eps: float = 1e-05 + """ + The rms layernorm eps param. + """ + + attention_layer_norm_with_affine: bool = True + """ + Toggle affine transform for the QK norms. + """ + + max_sequence_length: int = 1024 + """ + The maximum input sequence length supported by the model. + """ + + rope_theta: float = 10000.0 + """ + The rope base param. + """ + + include_qkv_bias: Optional[bool] = False + """ + Whether or not to include bias parameters in qkv linear layers. + """ + + include_bias: bool = False + """ + Whether or not to include bias parameters in linear layers. + In PaLM, they got rid of all bias terms because they found that large + models tend to have near 0 bias terms anyway. + """ + + bias_for_layer_norm: Optional[bool] = None + """ + Whether or not to include bias parameters in layer norm. + This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in + layer norm. + When this is None (the default), it inherits the setting from include_bias. + """ + + scale_logits: bool = False + """ + If ``True``, scale the output logits by ``1 / sqrt(d_model)``. + """ + + vocab_size: int = 50257 + """ + Vocabulary size of the model. + """ + + embedding_size: Optional[int] = 50304 + """ + The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default + to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the + next multiple of 128 that's greater than ``vocab_size`` can improve throughput + substantially. + """ + + weight_tying: bool = True + """ + Whether to tie output linear weights to the input embedding. + """ + + eos_token_id: int = 50256 + """ + The ID of the end-of-sentence special token. + """ + + pad_token_id: int = 50256 + """ + The ID of the token to use for padding. Defaults to the ID of the EOS token. + """ + + mask_token_id: Optional[int] = 50256 + """ + The ID of the token to use for mask token. Defaults to the ID of the EOS token. + """ + + init_device: Optional[str] = None + """ + The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". + """ + + init_fn: InitFnType = InitFnType.normal + """ + The weight initialization strategy. + """ + + init_std: float = 0.02 + """ + The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such + as "normal". + """ + + init_cutoff_factor: Optional[float] = None + """ + A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such + as "normal". Setting this to None means values are not cutoff. + """ + + precision: Optional[str] = None + """ + Precision used to train/evaluate with. You shouldn't set this directly. + See :data:`TrainConfig.precision` instead. + """ + + @property + def effective_n_kv_heads(self) -> int: + if self.n_kv_heads is None: + if self.multi_query_attention is True: + return 1 + else: + return self.n_heads + else: + if self.multi_query_attention is None: + return self.n_kv_heads + if self.multi_query_attention: + n_kv_heads_should_be = 1 + else: + n_kv_heads_should_be = self.n_heads + if self.n_kv_heads == n_kv_heads_should_be: + return n_kv_heads_should_be + else: + raise Exception( + "You can't set `multi_query_attention` and `n_kv_heads` at the same time." + ) + +class ActivationCheckpointingStrategy(StrEnum): + whole_layer = "whole_layer" + """ + Checkpoint every transformer layer. + """ + + one_in_two = "one_in_two" + """ + Checkpoint one in two transformer layers. + """ + + one_in_three = "one_in_three" + """ + Checkpoint one in three transformer layers. + """ + + one_in_four = "one_in_four" + """ + Checkpoint one in four transformer layers. + """ + + two_in_three = "two_in_three" + """ + Checkpoint two out of every three transformer layers. + """ + + three_in_four = "three_in_four" + """ + Checkpoint three out of four of every transformer layers. + """ + + four_in_five = "four_in_five" + """ + Checkpoint four out of five of every transformer layers. + """ + + nine_in_ten = "nine_in_ten" + """ + Checkpoint nine out of ten of every transformer layers. + """ + + fine_grained = "fine_grained" + """ + Focus checkpointing on where it is cheap to recompute and saves most memory. + """ + + +class LLaDAConfig(PretrainedConfig): + model_type = "llada" + keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm + + def __init__(self, use_cache: bool = False, **kwargs): + model_config = ModelConfig() + all_kwargs = model_config.__dict__ + all_kwargs.update(kwargs) + all_kwargs.update({"use_cache": use_cache}) + all_kwargs.update( + { + "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"]) + } + ) + super().__init__(**all_kwargs) + + @property + def num_attention_heads(self): + return self.n_heads + + @property + def num_hidden_layers(self): + return self.n_layers + + @property + def hidden_size(self): + return self.d_model diff --git a/dllm/dllm/pipelines/llada/models/configuration_lladamoe.py b/dllm/dllm/pipelines/llada/models/configuration_lladamoe.py new file mode 100644 index 0000000..a12c97f --- /dev/null +++ b/dllm/dllm/pipelines/llada/models/configuration_lladamoe.py @@ -0,0 +1,96 @@ +""" +LLaDA MoE configuration +""" +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class LLaDAMoEConfig(PretrainedConfig): + model_type = "lladamoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=-1, + hidden_size=-1, + dense_intermediate_size=-1, + expert_intermediate_size=-1, + shared_expert_intermediate_size=-1, + num_hidden_layers=-1, + num_attention_heads=-1, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=False, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=-1, + partial_rotary_factor=-1, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + clip_qkv=None, + num_experts_per_tok=-1, + num_experts=-1, + output_router_logits=False, + router_aux_loss_coef=0.01, + norm_topk_prob=None, + qk_layernorm=None, + moe_layer_freq=[], + moe_router_enable_expert_bias=None, + moe_router_score_function=None, + routed_scaling_factor=1, + router_num_group=-2, + router_topk_group=-2, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expert_intermediate_size = expert_intermediate_size + self.dense_intermediate_size = dense_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.clip_qkv = clip_qkv + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.norm_topk_prob = norm_topk_prob + self.qk_layernorm = qk_layernorm + self.moe_layer_freq = moe_layer_freq + self.moe_router_enable_expert_bias = moe_router_enable_expert_bias + self.moe_router_score_function = moe_router_score_function + self.partial_rotary_factor = partial_rotary_factor + self.routed_scaling_factor = routed_scaling_factor + self.router_num_group = router_num_group + self.router_topk_group = router_topk_group + + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/dllm/dllm/pipelines/llada/models/modeling_llada.py b/dllm/dllm/pipelines/llada/models/modeling_llada.py new file mode 100644 index 0000000..1c971d2 --- /dev/null +++ b/dllm/dllm/pipelines/llada/models/modeling_llada.py @@ -0,0 +1,1458 @@ +from __future__ import annotations + +import logging +import math +import sys +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import ( + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + cast, +) +from dataclasses import fields +from typing import List, Optional, Tuple, Union + +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.cache_utils import Cache + +from .configuration_llada import ( + LLaDAConfig, + StrEnum, + InitFnType, + ActivationType, + BlockType, + LayerNormType, + ModelConfig, + ActivationCheckpointingStrategy, +) + +if sys.version_info.minor > 8: + from collections.abc import MutableMapping +elif sys.version_info.minor == 8: + from typing import MutableMapping +else: + raise SystemExit("This script supports Python 3.8 or higher") + +__all__ = [ + "LayerNormBase", + "LayerNorm", + "RMSLayerNorm", + "GemmaRMSLayerNorm", + "RotaryEmbedding", + "Activation", + "GELU", + "ReLU", + "SwiGLU", + "LLaDABlock", + "LLaDASequentialBlock", + "LLaDAModel", + "LLaDAOutput", + "LLaDAGenerateOutput", +] + + +log = logging.getLogger(__name__) + + +class ModuleType(StrEnum): + in_module = "in" + out_module = "out" + emb = "emb" + final_out = "final_out" + + +def init_weights( + config: ModelConfig, + module: Union[nn.Linear, nn.Embedding], + d: Optional[int] = None, + layer_id: Optional[int] = None, + std_factor: float = 1.0, + type_of_module: Optional[ModuleType] = None, +) -> None: + """ + Initialize weights of a linear or embedding module. + + :param config: The model config. + :param module: The linear or embedding submodule to initialize. + :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions + for fused layers. + :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by + ``1 / sqrt(2 * (layer_id + 1))``. + """ + d = d if d is not None else config.d_model + if config.init_fn == InitFnType.normal: + std = config.init_std * std_factor + if config.init_cutoff_factor is not None: + cutoff_value = config.init_cutoff_factor * std + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) + else: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif config.init_fn == InitFnType.mitchell: + std = std_factor / math.sqrt(d) + if layer_id is not None: + std = std / math.sqrt(2 * (layer_id + 1)) + nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) + elif config.init_fn == InitFnType.kaiming_normal: + nn.init.kaiming_normal_(module.weight, nonlinearity="relu") + elif config.init_fn == InitFnType.fan_in: + std = std_factor / math.sqrt(d) + nn.init.normal_(module.weight, mean=0.0, std=std) + elif config.init_fn == InitFnType.full_megatron: + if type_of_module is None: + raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") + + cutoff_factor = config.init_cutoff_factor + if cutoff_factor is None: + cutoff_factor = 3 + + if type_of_module == ModuleType.in_module: + # for att_proj (same as QKV), ff_proj + std = config.init_std + elif type_of_module == ModuleType.out_module: + # for attn_out, ff_out + std = config.init_std / math.sqrt(2.0 * config.n_layers) + elif type_of_module == ModuleType.emb: + # positional embeddings (wpe) + # token embeddings (wte) + std = config.init_std + elif type_of_module == ModuleType.final_out: + # final output (ff_out) + std = config.d_model**-0.5 + else: + raise RuntimeError(f"Unknown module type '{type_of_module}'") + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-cutoff_factor * std, + b=cutoff_factor * std, + ) + else: + raise NotImplementedError(config.init_fn) + + if isinstance(module, nn.Linear): + if module.bias is not None: + nn.init.zeros_(module.bias) + + if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False): + with torch.no_grad(): + module.weight.div_(math.sqrt(2 * config.n_layers)) + + +def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): + """ + Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` + is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. + """ + if check_neg_inf: + x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) + if check_pos_inf: + x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) + + +def activation_checkpoint_function(cfg: ModelConfig): + preserve_rng_state = ( + (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0) + ) + from torch.utils.checkpoint import checkpoint + + return partial( + checkpoint, + preserve_rng_state=preserve_rng_state, + use_reentrant=False, + ) + + +class BufferCache(dict, MutableMapping[str, torch.Tensor]): + """ + Cache for attention biases and other things that would normally be stored as buffers. + We avoid using buffers because we've run into various issues doing so with FSDP. + In general it appears the way FSDP handles buffers is not well-defined. + It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid + since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into + NaNs when they're synchronized due to casting or some other issue. + """ + + +def _non_meta_init_device(config: ModelConfig) -> torch.device: + if config.init_device is not None and config.init_device != "meta": + return torch.device(config.init_device) + else: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Dropout(nn.Dropout): + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.p == 0.0: + return input + else: + return F.dropout(input, self.p, self.training, self.inplace) + + +class LayerNormBase(nn.Module): + def __init__( + self, + config: ModelConfig, + *, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + eps: float = 1e-05, + ): + super().__init__() + self.config = config + self.eps = eps + self.normalized_shape = (size or config.d_model,) + if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine): + self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device)) + use_bias = self.config.bias_for_layer_norm + if use_bias is None: + use_bias = self.config.include_bias + if use_bias: + self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device)) + else: + self.register_parameter("bias", None) + else: + self.register_parameter("bias", None) + self.register_parameter("weight", None) + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @classmethod + def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase: + if config.layer_norm_type == LayerNormType.default: + return LayerNorm(config, size=size, low_precision=False, **kwargs) + elif config.layer_norm_type == LayerNormType.low_precision: + return LayerNorm(config, size=size, low_precision=True, **kwargs) + elif config.layer_norm_type == LayerNormType.rms: + return RMSLayerNorm(config, size=size, **kwargs) + elif config.layer_norm_type == LayerNormType.gemma_rms: + return GemmaRMSLayerNorm(config, size=size, **kwargs) + else: + raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'") + + def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if tensor.device.type == "cuda" and torch.is_autocast_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype()) + elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype()) + else: + return tensor + + def reset_parameters(self): + if self.weight is not None: + torch.nn.init.ones_(self.weight) # type: ignore + if self.bias is not None: + torch.nn.init.zeros_(self.bias) # type: ignore + + +class LayerNorm(LayerNormBase): + """ + The default :class:`LayerNorm` implementation which can optionally run in low precision. + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + low_precision: bool = False, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-05, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) + self.low_precision = low_precision + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.low_precision: + module_device = x.device + downcast_x = self._cast_if_autocast_enabled(x) + downcast_weight = ( + self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + ) + downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + return F.layer_norm( + downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps + ) + else: + return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) + + +class RMSLayerNorm(LayerNormBase): + """ + RMS layer norm, a simplified :class:`LayerNorm` implementation + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-5, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.autocast(enabled=False, device_type=x.device.type): + og_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(og_dtype) + + if self.weight is not None: + if self.bias is not None: + return self.weight * x + self.bias + else: + return self.weight * x + else: + return x + + +class GemmaRMSLayerNorm(LayerNormBase): + """ + Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation + """ + + def __init__( + self, + config: ModelConfig, + size: Optional[int] = None, + elementwise_affine: Optional[bool] = None, + eps: float = 1e-5, + ): + super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.autocast(enabled=False, device_type=x.device.type): + og_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + x = x.to(og_dtype) + + if self.weight is not None: + if self.bias is not None: + return x * (1 + self.weight) + self.bias + else: + return x * (1 + self.weight) + else: + return x + + +class RotaryEmbedding(nn.Module): + """ + [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864). + """ + + def __init__(self, config: ModelConfig, cache: BufferCache): + super().__init__() + self.config = config + self.__cache = cache + # Warm up cache. + self.rope_theta = config.rope_theta + self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config)) + + def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + if ( + (pos_sin := self.__cache.get("rope_pos_sin")) is not None + and (pos_cos := self.__cache.get("rope_pos_cos")) is not None + and pos_sin.shape[-2] >= seq_len + and pos_cos.shape[-2] >= seq_len + ): + if pos_sin.device != device: + pos_sin = pos_sin.to(device) + self.__cache["rope_pos_sin"] = pos_sin + if pos_cos.device != device: + pos_cos = pos_cos.to(device) + self.__cache["rope_pos_cos"] = pos_cos + return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :] + + with torch.autocast(device.type, enabled=False): + dim = self.config.d_model // self.config.n_heads + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) + seq = torch.arange(seq_len, device=device, dtype=torch.float) + freqs = einsum("i , j -> i j", seq, inv_freq) + positions = torch.cat((freqs, freqs), dim=-1) + pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] + self.__cache["rope_pos_sin"] = pos_sin + self.__cache["rope_pos_cos"] = pos_cos + return pos_sin, pos_cos + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + B, nh, T, hs = x.size() + x = x.view(B, nh, T, 2, hs // 2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype) + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.config.rope_full_precision: + q_, k_ = q.float(), k.float() + else: + q_, k_ = q, k + + with torch.autocast(q.device.type, enabled=False): + query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None + pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) + pos_sin = pos_sin.type_as(q_) + pos_cos = pos_cos.type_as(q_) + q_ = self.apply_rotary_pos_emb( + pos_sin[:, :, key_len - query_len : key_len, :], + pos_cos[:, :, key_len - query_len : key_len, :], + q_, + ) + k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_) + return q_.type_as(q), k_.type_as(k) + + +class Activation(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @property + @abstractmethod + def output_multiplier(self) -> float: + raise NotImplementedError + + @classmethod + def build(cls, config: ModelConfig) -> Activation: + if config.activation_type == ActivationType.gelu: + return cast(Activation, GELU(approximate="none")) + elif config.activation_type == ActivationType.relu: + return cast(Activation, ReLU(inplace=False)) + elif config.activation_type == ActivationType.silu: + return cast(Activation, SiLU(inplace=False)) + elif config.activation_type == ActivationType.swiglu: + return SwiGLU(config) + else: + raise NotImplementedError(f"Unknown activation: '{config.activation_type}'") + + +class GELU(nn.GELU): + @property + def output_multiplier(self) -> float: + return 1.0 + + +class ReLU(nn.ReLU): + @property + def output_multiplier(self) -> float: + return 1.0 + +class SiLU(nn.SiLU): + @property + def output_multiplier(self) -> float: + return 1.0 + +class SwiGLU(Activation): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + @property + def output_multiplier(self) -> float: + return 0.5 + + +def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: + att_bias = torch.triu( + torch.ones(seq_len, seq_len, device=device, dtype=torch.float), + diagonal=1, + ) + att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) + return att_bias.view(1, 1, seq_len, seq_len) # type: ignore + + +def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: + if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: + if causal_bias.device != device: + causal_bias = causal_bias.to(device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + with torch.autocast(device.type, enabled=False): + causal_bias = causal_attention_bias(seq_len, device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + + +def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) + + # shape: (1, 1, seq_len, seq_len) + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) + alibi_bias.abs_().mul_(-1) + + # shape: (n_heads,) + m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) + m.mul_(config.alibi_bias_max / config.n_heads) + + # shape: (1, n_heads, seq_len, seq_len) + return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore + + +class LLaDABlock(nn.Module): + """ + A base class for transformer block implementations. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__() + self.layer_id = layer_id + self.config = config + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: Optional[LayerNormBase] = None + self.q_norm: Optional[LayerNormBase] = None + if config.attention_layer_norm: + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward output projection. + self.ff_out = nn.Linear( + int(self.act.output_multiplier * self.hidden_size), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.ff_out._is_residual = True # type: ignore + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func # type: ignore + + self.flash_attn_func = flash_attn_func + except ModuleNotFoundError: + pass + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + init_weights( + self.config, + self.attn_out, + d=self.config.d_model, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + init_weights( + self.config, + self.ff_out, + d=self.ff_out.in_features, + layer_id=self.layer_id, + type_of_module=ModuleType.out_module, + ) + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + if strategy == ActivationCheckpointingStrategy.fine_grained: + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + else: + self._activation_checkpoint_fn = None + + @classmethod + def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: + target_dtype = input_dtype + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if bias.device.type == "cuda" and torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + target_dtype = torch.get_autocast_cpu_dtype() + if bias.dtype != target_dtype: + bias = bias.to(target_dtype) + ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) + return bias + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: + """ + Computes scaled dot product attention on query, key and value tensors, using an optional + attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. + """ + if self.flash_attn_func is not None and attn_mask is None: + r = self.flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False + ) + return r.transpose(1, 2) + else: + # torch's sdpa doesn't support GQA, so we're doing this + assert k.size(1) == v.size(1) + num_kv_heads = k.size(1) + num_q_heads = q.size(1) + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + + # Modify: MDM set causal to False, and with no attn_mask. + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + def attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + B, T, C = q.size() # batch size, sequence length, d_model + dtype = k.dtype + + # Optionally apply layer norm to keys and queries. + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q).to(dtype=dtype) + k = self.k_norm(k).to(dtype=dtype) + + # Move head forward to be next to the batch dim. + # shape: (B, nh, T, hs) + q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + present = (k, v) if use_cache else None + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + + if self.config.rope: + # Apply rotary embeddings. + q, k = self.rotary_emb(q, k) + + if attention_bias is not None: + # Resize and cast attention bias. + # The current dtype of the attention bias might not match the dtype that the SDP attn function will + # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding + # as down-casting the attention bias to the autocast precision will result in -infs, which will + # cause the SDP attn function to produce NaNs. + attention_bias = self._cast_attn_bias( + attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype + ) + + # Get the attention scores. + # shape: (B, nh, T, hs) + att = self._scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_bias, + dropout_p=0.0 if not self.training else self.config.attention_dropout, + is_causal=False, + ) + + # Re-assemble all head outputs side-by-side. + att = att.transpose(1, 2).contiguous().view(B, T, C) + + # Apply output projection. + return self.attn_out(att), present + + @abstractmethod + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + raise NotImplementedError + + @classmethod + def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock: + if config.block_type == BlockType.sequential: + return LLaDASequentialBlock(layer_id, config, cache) + elif config.block_type == BlockType.llama: + return LLaDALlamaBlock(layer_id, config, cache) + else: + raise NotImplementedError(f"Unknown block type: '{config.block_type}'") + + +class LLaDASequentialBlock(LLaDABlock): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + # Attention input projection. Projects x -> (q, k, v) + head_dim = config.d_model // config.n_heads + self.fused_dims = ( + config.d_model, + config.effective_n_kv_heads * head_dim, + config.effective_n_kv_heads * head_dim, + ) + self.att_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights( + self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + init_weights( + self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module + ) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + if self._activation_checkpoint_fn is not None: + q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split( + self.fused_dims, dim=-1 + ) + else: + q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x = self.ff_proj(x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class LLaDALlamaBlock(LLaDABlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). This block is similar to `LLaDASequentialBlock` + but some operations have slightly different implementations to imitate the + behavior of Llama. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + self.__cache = cache + + # Attention input projection. Projects x -> (q, k, v) + head_dim = config.d_model // config.n_heads + q_proj_out_dim = config.d_model + k_proj_out_dim = config.effective_n_kv_heads * head_dim + v_proj_out_dim = config.effective_n_kv_heads * head_dim + self.q_proj = nn.Linear( + config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + self.k_proj = nn.Linear( + config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + self.v_proj = nn.Linear( + config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device + ) + + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + # new add + self.up_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + x_normed = self.attn_norm(x) + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x, x_up = self.ff_proj(x), self.up_proj(x) # new add + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = x * x_up # new add + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class LLaDAOutput(NamedTuple): + logits: torch.FloatTensor + """ + A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities + for the next token *before* normalization via (log) softmax. + """ + + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] + """ + Attention keys and values from each block. + """ + + hidden_states: Optional[Tuple[torch.Tensor]] + """ + Hidden states from each block. + """ + + +class LLaDAGenerateOutput(NamedTuple): + token_ids: torch.LongTensor + """ + The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`. + These do *not* include the original input IDs. + """ + + scores: torch.FloatTensor + """ + The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`. + """ + + +class LLaDABlockGroup(nn.ModuleList): + def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None): + super().__init__(modules) + self.config = config + self.layer_offset = layer_offset + self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.FloatTensor] = None, + layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + for block_idx, block in enumerate(self): + layer_past = None if layers_past is None else layers_past[block_idx] + block_idx += self.layer_offset + if ( + (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two + and block_idx % 2 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three + and block_idx % 3 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four + and block_idx % 4 == 0 + ) + ): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( # type: ignore + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + return x, attn_key_values + + def reset_parameters(self): + for block in self: + block.reset_parameters() + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + self.activation_checkpointing_strategy = strategy + for block in self: + block.set_activation_checkpointing(strategy) + + +class LLaDAModel(nn.Module): + def __init__(self, config: ModelConfig, init_params: bool = True): + super().__init__() + self.config = config + self.__cache = BufferCache() + + # Validate config. + if self.config.alibi and self.config.flash_attention: + raise Exception("ALiBi is currently not supported with FlashAttention") + + if self.config.alibi and self.config.rope: + raise Exception("ALiBi and RoPE are mutually exclusive") + + if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: + if self.config.embedding_size < self.config.vocab_size: + raise Exception("embedding size should be at least as big as vocab size") + elif self.config.embedding_size % 128 != 0: + import warnings + + warnings.warn( + "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning + ) + + self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None + self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config) + + if not ( + 0 < self.config.block_group_size <= self.config.n_layers + and self.config.n_layers % self.config.block_group_size == 0 + ): + raise Exception("n layers must be divisible by block group size") + + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding( + config.embedding_size or config.vocab_size, config.d_model, device=config.init_device + ), + emb_drop=Dropout(config.embedding_dropout), + ln_f=LayerNorm.build(config), + ) + ) + + blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)] + if self.config.block_group_size > 1: + block_groups = [ + LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size]) + for i in range(0, config.n_layers, config.block_group_size) + ] + self.transformer.update({"block_groups": nn.ModuleList(block_groups)}) + else: + self.transformer.update({"blocks": nn.ModuleList(blocks)}) + + if not (self.config.alibi or self.config.rope): + self.transformer.update( + {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} + ) + if not config.weight_tying: + self.transformer.update( + { + "ff_out": nn.Linear( + config.d_model, + config.embedding_size or config.vocab_size, + bias=config.include_bias, + device=config.init_device, + ) + } + ) + # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights. + if init_params and self.config.init_device != "meta": + self.reset_parameters() + self.__num_fwd_flops: Optional[int] = None + + # Warm up cache. + if self.config.alibi: + get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) + self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + + def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): + self.activation_checkpointing_strategy = strategy + if self.config.block_group_size != 1: + for block_group in self.transformer.block_groups: + block_group.set_activation_checkpointing(strategy) + else: + for block in self.transformer.blocks: + block.set_activation_checkpointing(strategy) + + @property + def device(self) -> torch.device: + device: torch.device = self.transformer.wte.weight.device # type: ignore + if device.type == "meta": + return _non_meta_init_device(self.config) + else: + return device + + def reset_parameters(self): + log.info("Initializing model parameters...") + # Top-level embeddings / linear layers. + init_weights( + self.config, + self.transformer.wte, # type: ignore + std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0, + type_of_module=ModuleType.emb, + ) + if hasattr(self.transformer, "wpe"): + init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore + + # Top-level layer norm. + self.transformer.ln_f.reset_parameters() # type: ignore + + # Output weights. + if hasattr(self.transformer, "ff_out"): + init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore + + # Let the blocks handle themselves. + if self.config.block_group_size == 1: + for block in self.transformer.blocks: + block.reset_parameters() + else: + for block_group in self.transformer.block_groups: + block_group.reset_parameters() + + def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: + if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[ + -1 + ] >= seq_len: + if alibi_bias.device != device: + alibi_bias = alibi_bias.to(device) + self.__cache["alibi_attention_bias"] = alibi_bias + return alibi_bias + with torch.autocast(device.type, enabled=False): + alibi_bias = alibi_attention_bias(seq_len, self.config, device) + self.__cache["alibi_attention_bias"] = alibi_bias + return alibi_bias + + def forward( + self, + input_ids: torch.LongTensor, + input_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, + use_cache: bool = False, + last_logits_only: bool = False, + output_hidden_states: Optional[bool] = None, + ) -> LLaDAOutput: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input + embeddings. When provided, it is treated as the output of the input embedding layer. + :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates + which input IDs are masked. A `1` value in the mask means that + the corresponding input ID should *not* be ignored. A `0` means + that the corresponding input ID is masked. + + This has the same meaning as the `attention_mask` in HuggingFace's `transformers` + library. + :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, + `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used + to introduce causal or other biases. + + If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` + indicates that the i-th element in the sequence is allowed to attend to the j-th + element in the sequence. + + If the tensor is a float tensor, it will just be added to the attention + scores before the softmax. + + The default is causal, which corresponds to a lower-diagonal byte matrix of ones. + :param past_key_values: Pre-computed keys and values for each attention block. + Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + :param use_cache: If `True`, return key and value tensors for each block. + :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. + This can speed up decoding when you only care about the next token. + """ + # Add Basic MDM Model config check + assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM." + assert self.config.rope, "Rope must be used in Llama-Encoder for MDM." + assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM." + + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + if past_key_values: + assert len(past_key_values) == self.config.n_layers + + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] + if past_key_values is None: + past_length = 0 + else: + past_length = past_key_values[0][0].size(-2) + + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore + + if self.config.input_emb_norm: + x = x * (self.config.d_model**0.5) + + if not (self.config.alibi or self.config.rope): + # Get positional embeddings. + # shape: (1, seq_len) + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) + # shape: (1, seq_len, d_model) + pos_emb = self.transformer.wpe(pos) # type: ignore + x = pos_emb + x + + # Add input + positional embeddings and apply dropout. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.emb_drop(x) # type: ignore + + # Transform the attention mask into what the blocks expect. + if attention_mask is not None and 0.0 in attention_mask: + # shape: (batch_size, 1, 1, seq_len) + attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] + attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min + else: + attention_mask = None + + if attention_mask is not None: + attention_bias = attention_mask.to(dtype=torch.float) + else: + attention_bias = None + + attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + + # decoder layers + all_hidden_states = [] + + # Apply blocks one-by-one. + if self.config.block_group_size == 1: + for block_idx, block in enumerate(self.transformer.blocks): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layer_past = None if past_key_values is None else past_key_values[block_idx] + if ( + (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two + and block_idx % 2 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three + and block_idx % 3 == 0 + ) + or ( + self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four + and block_idx % 4 == 0 + ) + ): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( + block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + else: + for group_idx, block_group in enumerate(self.transformer.block_groups): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layers_past = ( + None + if past_key_values is None + else past_key_values[ + group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size + ] + ) + x, cache = block_group( + x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache + ) + if attn_key_values is not None: + assert cache is not None + attn_key_values.extend(cache) + + if last_logits_only: + # shape: (batch_size, 1, d_model) + x = x[:, -1, :].unsqueeze(1) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + x = self.transformer.ln_f(x) # type: ignore + if output_hidden_states: + # add final hidden state post-final-layernorm, following HuggingFace's convention + all_hidden_states.append(x) + + # Get logits. + # shape: (batch_size, seq_len or 1, vocab_size) + if self.config.weight_tying: + logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore + else: + logits = self.transformer.ff_out(x) # type: ignore + if self.config.scale_logits: + logits.mul_(1 / math.sqrt(self.config.d_model)) + + return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] + + +def create_model_config_from_pretrained_config(config: LLaDAConfig): + """ + Utility function + """ + + kwargs = {} + for field in fields(ModelConfig): + kwargs[field.name] = getattr(config, field.name) + + model_config = ModelConfig(**kwargs) + return model_config + + +class LLaDAModelLM(PreTrainedModel): + """ + Extremely barebones HF model wrapper. + """ + + config_class = LLaDAConfig + base_model_prefix = "model" + _no_split_modules = ["LLaDALlamaBlock"] + + def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False): + super().__init__(config) + + if not model: + model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cuda" + self.model = LLaDAModel(model_config, init_params=init_params) + else: + self.model = model + + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x` + ) -> Union[Tuple, CausalLMOutputWithPast]: + if use_cache is None: + use_cache = self.config.use_cache + + if output_attentions: + raise ValueError("output_attentions is not yet supported in LLaDA") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.forward( + input_ids=input_ids, + input_embeddings=inputs_embeds, + attention_mask=attention_mask, + attention_bias=attention_bias, + past_key_values=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + logits = outputs.logits + hidden_states = outputs.hidden_states + + loss = None + if labels is not None: + import warnings + warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.attn_key_values, + hidden_states=hidden_states, + ) + + def can_generate(self) -> bool: + return True + + def prepare_inputs_for_generation( + self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs + ): + if past_key_values: + # This is because we want the model to only process the last generated token. + input_ids = input_ids[:, -1:] + model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values} + + model_inputs.update(kwargs) + model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) + return model_inputs + + # TODO: these are required to make the implementation complete. + # def resize_position_embeddings(self, new_num_position_embeddings: int): + # pass + # + # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: + # pass + # + # def _reorder_cache(self, past_key_values, beam_idx): + # pass + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.transformer.wte + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.transformer.wte = value + + def get_output_embeddings(self): + if self.config.weight_tying: + return self.model.transformer.wte + else: + return self.model.transformer.ff_out + + def set_output_embeddings(self, value: torch.nn.Module): + if self.config.weight_tying: + self.model.transformer.wte = value + else: + self.model.transformer.ff_out = value + + def tie_weights(self): + if self.config.weight_tying: + self.model.transformer.ff_out = self.model.transformer.wte diff --git a/dllm/dllm/pipelines/llada/models/modeling_lladamoe.py b/dllm/dllm/pipelines/llada/models/modeling_lladamoe.py new file mode 100644 index 0000000..1a5882e --- /dev/null +++ b/dllm/dllm/pipelines/llada/models/modeling_lladamoe.py @@ -0,0 +1,1168 @@ +"""LLaDA MoE model pytorch implementation""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers import AutoConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .configuration_lladamoe import LLaDAMoEConfig + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LLaDAMoEConfig" + + +# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, *optional*): + For diffusion language model, attention_mask is set to None by default. + If you pass an attention mask and expect the model to use it for computing other attention mechanisms, + it may lead to logits and aux_loss returned by the model being inconsistent with your expectations. + num_experts (`int`, *optional*): + Number of experts + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeRMSNorm -> LLaDAMoERMSNorm +class LLaDAMoERMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + LLaDAMoERMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LLaDAMoERMSNorm) + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeRotaryEmbedding -> LLaDAMoERotaryEmbedding +class LLaDAMoERotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LLaDAMoEConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LLaDAMoERotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# copied from transformers.models.olmoe.modeling_olmoe.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# copied from transformers.models.olmoe.modeling_olmoe.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + rotary_dim = cos.shape[-1] + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rot = q[..., :rotary_dim] + q_pass = q[..., rotary_dim:] + + k_rot = k[..., :rotary_dim] + k_pass = k[..., rotary_dim:] + + q_rotated = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_rotated = (k_rot * cos) + (rotate_half(k_rot) * sin) + + q_final = torch.cat((q_rotated, q_pass), dim=-1) + k_final = torch.cat((k_rotated, k_pass), dim=-1) + + return q_final, k_final + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeMLP with OlmoeMLP->LLaDAMoEMLP +class LLaDAMoEMLP(nn.Module): + def __init__(self, config, mlp_type): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + if mlp_type == 'dense': + self.intermediate_size = config.dense_intermediate_size + elif mlp_type == 'expert': + self.intermediate_size = config.expert_intermediate_size + elif mlp_type == 'shared_expert': + self.intermediate_size = config.shared_expert_intermediate_size + else: + assert False, "unknown mlp type" + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# copied from transformers.models.olmoe.modeling_olmoe.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeAttention with OlmoeAttention->LLaDAMoEAttention +class LLaDAMoEAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LLaDAMoEConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # **For diffusion language model, we set is_causal to False by default.** + self.is_causal = False + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + if config.qk_layernorm: + self.q_norm = LLaDAMoERMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = LLaDAMoERMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + if 'q_norm' in dir(self): + query_states = self.q_norm(query_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + key_states = self.k_norm(key_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + # attention_mask = None + + if attention_mask is not None: # no matter the length, we just slice it + converter = AttentionMaskConverter(is_causal=False) + extended_mask = converter.to_4d( + attention_mask_2d=attention_mask, + query_length=q_len, + dtype=attn_weights.dtype + ) + attn_weights = attn_weights + extended_mask + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.olmoe.modeling_olmoe.FlashAttention2 with OlmoeFlashAttention2->LLaDAMoEFlashAttention2 +class LLaDAMoEFlashAttention2(LLaDAMoEAttention): + """ + LLaDAMoE flash attention module. This module inherits from `LLaDAMoEAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # copied from transformers.models.olmoe.modeling_olmoe.OlmoeFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + if 'q_norm' in dir(self): + query_states = self.q_norm(query_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + key_states = self.k_norm(key_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + value_states = self.v_proj(hidden_states) + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LLaDAMoERMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + # attention_mask = None + self.is_causal = False + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeSdpaAttention with OlmoeSdpaAttention->LLaDAMoESdpaAttention +class LLaDAMoESdpaAttention(LLaDAMoEAttention): + """ + LLaDAMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LLaDAMoEAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + "LLaDAModel is using LLaDAMoESdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + if 'q_norm' in dir(self): + query_states = self.q_norm(query_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + key_states = self.k_norm(key_states.reshape(-1, self.head_dim)).reshape(bsz, q_len, -1) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + is_causal = False + causal_mask = None + if attention_mask is not None: + converter = AttentionMaskConverter(is_causal=False) + causal_mask = converter.to_4d( + attention_mask_2d=attention_mask, + query_length=q_len, + dtype=query_states.dtype, + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLADAMOE_ATTENTION_CLASSES = { + "eager": LLaDAMoEAttention, + "flash_attention_2": LLaDAMoEFlashAttention2, + "sdpa": LLaDAMoESdpaAttention, +} + + +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeSparseMoeBlock with OlmoeSparseMoeBlock->LLaDAMoESparseMoeBlock +class LLaDAMoESparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = False + self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) + self.experts = nn.ModuleList([LLaDAMoEMLP(config, 'expert') for _ in range(self.num_experts)]) + self.score_func = config.moe_router_score_function + if config.moe_router_enable_expert_bias: + self.register_buffer("expert_bias", torch.zeros(self.num_experts)) + else: + self.expert_bias = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + if self.expert_bias is not None: + routing_weights += self.expert_bias + + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be selected + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +class LLaDAMoEDecoderLayer(nn.Module): + def __init__(self, config: LLaDAMoEConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.mlp_type = 'dense' if config.moe_layer_freq[layer_idx] == 0 else 'moe' + + self.self_attn = LLADAMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LLaDAMoESparseMoeBlock(config) if self.mlp_type == 'moe' else LLaDAMoEMLP(config, 'dense') + self.input_layernorm = LLaDAMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LLaDAMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.shared_expert_intermediate_size is not None and self.mlp_type == 'moe': + self.shared_expert = LLaDAMoEMLP(config, 'shared_expert') + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + For diffusion language model, attention_mask is set to None(full attention). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + For diffusion language model, use_cache is set to False by default. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + For diffusion language model, past_key_value is set to None by default. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + For diffusion language model, cache_position is set to None by default. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # **For diffusion language model, attention_mask is set to None(full attention) by default.** + use_cache = False + # attention_mask = None + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + shared_expert_states = hidden_states + + hidden_states = self.mlp(hidden_states) + + if hasattr(self, "shared_expert"): + hidden_states = hidden_states + self.shared_expert(shared_expert_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLADAMOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`LLaDAMoEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaDAMoE Model outputting raw hidden-states without any specific head on top.", + LLADAMOE_START_DOCSTRING, +) +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeModel with OlmoePreTrainedModel->LLaDAMoEPreTrainedModel +class LLaDAMoEPreTrainedModel(PreTrainedModel): + config_class = LLaDAMoEConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LLaDAMoEDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLADAMOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. + **For diffusion language model, attention_mask is set to None(full attention) by default.** + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + **For diffusion language model, past_key_values can not be applied by default.** + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + For diffusion languagem model, the use_cache and past_key_values can not be enabled for default setting. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + **For diffusion language model, cache_position can not be applied by default.** +""" + + +@add_start_docstrings( + "The bare LLaDAMoE Model outputting raw hidden-states without any specific head on top.", + LLADAMOE_START_DOCSTRING, +) +# copied from transformers.models.olmoe.modeling_olmoe.OlmoeModel with OlmoeModel->LLaDAMoEModel +class LLaDAMoEModel(LLaDAMoEPreTrainedModel): + """ + Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LLaDAMoEDecoderLayer`] + Args: + config: LLaDAMoEConfig + """ + + def __init__(self, config: LLaDAMoEConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LLaDAMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LLaDAMoERMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LLaDAMoERotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLADAMOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + assert (not use_cache and past_key_values is None and cache_position is None), "The cache mechanism is not suppotred for LLaDA MoE by default." + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = attention_mask + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class LLaDAMoEModelLM(LLaDAMoEPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LLaDAMoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLADAMOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + For the current inference code of the diffusion language model, passing the parameters `labels` and `num_logits_to_keep` to compute loss is not supported. + Please note that for the diffusion language model, you cannot use model.generate() to generate responses. Please use the provided sampling code to generate model outputs. + + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> model = AutoModel.from_pretrained("path/to/LLaDAMoE") + >>> tokenizer = AutoTokenizer.from_pretrained("path/to/LLaDAMoE") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = generate() # Please use the customized generate method instead of model.generate(). + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + assert (labels is None and num_logits_to_keep == 0), "LLaDAMoE model does not support calculate loss in the forward pass." + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/dllm/dllm/pipelines/llada/trainer.py b/dllm/dllm/pipelines/llada/trainer.py new file mode 100644 index 0000000..180ffd0 --- /dev/null +++ b/dllm/dllm/pipelines/llada/trainer.py @@ -0,0 +1,3 @@ +from dllm.core.trainers import MDLMTrainer + +LLaDATrainer = MDLMTrainer diff --git a/dllm/dllm/pipelines/rnd/__init__.py b/dllm/dllm/pipelines/rnd/__init__.py new file mode 100644 index 0000000..ab6bd16 --- /dev/null +++ b/dllm/dllm/pipelines/rnd/__init__.py @@ -0,0 +1,7 @@ +# from dllm.pipelines.rnd import generate, trainer +from . import models +from .models import RND1LM, RND1Config, RND1GenerationConfig + +# from dllm.pipelines.rnd.models.modeling_rnd import RND1LM +# from dllm.pipelines.rnd.models.configuration_rnd import RND1Config +from .trainer import RNDTrainer diff --git a/dllm/dllm/pipelines/rnd/models/__init__.py b/dllm/dllm/pipelines/rnd/models/__init__.py new file mode 100644 index 0000000..030b38c --- /dev/null +++ b/dllm/dllm/pipelines/rnd/models/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2025 Radical Numerics Inc. +# +# This source code is licensed under the Apache License, Version 2.0, found in the +# LICENSE file in the root directory of this source tree. + +""" +Radical Numerics Diffusion (RND1) - Diffusion-based Language Model. +""" + +from .configuration_rnd import RND1Config +from .modeling_rnd import ( + RND1LM, + RND1Model, + RND1PreTrainedModel, + RND1Attention, + RND1DecoderLayer, + RND1SparseMoeBlock, +) +from .generation_config import RND1GenerationConfig +from .generation_utils import RND1GenerationMixin +from .sampling import ( + diffusion_sample, + apply_top_k_filtering, + apply_top_p_filtering, +) +from .terminal_visualizer import TerminalVisualizer, SimpleProgressBar + +__version__ = "0.1.0" + +__all__ = [ + "RND1Config", + "RND1GenerationConfig", + "RND1LM", + "RND1Model", + "RND1PreTrainedModel", + "RND1Attention", + "RND1DecoderLayer", + "RND1SparseMoeBlock", + "RND1GenerationMixin", + "TerminalVisualizer", + "SimpleProgressBar", +] + +# Register with HuggingFace Auto classes for local usage +try: + from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM + + AutoConfig.register("rnd1", RND1Config) + AutoModel.register(RND1Config, RND1LM) + AutoModelForMaskedLM.register(RND1Config, RND1LM) +except ImportError: + # transformers not available or Auto classes not imported + pass diff --git a/dllm/dllm/pipelines/rnd/models/configuration_rnd.py b/dllm/dllm/pipelines/rnd/models/configuration_rnd.py new file mode 100644 index 0000000..1de2982 --- /dev/null +++ b/dllm/dllm/pipelines/rnd/models/configuration_rnd.py @@ -0,0 +1,124 @@ + +# Copyright 2025 Radical Numerics Inc. +# +# This source code is licensed under the Apache License, Version 2.0, found in the +# LICENSE file in the root directory of this source tree. + +""" +RND1 Model Configuration. + +This module defines the configuration class for RND1 models. +The default settings are derived from Qwen/Qwen3-30B-A3B and augmented +with RND1-specific parameters. +""" + +from transformers.configuration_utils import PretrainedConfig + +# Qwen3-30B-A3B / checkpoint defaults +CONFIG_DEFAULTS = { + "attention_bias": False, + "attention_dropout": 0.0, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 40960, + "max_window_layers": 48, + "mlp_only_layers": [], + "moe_intermediate_size": 768, + "norm_topk_prob": True, + "num_attention_heads": 32, + "num_experts": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 48, + "num_key_value_heads": 4, + "output_router_logits": False, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": False, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.001, + "sliding_window": False, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "use_cache": False, + "use_sliding_window": False, + "vocab_size": 151936, +} + + +class RND1Config(PretrainedConfig): + """ + Configuration class for RND1 models. + + This configuration extends Qwen3MoeConfig with additional parameters + specific to the RND1 (Radical Numerics Diffusion v1) architecture. + + Args: + moe_backend: Backend for MoE computation ("hf", "vllm", "sglang" or "flashinfer") + num_diffusion_steps: Default number of diffusion steps for generation + mask_token_id: Token ID used for masking (default: 151669 for Qwen) + **kwargs: Additional arguments passed to Qwen3MoeConfig + """ + + model_type = "rnd1" + + def __init__( + self, + moe_backend: str = "hf", + num_diffusion_steps: int = 256, + mask_token_id: int = 151669, + **kwargs, + ): + # Force non-causal and no caching for RND1 + kwargs["use_cache"] = False + kwargs["is_causal"] = False + + super().__init__(**kwargs) + + # Set defaults after pretrained init to prevent overrides + self.set_config_defaults() + + # QoL: set attn impl directly from config + if "attn_implementation" in kwargs: + self._attn_implementation = kwargs["attn_implementation"] + + # RND1-specific parameters + self.moe_backend = moe_backend + self.num_diffusion_steps = num_diffusion_steps + self.mask_token_id = mask_token_id + + # Ensure bidirectional attention and no caching + self.is_causal = False + self.use_cache = False + + def set_config_defaults(self): + """ + Ensure model defaults are set according to final training checkpoint + + Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which + RND1 is derived. + """ + for k, v in CONFIG_DEFAULTS.items(): + setattr(self, k, v) + + def to_dict(self): + """ + Serializes configuration to dictionary with auto_map for Hub. + + The auto_map ensures that when users load from HuggingFace Hub, + the correct custom classes are automatically resolved. + """ + data = super().to_dict() + data.setdefault( + "auto_map", + { + "AutoConfig": "configuration_rnd.RND1Config", + "AutoModel": "modeling_rnd.RND1Model", + "AutoModelForMaskedLM": "modeling_rnd.RND1LM", + }, + ) + return data diff --git a/dllm/dllm/pipelines/rnd/models/generation_config.py b/dllm/dllm/pipelines/rnd/models/generation_config.py new file mode 100644 index 0000000..c4da69d --- /dev/null +++ b/dllm/dllm/pipelines/rnd/models/generation_config.py @@ -0,0 +1,77 @@ +# Copyright 2025 Radical Numerics Inc. +# +# This source code is licensed under the Apache License, Version 2.0, found in the +# LICENSE file in the root directory of this source tree. + +""" +RND1 Generation Configuration. + +This module defines the generation configuration for RND1 models, +controlling the diffusion-based generation process. +""" + +from typing import Optional +from transformers.generation.configuration_utils import GenerationConfig + + +class RND1GenerationConfig(GenerationConfig): + """ + Configuration class for RND1 generation parameters. + + This class extends the base GenerationConfig to include parameters + specific to diffusion-based language generation. + + Args: + max_length: Maximum sequence length + num_diffusion_steps: Number of denoising steps in the diffusion process + mask_token_id: Token ID used for masking during diffusion + temperature: Temperature for sampling (higher = more random) + top_k: Optional top-k filtering + top_p: Optional nucleus (top-p) filtering + greedy: Whether to use greedy decoding (True) or stochastic sampling (False) + **kwargs: Additional arguments passed to GenerationConfig + """ + + def __init__( + self, + max_length: int = 256, + num_diffusion_steps: int = 256, + mask_token_id: int = 151669, + temperature: float = 0.1, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + greedy: bool = False, + bos_token_id: int = None, + eos_token_id: int = None, + pad_token_id: int = None, + use_cache: bool = False, + **kwargs, + ): + # Force no caching for RND generation + # kwargs['use_cache'] = False + kwargs.pop('use_cache', None) + super().__init__( + max_length=max_length, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=not greedy, + use_cache=False, + **kwargs, + ) + + # RND-specific parameters + self.num_diffusion_steps = num_diffusion_steps + self.mask_token_id = mask_token_id + self.greedy = greedy + + def to_dict(self): + """Convert configuration to dictionary.""" + output = super().to_dict() + output["num_diffusion_steps"] = self.num_diffusion_steps + output["mask_token_id"] = self.mask_token_id + output["greedy"] = self.greedy + return output diff --git a/dllm/dllm/pipelines/rnd/models/generation_utils.py b/dllm/dllm/pipelines/rnd/models/generation_utils.py new file mode 100644 index 0000000..b551fb6 --- /dev/null +++ b/dllm/dllm/pipelines/rnd/models/generation_utils.py @@ -0,0 +1,187 @@ +# Copyright 2025 Radical Numerics Inc. +# +# This source code is licensed under the Apache License, Version 2.0, found in the +# LICENSE file in the root directory of this source tree. + +""" +RND1 Generation Utilities. + +This module provides generation utilities and mixins for RND1 models, +including the main GenerationMixin class that integrates with HuggingFace. +""" + +import torch +import torch.nn as nn +from typing import Optional, Union, Dict, Any +from transformers import GenerationMixin as HFGenerationMixin +from transformers.generation import GenerationConfig + +from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering + + +class RND1GenerationMixin(HFGenerationMixin): + """ + Generation mixin for RND1 models. + + This mixin provides generation methods compatible with HuggingFace's + generation API while using RND1's diffusion-based sampling internally. + """ + + def generate( + self, + inputs: Optional[torch.LongTensor] = None, + generation_config: Optional[GenerationConfig] = None, + # RND1-specific parameters + prefix_ids: Optional[torch.LongTensor] = None, + suffix_ids: Optional[torch.LongTensor] = None, + infill_length: Optional[int] = None, + return_dict_in_generate: Optional[bool] = None, + **kwargs, # Accept all kwargs to be compatible with pipelines + ) -> Union[torch.LongTensor, Dict[str, Any]]: + """ + Generate text using RND1's diffusion-based sampling. + + Follows HuggingFace's standard generate API, using diffusion sampling + internally. Supports both standard generation and infilling. + + Args: + inputs: Input token IDs to use as prefix (standard HF parameter) + generation_config: Generation configuration object + prefix_ids: Alternative to inputs for infilling tasks + suffix_ids: Optional suffix for infilling tasks + infill_length: Length of infill region (for infilling) + return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput + **kwargs: Additional arguments (accepted for compatibility) + + Returns: + Generated token IDs or GenerateDecoderOnlyOutput + """ + if generation_config is not None: + gen_config = generation_config + model_kwargs = kwargs.copy() + else: + # Only prepare config from kwargs if no config was provided + gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs) + + device = next(self.parameters()).device + + if inputs is not None: + prefix_ids = inputs.to(device) + elif prefix_ids is not None: + prefix_ids = prefix_ids.to(device) + else: + prefix_ids = None + + if suffix_ids is not None: + suffix_ids = suffix_ids.to(device) + + eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645) + eos_token_id = None if eos_token_id == -1 else eos_token_id + pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None) + bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None) + mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669)) + + if infill_length is not None and prefix_ids is not None: + # Infilling mode: use specified infill_length + prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0 + suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0 + seq_len = prefix_len + infill_length + suffix_len + else: + # Standard generation mode + if prefix_ids is not None: + prefix_len = prefix_ids.shape[1] + if gen_config.max_new_tokens is not None: + seq_len = prefix_len + gen_config.max_new_tokens + else: + seq_len = gen_config.max_length or self.config.max_position_embeddings + else: + seq_len = gen_config.max_length or self.config.max_position_embeddings + + num_diffusion_steps = getattr(gen_config, "num_diffusion_steps", + getattr(self.config, "num_diffusion_steps", 256)) + + temperature = float(getattr(gen_config, "temperature", 1.0)) + top_k = getattr(gen_config, "top_k", None) + top_p = getattr(gen_config, "top_p", None) + + greedy = getattr(gen_config, "greedy", + not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True) + + + with torch.inference_mode(): + sequences = diffusion_sample( + model=self, + seq_len=seq_len, + num_steps=num_diffusion_steps, + mask_token_id=mask_token_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + greedy=greedy, + prefix_ids=prefix_ids, + suffix_ids=suffix_ids, + infill_length=infill_length, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + device=device, + visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs + ) + + if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False): + from transformers.generation.utils import GenerateDecoderOnlyOutput + return GenerateDecoderOnlyOutput(sequences=sequences) + + return sequences + + def generate_with_visualization( + self, + tokenizer, + inputs: Optional[torch.LongTensor] = None, + generation_config: Optional[GenerationConfig] = None, + suffix_ids: Optional[torch.LongTensor] = None, + infill_length: Optional[int] = None, + **kwargs, + ) -> torch.LongTensor: + """ + Generate with live visualization (for demos). + + This method requires a tokenizer to display the generation process. + For production use, prefer `generate()`. + + Args: + tokenizer: Tokenizer for decoding tokens to text + inputs: Input token IDs to use as prefix + generation_config: Generation configuration object + suffix_ids: Optional suffix token IDs + infill_length: Length of infill region + **kwargs: Additional arguments for backward compatibility + + Returns: + Generated token IDs as LongTensor + """ + from .terminal_visualizer import TerminalVisualizer + visualizer = TerminalVisualizer(tokenizer, show_visualization=True) + + return self.generate( + inputs=inputs, + generation_config=generation_config, + suffix_ids=suffix_ids, + infill_length=infill_length, + visualizer=visualizer, + return_dict_in_generate=False, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + **kwargs, + ) -> Dict[str, Any]: + """ + Prepare inputs for generation (required by HuggingFace). + + For RND1, we don't use the standard autoregressive generation, + so this just returns the input_ids. + """ + return {"input_ids": input_ids} diff --git a/dllm/dllm/pipelines/rnd/models/modeling_rnd.py b/dllm/dllm/pipelines/rnd/models/modeling_rnd.py new file mode 100644 index 0000000..84253bb --- /dev/null +++ b/dllm/dllm/pipelines/rnd/models/modeling_rnd.py @@ -0,0 +1,653 @@ +# Copyright 2025 Radical Numerics Inc. +# +# This source code is licensed under the Apache License, Version 2.0, found in the +# LICENSE file in the root directory of this source tree. + +""" +RND1 model implementation. + +This module implements the RND1 architecture with bidirectional attention for +diffusion-based language modeling. Includes support for Mixture of Experts (MoE) +with multiple backend options (HF, vLLM, SGLang, FlashInfer). + +Based on the Qwen3Moe architecture: +https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +""" + +from __future__ import annotations + +import os +from typing import Optional, Tuple, List, Union + +import torch +from torch import nn + +from transformers.utils import logging +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + MoeModelOutputWithPast, + MaskedLMOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationConfig + +from .configuration_rnd import RND1Config +from .generation_utils import RND1GenerationMixin + +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + Qwen3MoeRMSNorm, + Qwen3MoeRotaryEmbedding, + Qwen3MoeMLP, + apply_rotary_pos_emb +) +import torch.nn.functional as F + + +try: + from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts as fused_experts_vllm, fused_topk as fused_topk_vllm + from vllm.model_executor.layers.layernorm import RMSNorm as VLLMRMSNorm +except Exception: + fused_experts_vllm = None + fused_topk_vllm = None + +try: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe + # from sglang.srt.layers.layernorm import RMSNorm as SGLangRMSNorm # TODO: buggy atm + from sglang.srt.layers.moe.topk import StandardTopKOutput +except Exception: + sglang_fused_moe = None + StandardTopKOutput = None + + +try: + import flashinfer.fused_moe as fused_moe + ## TODO: below needs flashinfer>=0.4.0, but has some bug atm + # from flashinfer.norm import rmsnorm as flashinfer_rmsnorm + # class FlashInferRMSNorm(Qwen3MoeRMSNorm): + # """Wrapper around FlashInfer RMSNorm to match Qwen3MoeRMSNorm interface""" + # def forward(self, hidden_states): + # return flashinfer_rmsnorm(hidden_states, self.weight, self.variance_epsilon) + +except Exception: + fused_moe = None + +logger = logging.get_logger(__name__) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """Expand key/value heads to match query heads for grouped-query attention.""" + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class RND1Attention(nn.Module): + """RND1 attention layer with bidirectional attention for diffusion modeling.""" + + def __init__(self, config: RND1Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.scaling = self.head_dim ** -0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + + if config.moe_backend == "vllm": + RMSNormClass = VLLMRMSNorm + else: + RMSNormClass = Qwen3MoeRMSNorm + self.q_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps) + + self.sliding_window = getattr(config, "sliding_window", None) + + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + dual_cache: Optional[bool] = False, + replace_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]: + + bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa") + + if use_sdpa: + if attention_mask is not None and isinstance(attention_mask, torch.Tensor): + if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]: + attention_mask = attention_mask.to(dtype=query_states.dtype) + + assert not self.is_causal, f"Attention layer {self.layer_idx} is causal" + attn_out = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, + attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=self.is_causal, + ) + attn_out = attn_out.transpose(1, 2).contiguous() + attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim) + attn_out = self.o_proj(attn_out) + return attn_out, None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: + # TODO: modify this to boolean masks + attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]] + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_out = torch.matmul(attn_weights, value_states) + attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1) + attn_out = self.o_proj(attn_out) + + return attn_out, None + + +class RND1DecoderLayer(nn.Module): + """RND1 decoder layer with bidirectional attention for diffusion language modeling.""" + + def __init__(self, config: RND1Config, layer_idx: int): + super().__init__() + self.self_attn = RND1Attention(config, layer_idx) + self.mlp = RND1SparseMoeBlock(config) + if config.moe_backend == "vllm": + RMSNormClass = VLLMRMSNorm + else: + RMSNormClass = Qwen3MoeRMSNorm + self.input_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + replace_position: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + attn_out, attn_weights = self.self_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + replace_position=replace_position, + ) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + ff_out = self.mlp(hidden_states) + if isinstance(ff_out, tuple): + ff_out = ff_out[0] + hidden_states = residual + ff_out + + return hidden_states, attn_weights + + +class RND1SparseMoeBlock(nn.Module): + """RND1 Sparse MoE block with multiple backend support (HF, vLLM, SGLang, FlashInfer).""" + + def __init__(self, config: RND1Config): + super().__init__() + self.config = config + self.backend = getattr(config, "moe_backend", "hf") + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size) + + self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)] + ) + + # Cached weight tensors for optimized backends + self._w1 = None + self._w2 = None + if self.backend == "sglang": + if sglang_fused_moe is None or StandardTopKOutput is None: + raise RuntimeError("sglang is not available, cannot use sglang backend") + elif self.backend == "flashinfer": + if fused_moe is None: + raise RuntimeError("flashinfer is not available, cannot use flashinfer backend") + elif self.backend == "vllm": + if fused_experts_vllm is None or fused_topk_vllm is None: + raise RuntimeError("vllm is not available, cannot use vllm backend") + + @torch.no_grad() + def _initialize_weights( + self, + free_experts: bool = True, + mode: str = "vllm", + ) -> None: + logger.info(f"Initializing weights for {mode} backend") + # Stack directly on device where weights already reside (loaded by HF) + gate_list: List[torch.Tensor] = [] + up_list: List[torch.Tensor] = [] + down_list: List[torch.Tensor] = [] + + # Collect weight references without any device moves + for expert in self.experts: + gate_list.append(expert.gate_proj.weight.data) + up_list.append(expert.up_proj.weight.data) + down_list.append(expert.down_proj.weight.data) + + gate_w_stacked = torch.stack(gate_list, dim=0).contiguous() + up_w_stacked = torch.stack(up_list, dim=0).contiguous() + down_w_stacked = torch.stack(down_list, dim=0).contiguous() + + if mode == "flashinfer": + w1 = torch.cat([up_w_stacked, gate_w_stacked], dim=1) # FlashInfer expects [up; gate] ordering + else: + w1 = torch.cat([gate_w_stacked, up_w_stacked], dim=1) + w2 = down_w_stacked + self._w1 = w1 + self._w2 = w2 + + + if free_experts: + # Free per-expert modules to reclaim memory + logger.info(f"Freeing experts for {mode} backend") + del self.experts + self.experts = None + + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass with expert routing and computation.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + x = hidden_states.view(-1, hidden_dim) + + # Expert routing + router_logits = self.gate(x) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + if self.backend == "vllm": + routing_weights, selected_experts, *_ = fused_topk_vllm( + hidden_states=x, + gating_output=router_logits, + topk=self.top_k, + renormalize=self.norm_topk_prob, + ) + else: + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + + + # if self.backend == "hf": + # final_hidden_states = torch.zeros( + # (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + # ) + + # expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + # expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + # for expert_idx in expert_hit: + # expert_layer = self.experts[expert_idx] + # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + # current_state = x[top_x] + # current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + # final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + # out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + # return out, router_logits.view(batch_size, sequence_length, -1) + if self.backend == "hf": + # Accumulate buffer: [B*S, H] + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # expert_mask: [E, top_k, tokens] + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0).contiguous() + + # 顺序遍历所有 experts;即使本 rank 没命中也要进入 forward,避免 ZeRO-3 控制流分歧 + for e in range(self.num_experts): + expert_layer = self.experts[int(e)] + + # 取出该 expert 命中的 token 索引 + idx, top_x = torch.where(expert_mask[e]) # idx∈[0, top_k), shapes: [n_tok_e] + current_state = x[top_x] # [n_tok_e, H],n_tok_e 可能为 0 + # if top_x.numel() == 0: + # print("0") + + # 空批照样前向;大多数 Linear/MLP 对 0 行输入是 no-op,但会对齐 ZeRO-3 的参数路径 + expert_out = expert_layer(current_state) # [n_tok_e, H] + + # 路由权重并加权 + w = routing_weights[top_x, idx] # [n_tok_e] + expert_out = expert_out * w.unsqueeze(-1) # [n_tok_e, H] + + # 累加回全局缓冲;当 n_tok_e=0 时这是合法的空操作 + final_hidden_states.index_add_(0, top_x, expert_out.to(hidden_states.dtype)) + + out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return out, router_logits.view(batch_size, sequence_length, -1) + + elif self.backend == "flashinfer": + # if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None: + # self._initialize_flashinfer_weights() + if self._w1 is None or self._w2 is None: + self._initialize_weights(mode="flashinfer") + + result = fused_moe.cutlass_fused_moe( + input=x, + token_selected_experts=selected_experts.to(torch.int), + token_final_scales=routing_weights.to(torch.float32), + fc1_expert_weights=self._w1, + fc2_expert_weights=self._w2, + output_dtype=x.dtype, + quant_scales=None, + ) + if isinstance(result, (list, tuple)): + out_flat = result[0] + else: + out_flat = result + out = out_flat.view(batch_size, sequence_length, hidden_dim) + return out, router_logits.view(batch_size, sequence_length, -1) + + elif self.backend == "sglang": + if self._w1 is None or self._w2 is None: + self._initialize_weights(mode="sglang") + + topk_output = StandardTopKOutput( + topk_weights=routing_weights, + topk_ids=selected_experts, + router_logits=router_logits, + ) + + out_flat = sglang_fused_moe( + hidden_states=x, + w1=self._w1, + w2=self._w2, + topk_output=topk_output, + ) + out = out_flat.view(batch_size, sequence_length, hidden_dim) + return out, router_logits.view(batch_size, sequence_length, -1) + + elif self.backend == "vllm": + if self._w1 is None or self._w2 is None: + self._initialize_weights() + + out_flat = fused_experts_vllm( + x, + self._w1, + self._w2, + routing_weights, + selected_experts, + ) + out = out_flat.view(batch_size, sequence_length, hidden_dim) + return out, router_logits.view(batch_size, sequence_length, -1) + + else: + raise ValueError(f"Invalid backend: {self.backend}") + + +class RND1PreTrainedModel(PreTrainedModel): + """Base class for RND1 models with weight initialization and loading support.""" + config_class = RND1Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["RND1DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + """Initialize weights using normal distribution.""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + weights_only: bool = True, + **kwargs, + ): + """Load pretrained model with generation config.""" + _model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + + resume_download = kwargs.get("resume_download", None) + proxies = kwargs.get("proxies", None) + subfolder = kwargs.get("subfolder", "") + from_auto_class = kwargs.get("_from_auto", False) + from_pipeline = kwargs.get("_from_pipeline", None) + + _model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + ) + + # If configured to use a fused backend, pack fused tensors once after load + try: + cfg = getattr(_model, "config", None) + backend = getattr(cfg, "moe_backend", "hf") if cfg is not None else "hf" + if backend in ("sglang", "vllm"): + # Walk decoder layers and initialize fused weights + model_core = getattr(_model, "model", _model) + layers = getattr(model_core, "layers", None) + if isinstance(layers, nn.ModuleList): + for layer in layers: + mlp = getattr(layer, "mlp", None) + if hasattr(mlp, "_initialize_weights"): + mlp._initialize_weights( + free_experts=True, + mode=backend, + ) + except Exception as _e: + logger.warning(f"Backend {backend} weight processing skipped: {_e}") + + return _model + + +class RND1Model(RND1PreTrainedModel): + """RND1 transformer model with bidirectional attention for diffusion language modeling.""" + + def __init__(self, config: RND1Config): + super().__init__(config) + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) + if config.moe_backend == "vllm": + RMSNormClass = VLLMRMSNorm + else: + RMSNormClass = Qwen3MoeRMSNorm + self.norm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps) + + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + + self.post_init() + + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> MoeModelOutputWithPast: + """Forward pass through the RND1 model.""" + + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if position_ids is None: + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0) + + if isinstance(attention_mask, torch.Tensor): + # shape: (batch_size, 1, 1, seq_len) + attention_mask = attention_mask.to(dtype=torch.float)[:, None, None, :] + attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min + + position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + + for layer in self.layers: + hidden_states, _ = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + router_logits=None, + ) + + +class RND1LM(RND1PreTrainedModel, RND1GenerationMixin): + """Radical Numerics Diffusion Language Model with bidirectional attention.""" + + def __init__(self, config: RND1Config): + super().__init__(config) + self.model = RND1Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings layer.""" + return self.model.embed_tokens + + def set_input_embeddings(self, value): + """Set the input embeddings layer.""" + self.model.embed_tokens = value + + def get_output_embeddings(self): + """Get the output embeddings layer (lm_head).""" + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings layer (lm_head).""" + self.lm_head = new_embeddings + + @classmethod + def can_generate(cls) -> bool: + """Indicates this model can generate text.""" + return True + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ) -> MaskedLMOutput: + """Forward pass with optional loss computation.""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + logits = self.lm_head(outputs.last_hidden_state) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + return MaskedLMOutput( + loss=loss, + logits=logits, + ) diff --git a/dllm/dllm/pipelines/rnd/models/sampling.py b/dllm/dllm/pipelines/rnd/models/sampling.py new file mode 100644 index 0000000..0483b27 --- /dev/null +++ b/dllm/dllm/pipelines/rnd/models/sampling.py @@ -0,0 +1,260 @@ +# Copyright 2025 Radical Numerics Inc. +# +# This source code is licensed under the Apache License, Version 2.0, found in the +# LICENSE file in the root directory of this source tree. + +""" +RND1 sampling module for masked diffusion generation. + +This module implements entropy-based token selection for iterative denoising +in diffusion language models. Supports both greedy and stochastic sampling +with optional prefix/suffix constraints and infilling. +""" + +import torch +import torch.nn as nn +from typing import Optional, Union + + +def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor: + """ + Apply top-k filtering to logits: with non-top-k values set to -inf + """ + top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1) + filtered_logits = torch.full_like(logits, float('-inf')) + filtered_logits.scatter_(-1, top_k_indices, top_k_values) + return filtered_logits + + +def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor: + """ + Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf + """ + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above threshold + sorted_indices_to_remove = cumulative_probs > p + sorted_indices_to_remove[..., 0] = False # Keep at least one token + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + + indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) + return logits.masked_fill(indices_to_remove, float('-inf')) + + +@torch.no_grad() +def diffusion_sample( + model: nn.Module, + seq_len: int = 256, + num_steps: int = 256, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: float = 1.0, + greedy: bool = True, + mask_token_id: int = 151669, + prefix_ids: Optional[torch.LongTensor] = None, + suffix_ids: Optional[torch.LongTensor] = None, + infill_length: Optional[int] = None, + eos_token_id: int = 151645, + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + visualizer: Optional[object] = None, +) -> torch.LongTensor: + """ + Perform masked diffusion sampling with entropy-based token selection. + + Args: + model: The RND1 language model + seq_len: Target sequence length + num_steps: Number of denoising steps + top_k: Optional top-k filtering for sampling (None = no filtering) + top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering) + When both top_k and top_p are set, top_k is applied first, then top_p + temperature: Temperature for sampling (higher = more random, lower = more deterministic) + Values close to 0 are clamped to 1e-8 to avoid division by zero + greedy: Whether to use greedy sampling (True) or stochastic (False) + mask_token_id: Token ID for masked positions (default: 151669) + prefix_ids: Optional prefix token IDs to preserve + suffix_ids: Optional suffix token IDs to preserve + infill_length: Length of infill region between prefix/suffix + eos_token_id: End of sequence token ID (default: 151645) + pad_token_id: Padding token ID (default: None, uses 0 if needed) + bos_token_id: Beginning of sequence token ID (default: None) + device: Device for computation (None = infer from model) + visualizer: Optional visualizer for live visualization + + Returns: + Generated token IDs as LongTensor + """ + model.eval() + + if device is None: + device = next(model.parameters()).device + else: + device = torch.device(device) + + if pad_token_id is None: + pad_token_id = 0 + + # Build initial masked sequence + # When prefix_ids is provided, we create a sequence of length seq_len where: + # - The prefix occupies the first pre_len positions + # - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated + if prefix_ids is not None or suffix_ids is not None: + if prefix_ids is not None: + prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device) + pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0 + else: + pre_len = 0 + + if suffix_ids is not None: + suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device) + suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0 + else: + suf_len = 0 + + reserved = (1 if eos_token_id is not None else 0) + used = pre_len + suf_len + reserved + + if used > seq_len: + raise ValueError( + f"Combined length of prefix ({pre_len}), suffix ({suf_len}), " + f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). " + f"Please increase seq_len or reduce input lengths." + ) + elif used == seq_len: + raise ValueError( + f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) " + f"+ special tokens ({reserved}) = seq_len ({seq_len}). " + f"Need at least 1 position for generation." + ) + + infill_length = min(infill_length or (seq_len - used), seq_len - used) + + x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device) + pos = 0 + # if bos_token_id is not None: + # x[0, pos] = bos_token_id; pos += 1 + if eos_token_id is not None: + x[0, -1] = eos_token_id + if pre_len > 0: + x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len] + pos += pre_len + fill_start, fill_end = pos, pos + infill_length + x[0, fill_start:fill_end] = mask_token_id + # print(fill_start, fill_end, seq_len, used, x[0, -1]) + pos = fill_end + if suf_len > 0: + x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len] + pos += suf_len + + init_maskable = torch.zeros_like(x, dtype=torch.bool) + init_maskable[0, fill_start:fill_end] = True + else: + x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device) + if bos_token_id is not None: + x[0, 0] = bos_token_id + if eos_token_id is not None: + x[0, -1] = eos_token_id + init_maskable = x.eq(mask_token_id) + + if bos_token_id is not None: + init_maskable[:, 0] = False + if eos_token_id is not None: + init_maskable &= x.ne(eos_token_id) + init_maskable &= x.ne(pad_token_id) + + maskable = init_maskable.clone() + xt = x.clone() + + if visualizer: + visualizer.start_visualization(xt, maskable, num_steps) + + def forward_scores(tokens): + """Compute predictions and entropy scores for next tokens.""" + # Try with input_ids parameter first (standard HF models) + try: + model_output = model(input_ids=tokens) + except TypeError: + # Fall back to positional argument + model_output = model(tokens) + + # Apply temperature scaling (with safety for near-zero temperature) + safe_temperature = max(temperature, 1e-8) # Prevent division by zero + logits = model_output.logits / safe_temperature + + # Apply filtering strategies + # Note: When both top_k and top_p are provided, they are applied sequentially: + # First top_k filters to k tokens, then top_p filters from those k tokens + if top_k is not None and top_k > 0: + logits = apply_top_k_filtering(logits, top_k) + + if top_p is not None and 0 < top_p < 1.0: + logits = apply_top_p_filtering(logits, top_p) + + # Convert to log probabilities + logp = torch.log_softmax(logits, dim=-1) + + # Greedy or stochastic sampling + if greedy: + pred_next = logp.argmax(-1) + else: + pred_next = torch.distributions.Categorical(logits=logp).sample() + + conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1) + + p = logp.exp() + ent_next = -(p * logp).sum(-1) + + # Shift predictions: pos i predicts token i+1 + pred_i = tokens.clone() + conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min) + ent_i = torch.zeros_like(ent_next) + + pred_i[:, 1:] = pred_next[:, :-1] + conf_i[:, 1:] = conf_next[:, :-1] + ent_i[:, 1:] = ent_next[:, :-1] + + return pred_i, conf_i, ent_i + + pred_i, conf_i, ent_i = forward_scores(xt) + total_masked = init_maskable.sum(1, keepdim=True) + finf = torch.finfo(conf_i.dtype) + + for step in range(num_steps - 1, 0, -1): + rate = step / num_steps + cutoff_len = (total_masked * rate).long().clamp(min=0) + + # Choose HIGH-entropy tokens to keep masked + sel_scores = ent_i.masked_fill(~maskable, -finf.max) + B, L = sel_scores.shape + k_max = cutoff_len.max().item() + if k_max > 0: + sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True) + keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool) + for b in range(B): + k_b = int(cutoff_len[b].item()) + if k_b > 0: + keep_mask[b, idx[b, :k_b]] = True + else: + keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool) + + to_unmask = maskable & ~keep_mask + if to_unmask.any(): + xt[to_unmask] = pred_i[to_unmask] + maskable[to_unmask] = False + + if visualizer: + visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i) + + if maskable.any(): + pred_i, conf_i, ent_i = forward_scores(xt) + + if maskable.any(): + xt[maskable] = pred_i[maskable] + + if visualizer: + visualizer.stop_visualization() + + return xt diff --git a/dllm/dllm/pipelines/rnd/models/terminal_visualizer.py b/dllm/dllm/pipelines/rnd/models/terminal_visualizer.py new file mode 100644 index 0000000..f34d88c --- /dev/null +++ b/dllm/dllm/pipelines/rnd/models/terminal_visualizer.py @@ -0,0 +1,251 @@ +# Copyright 2025 Radical Numerics Inc. +# +# This source code is licensed under the Apache License, Version 2.0, found in the +# LICENSE file in the root directory of this source tree. + +""" +Terminal visualization for RND1 generation. + +This module provides real-time visualization of the diffusion denoising process, +showing token evolution and generation progress in the terminal using rich +formatting when available. +""" + +import torch +from typing import Optional +from tqdm import tqdm + +try: + from rich.console import Console + from rich.live import Live + from rich.text import Text + from rich.panel import Panel + from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn + from rich.layout import Layout + RICH_AVAILABLE = True +except ImportError: + RICH_AVAILABLE = False + + +class TerminalVisualizer: + """ + Rich-based visualization for diffusion process with live updates. + + Provides real-time visualization of the token denoising process during + diffusion-based language generation, with colored highlighting of masked + positions and progress tracking. + """ + + def __init__(self, tokenizer, show_visualization: bool = True): + """ + Initialize the terminal visualizer. + + Args: + tokenizer: The tokenizer for decoding tokens to text + show_visualization: Whether to show visualization (requires rich) + """ + self.tokenizer = tokenizer + self.show_visualization = show_visualization and RICH_AVAILABLE + if not RICH_AVAILABLE and show_visualization: + print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.") + self.show_visualization = False + + if self.show_visualization: + self.console = Console() + self.live = None + self.progress = None + self.layout = None + else: + self.pbar = None + + self.current_tokens = None + self.mask_positions = None + self.total_steps = 0 + self.current_step = 0 + + def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int): + """ + Start the visualization. + + Args: + initial_tokens: Initial token IDs (possibly masked) + mask_positions: Boolean mask indicating which positions are masked + total_steps: Total number of diffusion steps + """ + if not self.show_visualization: + self.pbar = tqdm(total=total_steps, desc="Diffusion") + return + + self.current_tokens = initial_tokens.clone() + self.mask_positions = mask_positions + self.total_steps = total_steps + self.current_step = 0 + + self.layout = Layout() + self.layout.split_column( + Layout(name="header", size=3), + Layout(name="text", ratio=1), + Layout(name="progress", size=3) + ) + + self.progress = Progress( + TextColumn("[bold blue]Diffusion"), + BarColumn(), + MofNCompleteColumn(), + TextColumn("•"), + TextColumn("[cyan]Masks: {task.fields[masks]}"), + TimeRemainingColumn(), + ) + self.progress_task = self.progress.add_task( + "Generating", + total=total_steps, + masks=mask_positions.sum().item() + ) + + self.live = Live(self.layout, console=self.console, refresh_per_second=4) + self.live.start() + self._update_display() + + def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int, + entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None): + """ + Update visualization for current step. + + Args: + tokens: Current token IDs + maskable: Boolean mask of remaining masked positions + step: Current step number + entropy: Optional entropy scores for each position + confidence: Optional confidence scores for each position + """ + if not self.show_visualization: + if self.pbar: + self.pbar.update(1) + masks = maskable.sum().item() if maskable is not None else 0 + self.pbar.set_postfix({'masks': masks}) + return + + self.current_tokens = tokens.clone() + self.mask_positions = maskable + self.current_step = step + + masks_remaining = maskable.sum().item() if maskable is not None else 0 + self.progress.update( + self.progress_task, + advance=1, + masks=masks_remaining + ) + + self._update_display() + + def _update_display(self): + """Update the live display.""" + if not self.live: + return + + header = Text("RND1-Base Generation", style="bold magenta", justify="center") + self.layout["header"].update(Panel(header, border_style="bright_blue")) + + text_display = self._format_text_with_masks() + self.layout["text"].update( + Panel( + text_display, + title="[bold]Generated Text", + subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]", + border_style="cyan" + ) + ) + + self.layout["progress"].update(Panel(self.progress)) + + def _format_text_with_masks(self) -> Text: + """ + Format text with colored masks. + + Returns: + Rich Text object with formatted tokens + """ + text = Text() + + if self.current_tokens is None: + return text + + token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens + mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions + + for i, token_id in enumerate(token_ids): + if mask_flags is not None and i < len(mask_flags) and mask_flags[i]: + # Alternate colors for visual effect + text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red") + else: + try: + token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False) + # Skip special tokens in display + if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "", ""]: + # Color based on position + text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan") + except: + continue + + return text + + def stop_visualization(self): + """Stop the visualization and display final result.""" + if not self.show_visualization: + if self.pbar: + self.pbar.close() + print("\n✨ Generation complete!\n") + return + + if self.live: + self.live.stop() + + self.console.print("\n[bold green]✨ Generation complete![/bold green]\n") + + # Display final text + if self.current_tokens is not None: + try: + token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens + final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) + + self.console.print(Panel( + final_text, + title="[bold]Final Generated Text", + border_style="green", + padding=(1, 2) + )) + except: + pass + + +class SimpleProgressBar: + """ + Simple progress bar fallback when rich is not available. + + Provides basic progress tracking using tqdm when the rich library + is not installed. + """ + + def __init__(self, total_steps: int): + """ + Initialize simple progress bar. + + Args: + total_steps: Total number of steps + """ + self.pbar = tqdm(total=total_steps, desc="Diffusion") + + def update(self, masks_remaining: int = 0): + """ + Update progress bar. + + Args: + masks_remaining: Number of masks still remaining + """ + self.pbar.update(1) + self.pbar.set_postfix({'masks': masks_remaining}) + + def close(self): + """Close the progress bar.""" + self.pbar.close() + print("\n✨ Generation complete!\n") diff --git a/dllm/dllm/pipelines/rnd/trainer.py b/dllm/dllm/pipelines/rnd/trainer.py new file mode 100644 index 0000000..bc9696f --- /dev/null +++ b/dllm/dllm/pipelines/rnd/trainer.py @@ -0,0 +1,23 @@ +from typing import Any + +import torch + +from dllm.core.trainers import MDLMTrainer + + +class RNDTrainer(MDLMTrainer): + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def _preprocess_inputs(self, inputs): + labels = inputs["labels"] + assert (labels[:, 0] == -100).all() + + def _postprocess_outputs(self, outputs): + logits = outputs.logits + outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) diff --git a/dllm/dllm/tools/chat.py b/dllm/dllm/tools/chat.py new file mode 100644 index 0000000..7bb59ff --- /dev/null +++ b/dllm/dllm/tools/chat.py @@ -0,0 +1,242 @@ +import shutil +from typing import List, Literal + +import textwrap + +import dllm + + +# ============================================================ +# Utility helpers +# ============================================================ + +try: + L = shutil.get_terminal_size().columns + if not isinstance(L, int) or L <= 0: + L = 120 +except Exception: + L = 120 +DIV = "=" * L +SUB = "-" * L + + +def banner_line(text: str, width: int = L, fill: str = "=") -> str: + """Return a centered banner line with given width and fill.""" + text = f" {text.strip()} " + fill_len = width - len(text) + if fill_len <= 0: + return text + left = fill_len // 2 + right = fill_len - left + return f"{fill * left}{text}{fill * right}" + + +def print_wrapped(text: str, width: int = L): + """Print text with automatic line wrapping.""" + wrapped = textwrap.fill(text, width=width) + print(wrapped) + + +def boxed(text: str, width: int = L, padding: int = 1): + """Render a centered box with the given text and width.""" + lines = text.splitlines() + content_width = max(len(line) for line in lines) + box_width = min(width, content_width + padding * 2 + 2) + + # compute left margin for centering + terminal_width = width + left_margin = max((terminal_width - box_width) // 2, 0) + margin = " " * left_margin + + top = margin + "┌" + "─" * (box_width - 2) + "┐" + bottom = margin + "└" + "─" * (box_width - 2) + "┘" + + print(top) + for line in lines: + inner = line.center(content_width) + print(margin + "│" + " " * padding + inner + " " * padding + "│") + print(bottom) + + +def decode_trim(tokenizer, seq_ids_list, input_ids_list) -> str: + """ + Return only the generated text, truncated at the first EOS **after** the prompt. + + Args: + tokenizer: HF tokenizer with eos_token_id / pad_token_id. + seq_ids: Full sequence token ids from the model (prompt + generation). + input_ids: The prompt token ids that were fed into the model. + + Behavior: + - Finds the first eos_token_id that occurs at or after len(input_ids). + - Slices generation up to (but not including) that EOS. + - Decodes only the generation span, skipping special/pad tokens. + """ + # Make sure we can index these + sequences = [] + for seq_ids, input_ids in zip(seq_ids_list, input_ids_list): + full = list(seq_ids) + prompt = list(input_ids) + + # Skip left padding tokens (necessary for dream) + pad_id = getattr(tokenizer, "pad_token_id", None) + if pad_id is not None: + while full and full[0] == pad_id: + full.pop(0) + + start = len(prompt) + end = len(full) + + eos_id = getattr(tokenizer, "eos_token_id", None) + eot_id = getattr(tokenizer, "eot_token_id", None) + if eos_id is not None: + for i in range(start, len(full)): + if full[i] in (eos_id, eot_id): + end = i + break + + gen_ids = full[start:end] + text = tokenizer.decode(gen_ids, skip_special_tokens=True) + # in case there is no eos_id or eot_id, just strings + eos = getattr(tokenizer, "eos_token", None) + eot = getattr(tokenizer, "eot_token", None) + if eos: + text = text.split(eos)[0] + if eot: + text = text.split(eot)[0] + # return text.strip() + sequences.append(text) + return sequences + + +def render_menu(round_idx: int): + """Render a boxed menu of possible actions.""" + if round_idx == 0: + text = ( + "Possible next actions:\n" + "[1] Continue this chat\n" + "[2] End this chat and start a new one\n" + "[3] Exit" + ) + else: + text = ( + f"(Round {round_idx})\n" + "Possible next actions:\n" + "[1] Continue this chat\n" + "[2] End this chat and start a new one\n" + "[3] Exit" + ) + + print() # spacing + boxed(text) + + +def prompt_choice() -> Literal["1", "2", "3"]: + while True: + print("Select action [1/2/3]: ") + choice = input().strip() + if choice in ("1", "2", "3"): + return choice + print(banner_line("", fill=" ")) + + +def build_chat_inputs(tokenizer, messages: List[dict], add_generation_prompt: bool): + """Tokenize chat messages into inputs tensor.""" + return tokenizer.apply_chat_template( + messages, + add_generation_prompt=add_generation_prompt, + tokenize=True, + ) + + +def visualize_histories(tokenizer, histories): + try: + terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer( + tokenizer=tokenizer + ) + terminal_visualizer.visualize(histories, rich=True) + except Exception as e: + print(f"(Visualization skipped: {e})") + + +# ============================================================ +# Modes +# ============================================================ +def single_turn_generate(generator, gen_config, visualize: bool): + print() + print(banner_line("continuation mode")) + model, tokenizer = generator.model, generator.tokenizer + + while True: + print(banner_line("", fill=" ")) + try: + # user_text = input("Prompt > ").strip() + print("[Prompt] > ") + user_text = input().strip() + except (EOFError, KeyboardInterrupt): + print("\n" + banner_line("Exiting. Bye!", width=len(DIV))) + return + + # if not user_text: + # print("(Empty input, skipped)\n") + # continue + + inputs = tokenizer([user_text], add_special_tokens=False)["input_ids"] + outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True) + text = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0] + + print(banner_line("Output")) + print_wrapped(text if text else "") + print(DIV + "\n") + + if visualize: + visualize_histories(tokenizer, outputs.histories) + + +def multi_turn_chat(generator, gen_config, visualize: bool): + # """Chat mode with chat template & message history.""" + print() + print(banner_line("multi-turn chat mode")) + print(banner_line("", fill=" ")) + model, tokenizer = generator.model, generator.tokenizer + + messages: List[dict] = [] + round_idx = 0 + + while True: + try: + print("[You]:") + user_msg = input().strip() + except (EOFError, KeyboardInterrupt): + print("\nExiting. Bye!") + return + + messages.append({"role": "user", "content": user_msg}) + inputs = build_chat_inputs(tokenizer, [messages], add_generation_prompt=True) + + outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True) + reply = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0] + + print(DIV) + print_wrapped("[Assistant]: " + reply if reply else "") + print(DIV + "\n") + + messages.append({"role": "assistant", "content": reply}) + + if visualize: + visualize_histories(tokenizer, outputs.histories) + + render_menu(round_idx) + choice = prompt_choice() + if choice == "1": + print(banner_line("", fill=" ")) + round_idx += 1 + continue + elif choice == "2": + print(banner_line("", fill=" ")) + messages = [] + round_idx = 0 + continue + else: + print("\nExiting. Bye!") + return diff --git a/dllm/dllm/tools/download_hf_dataset.py b/dllm/dllm/tools/download_hf_dataset.py new file mode 100644 index 0000000..0e40c46 --- /dev/null +++ b/dllm/dllm/tools/download_hf_dataset.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +import tyro +from huggingface_hub import snapshot_download + + +@dataclass +class ScriptArguments: + dataset_id: str = "Anthropic/hh-rlhf" + allow_patterns: str = None + + +script_args = tyro.cli(ScriptArguments) + +# Replace with the dataset repo you want, e.g. "wikitext" +dataset_id = script_args.dataset_id + +# Replace with your desired local directory +local_dir = f"/mnt/lustrenew/mllm_aligned/shared/datasets/huggingface/{dataset_id}" + +# Download the dataset snapshot +snapshot_download( + repo_id=dataset_id, + repo_type="dataset", # 👈 tell HF it's a dataset + local_dir=local_dir, + local_dir_use_symlinks=False, # ensures real files, not symlinks + allow_patterns=script_args.allow_patterns, +) + +print(f"Dataset downloaded to: {local_dir}") diff --git a/dllm/dllm/tools/download_hf_model.py b/dllm/dllm/tools/download_hf_model.py new file mode 100644 index 0000000..434bf91 --- /dev/null +++ b/dllm/dllm/tools/download_hf_model.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + +import tyro +from huggingface_hub import snapshot_download + + +@dataclass +class ScriptArguments: + model_id: str = "GSAI-ML/LLaDA-8B-Instruct" + + +script_args = tyro.cli(ScriptArguments) + +# Replace with the model repo you want, e.g. "bert-base-uncased" +model_id = script_args.model_id + +# Replace with your desired local directory +local_dir = f"/mnt/lustrenew/mllm_aligned/shared/models/huggingface/{model_id}" + +# Download the model snapshot +snapshot_download( + repo_id=model_id, + local_dir=local_dir, + local_dir_use_symlinks=False, # ensures real files, not symlinks +) + +print(f"Model downloaded to: {local_dir}") diff --git a/dllm/dllm/tools/generate.py b/dllm/dllm/tools/generate.py new file mode 100644 index 0000000..4640904 --- /dev/null +++ b/dllm/dllm/tools/generate.py @@ -0,0 +1 @@ +# TODO diff --git a/dllm/dllm/tools/merge_peft_adapter.py b/dllm/dllm/tools/merge_peft_adapter.py new file mode 100644 index 0000000..66a02d7 --- /dev/null +++ b/dllm/dllm/tools/merge_peft_adapter.py @@ -0,0 +1,80 @@ +""" +Merge a PEFT/LoRA adapter into its base model (auto-detected from adapter_config.json). + +Usage: + python dllm_trainer/tools/merge_peft_adapter.py \ + --adapter_model_name_or_path your-org/your-lora \ + --output_model_name_or_path ./merged-model \ + --dtype bf16 +""" + +from dataclasses import dataclass, field +from typing import Optional + +import torch +from peft import PeftConfig, PeftModel +from transformers import AutoModel, AutoTokenizer, HfArgumentParser + +import dllm # so that no need to trust_remote_code + +DTYPE_MAP = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float} + + +@dataclass +class ScriptArguments: + adapter_model_name_or_path: str | None = field( + default=None, metadata={"help": "Adapter repo or local path"} + ) + output_model_name_or_path: str | None = field( + default=None, + metadata={"help": "Where to save the merged model (folder or repo id)"}, + ) + dtype: str | None = field(default="fp16", metadata={"help": "fp16|bf16|fp32"}) + push_to_hub: bool | None = field( + default=False, metadata={"help": "Push merged weights to the Hub"} + ) + # Optional override if adapter config lacks base info: + base_model_name_or_path: str | None = field( + default=None, + metadata={"help": "Override base model if adapter config lacks it"}, + ) + + +def main(): + parser = HfArgumentParser(ScriptArguments) + args = parser.parse_args_into_dataclasses()[0] + + assert args.adapter_model_name_or_path, "please provide the adapter (repo or path)" + assert args.output_model_name_or_path, "please provide output_model_name_or_path" + assert args.dtype in DTYPE_MAP, f"dtype must be one of {list(DTYPE_MAP.keys())}" + + # Read base path from adapter_config.json + peft_cfg = PeftConfig.from_pretrained(args.adapter_model_name_or_path) + base_id = args.base_model_name_or_path or getattr( + peft_cfg, "base_model_name_or_path", None + ) + assert base_id, ( + "adapter_config.json does not include base_model_name_or_path; " + "pass --base_model_name_or_path to override." + ) + + # Load base model and tokenizer + model = AutoModel.from_pretrained( + base_id, return_dict=True, dtype=DTYPE_MAP[args.dtype] + ) + tokenizer = AutoTokenizer.from_pretrained(base_id) + + # Attach adapter, merge, and unload PEFT layers + model = PeftModel.from_pretrained(model, args.adapter_model_name_or_path) + model.eval() + model = model.merge_and_unload() # plain transformers model + + # Save locally + model.save_pretrained(args.output_model_name_or_path) + tokenizer.save_pretrained(args.output_model_name_or_path) + + print(f"✓ merged model saved to: {args.output_model_name_or_path}") + + +if __name__ == "__main__": + main() diff --git a/dllm/dllm/tools/preprocess_pt_dataset.py b/dllm/dllm/tools/preprocess_pt_dataset.py new file mode 100644 index 0000000..1145404 --- /dev/null +++ b/dllm/dllm/tools/preprocess_pt_dataset.py @@ -0,0 +1,109 @@ +""" +python dllm/tools/preprocess_pt_dataset.py +""" + +import os +from dataclasses import dataclass, asdict +from functools import partial + +import datasets +import tyro +import transformers +from pprint import pprint + +import dllm + + +@dataclass +class ScriptArguments: + """Preprocess PT dataset""" + + model_name_or_path: str = "answerdotai/ModernBERT-large" + dataset_args: str = "OpenCoder-LLM/opc-annealing-corpus[lang:python]" # required + output_dir: str = "data/pt/modernbert/opc-annealing-corpus[lang:python]" # required + text_field: str = "text" + max_length: int = 1024 + insert_eos: bool = True + drop_tail: bool = True + remove_columns: bool = False + num_proc: int = 32 + + def __post_init__(self): + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +def preprocess_pt_dataset( + dataset: datasets.DatasetDict, + tokenizer: transformers.PreTrainedTokenizer, + output_dir: str, + text_field: str = "text", + max_length: int = 1024, + insert_eos: bool = True, + drop_tail: bool = True, + remove_columns: bool = False, + num_proc: int = 32, +): + processed = dataset.map( + partial( + dllm.utils.tokenize_and_group, + tokenizer=tokenizer, + text_field=text_field, + seq_length=max_length, + insert_eos=insert_eos, + drop_tail=drop_tail, + ), + batched=True, + num_proc=num_proc, + remove_columns=dataset["train"].column_names, + ) + + # Keep only the three required columns to save space. + if remove_columns: + keep = {"input_ids", "labels"} + + def strip_cols(ds: datasets.Dataset) -> datasets.Dataset: + drop = [c for c in ds.column_names if c not in keep] + return ds.remove_columns(drop) if drop else ds + + if isinstance(processed, datasets.DatasetDict): + for split in list(processed.keys()): + processed[split] = strip_cols(processed[split]) + else: + processed = strip_cols(processed) + + output_dir = os.path.join( + output_dir, + f"max_length-{max_length}-insert_eos-{insert_eos}-drop_tail-{drop_tail}", + ) + os.makedirs(output_dir, exist_ok=True) + processed.save_to_disk(output_dir) + print(f"[OK] Saved to: {output_dir}") + + +def main(): + # Parse with tyro + args = tyro.cli(ScriptArguments) + dllm.utils.print_args(args) + + tokenizer = dllm.utils.get_tokenizer(args) + + # Load your raw dataset (must contain a "messages" field per example). + dataset = dllm.data.load_pt_dataset(args.dataset_args, streaming=False) + + preprocess_pt_dataset( + dataset=dataset, + tokenizer=tokenizer, + output_dir=args.output_dir, + text_field=args.text_field, + max_length=args.max_length, + insert_eos=args.insert_eos, + drop_tail=args.drop_tail, + remove_columns=args.remove_columns, + num_proc=args.num_proc, + ) + + +if __name__ == "__main__": + main() diff --git a/dllm/dllm/tools/preprocess_sft_dataset.py b/dllm/dllm/tools/preprocess_sft_dataset.py new file mode 100644 index 0000000..3fce214 --- /dev/null +++ b/dllm/dllm/tools/preprocess_sft_dataset.py @@ -0,0 +1,117 @@ +""" +Example: + +PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \ + --model_name_or_path "Dream-org/Dream-v0-Base-7B" \ + --sft_map_fn_path "examples.dream.sft.sft_map_fn" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "data/sft/dream/tulu-3-sft-mixture" \ + --num_proc 64 +""" + +import os +import importlib +from dataclasses import dataclass +from functools import partial + +import datasets +import tyro + +import dllm + + +@dataclass +class ScriptArguments: + """Preprocess SFT dataset""" + + model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base" + sft_map_fn_path: str = "dllm.utils.default_sft_map_fn" + dataset_args: str = "HuggingFaceTB/smoltalk" # required + output_dir: str = "data/sft/llada/smoltalk" # required + mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100 + num_proc: int = 32 + remove_columns: bool = False + + def __post_init__(self): + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +def preprocess_sft_dataset( + dataset: datasets.DatasetDict, + map_fn: callable, + output_dir: str, + remove_columns: bool = False, + num_proc: int = 32, +): + processed = dataset.map( + map_fn, + batched=False, + num_proc=num_proc, + load_from_cache_file=True, + writer_batch_size=512, + desc="offline preprocessing", + ) + + # Keep only the three required columns to save space. + if remove_columns: + keep = {"input_ids", "labels", "prompt_len", "attention_mask"} + + def strip_cols(ds: datasets.Dataset) -> datasets.Dataset: + drop = [c for c in ds.column_names if c not in keep] + return ds.remove_columns(drop) if drop else ds + + if isinstance(processed, datasets.DatasetDict): + for split in list(processed.keys()): + processed[split] = strip_cols(processed[split]) + else: + processed = strip_cols(processed) + + os.makedirs(output_dir, exist_ok=True) + processed.save_to_disk(output_dir) + print(f"[OK] Saved to: {output_dir}") + + +def main(): + # Parse with tyro + args = tyro.cli(ScriptArguments) + dllm.utils.print_args(args) + + tokenizer = dllm.utils.get_tokenizer(args) + + # Load your raw dataset (must contain a "messages" field per example). + dataset = dllm.data.load_sft_dataset(args.dataset_args) + + # 4. Dynamically import the function based on the argument + try: + # Split the path into module and function name + module_path, function_name = args.sft_map_fn_path.rsplit(".", 1) + + # Import the module + module = importlib.import_module(module_path) + + # Get the function from the module + sft_map_fn = getattr(module, function_name) + + except (ImportError, AttributeError, ValueError) as e: + print(f"Error: Could not import '{args.sft_map_fn_path}'.") + print(f"Details: {e}") + return + + map_fn = partial( + sft_map_fn, + tokenizer=tokenizer, + mask_prompt_loss=args.mask_prompt_loss, + ) + preprocess_sft_dataset( + dataset=dataset, + map_fn=map_fn, + output_dir=args.output_dir, + remove_columns=args.remove_columns, + num_proc=args.num_proc, + ) + + +if __name__ == "__main__": + main() diff --git a/dllm/dllm/utils/__init__.py b/dllm/dllm/utils/__init__.py new file mode 100644 index 0000000..1ac787f --- /dev/null +++ b/dllm/dllm/utils/__init__.py @@ -0,0 +1,6 @@ +from . import configs, generation_utils, model_utils, utils +from .configs import * +from .generation_utils import * +from .data_utils import * +from .model_utils import * +from .utils import * diff --git a/dllm/dllm/utils/configs.py b/dllm/dllm/utils/configs.py new file mode 100644 index 0000000..e1cb6f2 --- /dev/null +++ b/dllm/dllm/utils/configs.py @@ -0,0 +1,77 @@ +import os +from dataclasses import dataclass, field + +import transformers + +from dllm.utils.utils import resolve_with_base_env, get_default_logger + +logger = get_default_logger(__name__) + + +@dataclass +class ModelArguments: + model_name_or_path: str = None # overwrite this + dtype: str = "bfloat16" + load_in_4bit: bool = False + attn_implementation: str = None + # --- fold PEFT args here --- + lora: bool = False + target_modules: str = "all-linear" + r: int = 32 + lora_alpha: int = 64 + lora_dropout: float = 0.05 + bias: str = "none" + modules_to_save: str = None + + def __post_init__(self): + self.model_name_or_path = resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +@dataclass +class DataArguments: + dataset_args: str = None # overwrite this + num_proc: int = 8 + disable_caching: bool = False + max_length: int = 1024 + truncation: str = field( + default="right", + metadata={ + "help": ( + 'The truncation strategy to use ("filter" or "right"). ' + '"filter" only keeps sequences that are shorter than max_length; ' + '"right" only keeps the rightmost max_length tokens for each sequence.' + ) + }, + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + output_dir: str = None # overwrite this + report_to: str = "wandb" + overwrite_output_dir: bool = True + seed: int = 42 + per_device_train_batch_size: int = 4 + per_device_eval_batch_size: int = 4 + gradient_accumulation_steps: int = 1 + learning_rate: float = 2e-5 + lr_scheduler_type: str = "cosine" + warmup_ratio: float = 0.1 + bf16: bool = True + num_train_epochs: float = 4 + logging_steps: float = 10 + eval_on_start: bool = False + eval_strategy: str = "steps" + eval_steps: float = 0.25 + save_steps: float = 0.25 + save_only_model: bool = True + + def __post_init__(self): + super().__post_init__() + if self.group_by_length: + logger.info( + "training_args.group_by_length=True: preprocessing " + "may take some time after `trainer.train()` starts." + ) diff --git a/dllm/dllm/utils/data_utils.py b/dllm/dllm/utils/data_utils.py new file mode 100644 index 0000000..f68c2bc --- /dev/null +++ b/dllm/dllm/utils/data_utils.py @@ -0,0 +1,222 @@ +import random +import warnings +from dataclasses import dataclass +from itertools import chain +from typing import TYPE_CHECKING + +import torch +import datasets +import transformers + +if TYPE_CHECKING: + from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments + + +def tokenize_and_group( + examples, + tokenizer, + text_field: str = "text", + seq_length: int = 1024, + insert_eos: bool = False, + drop_tail: bool = True, + add_special_tokens: bool = False, +): + # 1) Tokenize (batched input) + tokenized = tokenizer(examples[text_field], add_special_tokens=add_special_tokens) + ids = tokenized["input_ids"] + + # --- optionally append EOS to each sample --- + if insert_eos: + eos_id = getattr(tokenizer, "eos_token_id") + assert eos_id + # append EOS only if the sample doesn't already end with it + ids = [seq + ([] if (seq and seq[-1] == eos_id) else [eos_id]) for seq in ids] + # ---------------------------------------------------------------- + + # 2) Flatten and concatenate all token lists + concatenated = list(chain.from_iterable(ids)) + if not concatenated: + return {"input_ids": [], "labels": []} # Safe return for empty batch + + # 3) Calculate the total length based on drop_tail + if drop_tail: + total_len = (len(concatenated) // seq_length) * seq_length + concatenated = concatenated[:total_len] # Truncate the last incomplete chunk + else: + total_len = len(concatenated) + + # Split into fixed-length chunks + chunks = [concatenated[i : i + seq_length] for i in range(0, total_len, seq_length)] + + return { + "input_ids": chunks, + "labels": [c[:] for c in chunks], # Labels are the same as input_ids + } + + +def clip_row(row: dict, max_length: int, truncation: str = "right") -> dict: + for key in ("input_ids", "labels", "attention_mask"): + if key in row: + if truncation == "right": + row[key] = row[key][:max_length] + elif truncation == "left": + row[key] = row[key][-max_length:] + else: + raise NotImplementedError + return row + + +def post_process_dataset( + dataset: datasets.DatasetDict, data_args: "DataArguments" +) -> datasets.DatasetDict: + if data_args.truncation == "filter": + return dataset.filter( + lambda row: len(row["input_ids"]) <= data_args.max_length, + num_proc=data_args.num_proc, + desc=f"Filtering samples with length <= {data_args.max_length}", + ) + elif data_args.truncation == "right": + # do this only if dataset has "prompt_len" + if "prompt_len" in dataset.column_names["train"]: + dataset = dataset.filter( + lambda row: row["prompt_len"] <= data_args.max_length, + num_proc=data_args.num_proc, + desc=f"Filtering samples with `prompt_len` <= {data_args.max_length}", + ) + return dataset.map( + lambda row: clip_row(row, data_args.max_length, truncation="right"), + num_proc=data_args.num_proc, + desc=f"Right-truncating samples to max_length={data_args.max_length}", + ) + else: + raise NotImplementedError + + +def clip_row_streaming(row: dict, max_length: int, truncation: str = "right") -> dict: + """Clip whole sequence OR (if prompt_len present) preserve prompt and clip only the response.""" + if truncation not in {"right", "left"}: + raise NotImplementedError(f"Unknown truncation: {truncation}") + + def clip(seq): + return seq[:max_length] if truncation == "right" else seq[-max_length:] + + def clip_preserve_prompt(seq, prompt_len: int): + prompt = seq[:prompt_len] + resp = seq[prompt_len:] + budget = max(0, max_length - len(prompt)) + resp = resp[:budget] if truncation == "right" else resp[-budget:] + return prompt + resp + + prompt_len = row.get("prompt_len", None) + for k in ("input_ids", "labels", "attention_mask"): + if k in row and isinstance(row[k], list): + row[k] = ( + clip_preserve_prompt(row[k], prompt_len) + if isinstance(prompt_len, int) and prompt_len >= 0 + else clip(row[k]) + ) + return row + + +def post_process_dataset_streaming( + dataset: datasets.IterableDatasetDict, + data_args: "DataArguments", +) -> datasets.IterableDatasetDict: + + def _train_has_prompt_len_streaming(dataset: datasets.IterableDatasetDict) -> bool: + """Replicates: 'if \"prompt_len\" in dataset.column_names[\"train\"]' for streaming.""" + it = dataset["train"].take(1) + try: + ex = next(iter(it)) + except StopIteration: + return False + return "prompt_len" in ex + + mode = data_args.truncation + max_len = data_args.max_length + + if mode == "filter": + # Keep rows with len(input_ids) <= max_len (emulate .filter with generator map) + def keep_if_short(row): + if ( + "input_ids" in row + and isinstance(row["input_ids"], list) + and len(row["input_ids"]) <= max_len + ): + yield row # keep + # else: drop (yield nothing) + + return datasets.IterableDatasetDict( + {name: ds.map(keep_if_short) for name, ds in dataset.items()} + ) + + elif mode == "right": + ds_out = dataset + + # Do this only if TRAIN split has "prompt_len" (same condition as your non-streaming code) + if _train_has_prompt_len_streaming(ds_out): + + def keep_if_prompt_fits(row): + pl = row.get("prompt_len", None) + if isinstance(pl, int) and pl <= max_len: + yield row # keep + elif pl is None: + # If a row lacks prompt_len but train had it, the non-streaming code would try to access it and fail. + # Here we conservatively drop such rows to mirror "requires prompt_len <= max_len". + return + # else: drop + + ds_out = datasets.IterableDatasetDict( + {name: ds.map(keep_if_prompt_fits) for name, ds in ds_out.items()} + ) + + # Then clip right (same clipping as clip_row) + def clip_right(row): + return clip_row(row, max_len, truncation="right") + + return datasets.IterableDatasetDict( + {name: ds.map(clip_right) for name, ds in ds_out.items()} + ) + + else: + raise NotImplementedError + + +@dataclass +class NoAttentionMaskCollator(transformers.DataCollatorForSeq2Seq): + def __call__(self, features, return_tensors=None): + outputs = super().__call__(features, return_tensors) + # fintune on padding ; should not mask them out + outputs.pop("attention_mask") + return outputs + + +def default_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict: + """ + Build input_ids and labels for SFT. + + Args: + row: a dataset row with `messages` + tokenizer: a HF tokenizer + mask_prompt_loss: whether to mask prompt tokens (set their labels to -100) + + Returns: + dict with keys: input_ids, labels, and optionally prompt_len + """ + prompt_response_tokens = tokenizer.apply_chat_template( + row["messages"], tokenize=True, add_generation_prompt=False + ) + labels = prompt_response_tokens.copy() + + if mask_prompt_loss: + prompt_tokens = tokenizer.apply_chat_template( + row["messages"][:-1], tokenize=True, add_generation_prompt=True + ) + labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens) + return { + "input_ids": prompt_response_tokens, + "labels": labels, + "prompt_len": len(prompt_tokens), + } + + return {"input_ids": prompt_response_tokens, "labels": labels} diff --git a/dllm/dllm/utils/generation_utils.py b/dllm/dllm/utils/generation_utils.py new file mode 100644 index 0000000..af1120c --- /dev/null +++ b/dllm/dllm/utils/generation_utils.py @@ -0,0 +1,53 @@ +import torch + +from dllm.core.schedulers import BaseAlphaScheduler + + +def get_num_transfer_tokens( + mask_index: torch.Tensor, + steps: int, + scheduler: BaseAlphaScheduler, + stochastic: bool = False, +) -> torch.Tensor: + mask_num = mask_index.sum(dim=1, keepdim=True) + num_transfer_tokens = torch.zeros( + mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64 + ) + for i in range(mask_num.size(0)): + for t, s, j in zip(range(steps, 0, -1), range(steps - 1, -1, -1), range(steps)): + s /= steps + t /= steps + reverse_transfer_prob = 1 - scheduler.reverse_mask_prob(s=s, t=t) + if not stochastic: + x = mask_num[i, 0].to(torch.float64) * reverse_transfer_prob + num_transfer_tokens[i, j] = torch.round(x).to(torch.int64) + else: + n = mask_num[i, 0].to(torch.float64) + num_transfer_tokens[i, j] = ( + torch.distributions.Binomial(n, reverse_transfer_prob) + .sample() + .to(torch.int64) + ) + num_transfer_tokens[i, j] = torch.minimum( + num_transfer_tokens[i, j], mask_num[i, 0] + ) + mask_num[i, 0] -= num_transfer_tokens[i, j] + if mask_num[i, 0].item() == 0: + break + # Note: because llada is not conditioned on time, this allows us to skip steps with no unmasking (i.e. transfer). + # Clear all zeros per row (compact) and right-pad with zeros + # Remove zeros per row, then pad only up to the max length across rows + rows = [] + max_len = 0 + for i in range(num_transfer_tokens.size(0)): + nonzero = num_transfer_tokens[i][num_transfer_tokens[i] > 0] + rows.append(nonzero) + max_len = max(max_len, nonzero.numel()) + # Pad each row to max_len + padded_rows = [] + for r in rows: + if r.numel() < max_len: + pad = torch.zeros(max_len - r.numel(), dtype=r.dtype, device=r.device) + r = torch.cat([r, pad]) + padded_rows.append(r) + return torch.stack(padded_rows, dim=0) diff --git a/dllm/dllm/utils/model_utils.py b/dllm/dllm/utils/model_utils.py new file mode 100644 index 0000000..71415d5 --- /dev/null +++ b/dllm/dllm/utils/model_utils.py @@ -0,0 +1,180 @@ +import torch +import accelerate +import transformers +from peft import prepare_model_for_kbit_training + +from dllm.utils.utils import disable_caching_allocator_warmup, print_main, load_peft +from dllm.utils.configs import ModelArguments, TrainingArguments + + +def get_model( + model_args, + config: transformers.PretrainedConfig | None = None, +) -> transformers.PreTrainedModel: + """ + Load a model with flexible input sources. + + Args: + model_args: An optional dataclass or namespace containing model parameters. + model_name_or_path: Optional direct model path or name (overrides model_args.model_name_or_path). + dtype: Dtype (string or torch.dtype). + load_in_4bit: Whether to load using 4-bit quantization (can override model_args.load_in_4bit). + + Returns: + transformers.PreTrainedModel + """ + model_name_or_path = getattr(model_args, "model_name_or_path") + dtype = getattr(model_args, "dtype", "bfloat16") + load_in_4bit = getattr(model_args, "load_in_4bit", False) + attn_implementation = getattr(model_args, "attn_implementation", None) + + # Device map: skip when ZeRO-3 + device_map = ( + {"": accelerate.PartialState().local_process_index} + if not transformers.modeling_utils.is_deepspeed_zero3_enabled() + and torch.cuda.is_available() + else None + ) + + quant_config = None + if load_in_4bit and transformers.utils.is_bitsandbytes_available(): + quant_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=dtype, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + + params = { + "dtype": dtype, + "device_map": device_map, + "quantization_config": quant_config, + "attn_implementation": attn_implementation, + "config": config, + } + + try: + model = transformers.AutoModelForMaskedLM.from_pretrained( + model_name_or_path, **params + ) + except: + model = transformers.AutoModel.from_pretrained(model_name_or_path, **params) + + # --- if quantized, prepare for LoRA / QLoRA training --- + if load_in_4bit and quant_config is not None: + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False) + + # Optionally train with lora + model = load_peft(model, model_args) + + return model + + +def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer: + """ + Load a tokenizer with flexible input sources. + + Args: + model_args: Optional dataclass or namespace containing model parameters. + model: Optional model instance to configure tokenizer behavior. + model_name_or_path: Optional direct model name or path (overrides model_args.model_name_or_path). + + Returns: + transformers.PreTrainedTokenizer + """ + # Lazy imports to avoid circular dependencies + from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM + from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM + from dllm.pipelines.dream.models.modeling_dream import DreamModel + from dllm.pipelines.rnd.models.modeling_rnd import RND1LM + from transformers import ( + BertPreTrainedModel, + RobertaPreTrainedModel, + ModernBertPreTrainedModel, + ) + + model_name_or_path = getattr(model_args, "model_name_or_path") + + # ---------------- Tokenizer loading ---------------- + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name_or_path, + padding_side="right", + ) + + assert tokenizer.eos_token != None or tokenizer.pad_token != None + + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + if not tokenizer.eos_token: + tokenizer.eos_token = tokenizer.pad_token + if not tokenizer.bos_token: + tokenizer.bos_token = tokenizer.pad_token + + # If model is not provided, return as-is + model_cfg = transformers.AutoConfig.from_pretrained(model_name_or_path) + model_cls = transformers.AutoModel._model_mapping[type(model_cfg)] + + # ---------------- Model-specific customization ---------------- + if issubclass(model_cls, LLaDAModelLM): + tokenizer.add_special_tokens({"mask_token": "<|mdm_mask|>"}) + tokenizer.eot_token = "<|eot_id|>" + # tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) # can not do this for llada base directly + # TODO: for llada base, add special_tokens = {"<|start_header_id|>": 126346, "<|end_header_id|>": 126347, "<|eot_id|>": 126348} + # fix bugs in chat template + tokenizer.chat_template = """\ +{% set loop_messages = messages %} +{% for message in loop_messages %} +{% if loop.index0 == 0 %}{{ bos_token }}{% endif %} +<|start_header_id|>{{ message['role'] }}<|end_header_id|> + +{{ message['content'] | trim }}<|eot_id|> +{%- endfor %} +{% if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %} +<|start_header_id|>assistant<|end_header_id|> + +{% endif %} +""" + elif issubclass(model_cls, LLaDAMoEModelLM): + tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) + tokenizer.eot_token = "<|role_end|>" + tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) + elif issubclass(model_cls, DreamModel): + tokenizer.eot_token = "<|im_end|>" + tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) + elif issubclass(model_cls, RND1LM): + tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) + elif issubclass( + model_cls, + (BertPreTrainedModel, RobertaPreTrainedModel, ModernBertPreTrainedModel), + ): + tokenizer.eot_token = "[/Answer]" + tokenizer.chat_template = """\ +{% if messages[0]['role'] == 'system' %} +[SYS] +{{ messages[0]['content'] | trim }} +[/SYS] + +{% set loop_messages = messages[1:] %} +{% else %} +{% set loop_messages = messages %} +{% endif -%} +{%- for message in loop_messages %} +{% if message['role'] == 'user' %} +[Question] +{{ message['content'] | trim }} +[/Question] + +{% elif message['role'] == 'assistant' %} +[Answer] +{{ message['content'] | trim }} +[/Answer] + +{% endif %} +{% endfor -%} +{%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %} +[Answer] +{% endif %} +""" + else: + print_main("no tokenizer customization for model class:", model_cls) + return tokenizer diff --git a/dllm/dllm/utils/utils.py b/dllm/dllm/utils/utils.py new file mode 100644 index 0000000..ab7c3a5 --- /dev/null +++ b/dllm/dllm/utils/utils.py @@ -0,0 +1,284 @@ +import os +import re +import sys +import logging +from contextlib import contextmanager +from dataclasses import dataclass, asdict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments + +import pprint +import torch +import peft +import accelerate +import transformers + + +def resolve_with_base_env(path: str, env_name: str) -> str: + """ + If `env_name` is set and `path` is NOT absolute, NOT a URL/scheme, + and does not already exist locally, prepend the `env_name` directory. + + If the resulting path does not exist, return the base environment directory instead. + Otherwise return `path` unchanged. + """ + base = os.getenv(env_name, "").strip() + if not base: + return path + if os.path.isabs(path): + return path + if os.path.exists(path): + return path + + candidate = os.path.join(base.rstrip("/"), path.lstrip("/")) + if os.path.exists(candidate): + return candidate + else: + raise FileNotFoundError + + +@contextmanager +def init_device_context_manager(device: str | torch.device | None = None): + """ + Temporarily set torch default dtype and default device so that tensors + created inside the context are allocated on `device` with dtype `dtype`. + Restores previous settings on exit. + """ + if transformers.integrations.is_deepspeed_zero3_enabled(): + yield + return + + # Resolve device + if device is None: + try: + from accelerate import PartialState + + idx = PartialState().local_process_index + except Exception: + idx = 0 + device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu" + elif isinstance(device, int): + device = f"cuda:{device}" + + try: + torch.set_default_device(device) + yield + finally: + torch.set_default_device("cpu") + + +def print_main(*args, **kwargs): + """ + Print only from the global main process (rank 0 across all nodes). + Usage: print_main("Hello from main process!") + """ + if accelerate.PartialState().is_main_process: + print(*args, **kwargs) + + +def pprint_main(*args, **kwargs): + """ + Print (with pprint) only from the global main process (rank 0 across all nodes). + Usage: print_main("Hello from main process!") + """ + if accelerate.PartialState().is_main_process: + pprint.pprint(*args, **kwargs) + + +def load_peft( + model: transformers.PreTrainedModel, model_args: "ModelArguments" +) -> transformers.PreTrainedModel: + """ + e.g., + --modules_to_save "lm_head" --target_modules "q_proj,k_proj,v_proj,o_proj,up_proj,down_proj,gate_proj" + --target_modules "all-linear" + """ + if not getattr(model_args, "lora", False): + return model + target_modules = ( + model_args.target_modules.split(",") if model_args.target_modules else None + ) + # if it’s a single 'all-linear', drop the list and use the string directly + if ( + target_modules + and len(target_modules) == 1 + and target_modules[0].strip() == "all-linear" + ): + target_modules = target_modules[0] + modules_to_save = ( + model_args.modules_to_save.split(",") if model_args.modules_to_save else None + ) + peft_config = peft.LoraConfig( + r=model_args.r, + target_modules=target_modules, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, + bias=model_args.bias, + modules_to_save=modules_to_save, + ) + model = peft.get_peft_model(model, peft_config) + if accelerate.PartialState().is_main_process: + print(model) + model.print_trainable_parameters() + return model + + +def print_args_main( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "TrainingArguments", +): + print_main("\n===== Parsed arguments =====") + for name, args in [ + ("model_args", model_args), + ("data_args", data_args), + ("training_args", training_args), + ]: + d = asdict(args) + # keep it tiny: just show first few entries + short = {k: d[k] for k in list(d)} # adjust number as you like + print_main(f"{name}:") + pprint_main(short, width=100, compact=True, sort_dicts=False) + print_main("============================\n") + + +def print_args(args): + print_main("\n===== Parsed arguments =====") + d = asdict(args) + # keep it tiny: just show first few entries + short = {k: d[k] for k in list(d)} # adjust number as you like + pprint_main(short, width=100, compact=True, sort_dicts=False) + print_main("============================\n") + + +def disable_caching_allocator_warmup(): + try: + from transformers import modeling_utils as _mu + + def _noop(*args, **kwargs): + return + + _mu.caching_allocator_warmup = _noop + except Exception: + pass + + +def disable_dataset_progress_bar_except_main(): + # state = accelerate.PartialState() # figures out your rank/world automatically + from datasets.utils.logging import disable_progress_bar, enable_progress_bar + + if accelerate.PartialState().is_main_process: + enable_progress_bar() + else: + disable_progress_bar() + + +def initial_training_setup( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "TrainingArguments", +): + transformers.set_seed(training_args.seed) + disable_caching_allocator_warmup() + disable_dataset_progress_bar_except_main() + if getattr(data_args, "disable_caching", False): + disable_dataset_caching() + + +def disable_dataset_caching(): + from datasets import disable_caching + + disable_caching() + tmp_root = f"/tmp/hf_cache_rank{accelerate.PartialState().process_index}" + os.environ["HF_DATASETS_CACHE"] = tmp_root + os.environ["HF_DATASETS_TEMP_DIR"] = tmp_root + os.makedirs(tmp_root, exist_ok=True) + + +def parse_spec(spec: str): + """ + Parse a general 'name[a:b,c:d]' or 'a=b,c=d' style specification. + + Supports: + - Bare name, e.g. "foo/bar" + - Optional bracket suffix with comma-separated entries: + key:value or key:int_value (underscores allowed) + - Optional "key=value" pairs outside the bracket. + + Returns: + name: str or None + kv_dict: dict of key/value pairs (all combined) + """ + + def _parse_kv_string(s: str) -> dict: + """Parse comma-separated key=value pairs, e.g. 'a=1,b=2'.""" + return dict(part.split("=", 1) for part in s.split(",") if "=" in part) + + s = spec.strip() + + # Extract bracket content if present + m = re.search(r"\[(.*?)\]$", s) + bracket_kvs = {} + numeric_kvs = {} + if m: + bracket = m.group(1).strip() + if bracket: + for part in bracket.split(","): + part = part.strip() + if not part: + continue + if ":" not in part: + raise ValueError( + f"Invalid entry '{part}' in '{spec}' (expected key:value)." + ) + key, value = part.split(":", 1) + key = key.strip() + value = value.strip() + + # Integers (with optional underscores) + if re.fullmatch(r"\d(?:_?\d)*", value): + numeric_kvs[key] = int(value.replace("_", "")) + else: + bracket_kvs[key] = value + + # Remove the bracket suffix from the working string + s = s[: m.start()].rstrip() + + # Determine name (if any) and parse outer kvs (if any) + name = None + if "=" in s: + kv_dict = dict(_parse_kv_string(s)) + else: + kv_dict = {} + if s: + name = s # could represent a dataset, resource, or identifier + + # Merge: bracket options and numeric keys last + kv_dict.update(bracket_kvs) + kv_dict.update(numeric_kvs) + + return name, kv_dict + + +def get_default_logger(name): + logger = logging.getLogger(name) + if accelerate.PartialState().is_main_process: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + handler = logging.StreamHandler(sys.stdout) # print to terminal + formatter = logging.Formatter( + fmt=( + "\x1b[38;5;110m[%(asctime)s " + "\x1b[38;5;174m%(levelname)s " + "\x1b[38;5;109m%(name)s" + "/%(lineno)d-%(processName)s\x1b[38;5;110m] " + "\x1b[0m%(message)s" + ), + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger diff --git a/dllm/examples/bert/README.md b/dllm/examples/bert/README.md new file mode 100644 index 0000000..522ef42 --- /dev/null +++ b/dllm/examples/bert/README.md @@ -0,0 +1,190 @@ +# Generative BERT + +[![Hugging Face Checkpoints](https://img.shields.io/badge/Hugging%20Face-Checkpoints-yellow)](https://huggingface.co/collections/dllm-collection/bert-chat) +[![W&B Report](https://img.shields.io/badge/W&B-Report-white?logo=weightsandbiases)](https://api.wandb.ai/links/asap-zzhou/101h5xvg) + +This directory provides two key sets of resources: + +1. **Toy Examples ([Warmup](#warmup)):** Scripts for pretraining and SFTing any BERT-style model on small datasets to generate text. +2. **Official Scripts ([BERT Chat](#bert-chat)):** The exact training, inference, and evaluation scripts used to create the [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) checkpoints, two BERTs finetuned as Chatbots. For a deep dive into experimental results, lessons learned, and more reproduction details, please see our full [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg). + +

+ chat +

+

+ + Chat with ModernBERT-large-chat-v0. See Inference for details. + +

+ +## Files overview +``` +# example entry points for training / inference / evaluation +examples/bert +├── chat.py # Interactive inference example +├── eval.sh # Automatic evaluation script +├── generate.py # Inference example +├── pt.py # Pretraining example +├── README.md # Documentation (you are here) +└── sft.py # Supervised finetuning example +``` + +## Warmup + +In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text. +You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`. + +### Pretrain + +To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run: +```shell +accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/bert/pt.py \ + --model_name_or_path "answerdotai/ModernBERT-large" \ + --dataset_args "Trelis/tiny-shakespeare" \ + --text_field "Text" \ + --insert_eos False \ + --max_length 128 \ + --num_train_epochs 20 \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 64 \ + --save_steps 0.1 \ + --output_dir "models/ModernBERT-large/tiny-shakespeare" +``` + +To run inference with the model: +```shell +# just press enter (empty prompt) if you want the model to generate text from scratch +python -u examples/bert/chat.py \ + --model_name_or_path "models/ModernBERT-large/tiny-shakespeare/checkpoint-final" \ + --chat False --remasking "random" --steps 128 --max_new_tokens 128 +``` + +### SFT + +To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run: +```shell +accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 8 \ + examples/bert/sft.py \ + --model_name_or_path "answerdotai/ModernBERT-large" \ + --dataset_args "tatsu-lab/alpaca" \ + --max_length 512 \ + --num_train_epochs 20 \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 64 \ + --save_steps 0.1 \ + --output_dir "models/ModernBERT-large/alpaca" +``` + +To chat with the model: +```shell +python -u examples/bert/chat.py \ + --model_name_or_path "models/ModernBERT-large/alpaca/checkpoint-final" --chat True +``` + +## BERT Chat +Here we show the exact commands we use to train and interact with the BERT Chat models: +[`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0). +For training curves and other details, please see [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg). + +### Training + +To reproduce [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0), run: +```shell +accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \ + examples/bert/sft.py \ + --model_name_or_path "answerdotai/ModernBERT-base" \ + --dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \ + --max_length 1024 \ + --num_train_epochs 10 \ + --per_device_train_batch_size 48 \ + --per_device_eval_batch_size 48 \ + --save_steps 0.1 \ + --output_dir "models/ModernBERT-base/tulu-3-smoltalk/epochs-10-bs-384-len-1024" +``` + +To reproduce [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0), run: +```shell +accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \ + examples/bert/sft.py \ + --model_name_or_path "answerdotai/ModernBERT-large" \ + --dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \ + --max_length 1024 \ + --num_train_epochs 10 \ + --per_device_train_batch_size 48 \ + --per_device_eval_batch_size 48 \ + --save_steps 0.1 \ + --output_dir "models/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024" +``` + +### Inference + +To chat with the model: +```shell +python -u examples/bert/chat.py --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" --chat True +``` + +## Evaluation +> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation. + +For example, to evaluate [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run: +```shell +# Use model_args to adjust the generation arguments for evalution. +accelerate launch --num_processes 4 \ + dllm/pipelines/bert/eval.py \ + --tasks "mmlu_pro" \ + --model "bert" \ + --apply_chat_template \ + --num_fewshot 0 \ + --model_args "pretrained=dllm-collection/ModernBERT-large-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256" +``` + +To automatically evaluate [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on all benchmarks, run: +```shell +bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-base-chat-v0" +bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" +``` + +### Evaluation results + + + + + + +|                     | LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU | +|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| +| [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0)(evaluated) | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 | +| [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0)(evaluated) | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 | +| [`Qwen1.5-0.5B`](https://huggingface.co/Qwen/Qwen1.5-0.5B)(reported & evaluated) | 48.6 | 22.0 | 50.5 | 18.3 | 3.1 | 39.2 | 55.0 | 48.2 | 46.6 | +| [`Qwen1.5-0.5B-Chat`](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat)(reported & evaluated) | 41.2 | 11.3 | 37.2 | 18.2 | 2.1 | 35.0 | 52.0 | 36.9 | 32.2 | +| [`gpt2`](https://huggingface.co/openai-community/gpt2)(reported & evaluated) | 46.0 | 0.7 | 24.7 | 6.9 | 1.8 | 22.9 | 51.6 | 31.1 | 25.2 | +| [`gpt2-medium`](https://huggingface.co/openai-community/gpt2-medium)(reported & evaluated) | 55.5 | 2.1 | 24.6 | 17.8 | 1.4 | 22.9 |53.1 | 39.4 | 0.3 | + + +

+Table 1. Evaluation results of + +ModernBERT-base-chat-v0 +, + +ModernBERT-large-chat-v0 +, + +Qwen1.5-0.5B +, + +Qwen1.5-0.5B-Chat +, + +gpt2 +, and + +gpt2-medium +. +Underlined entries are results from official reports: GPT-2 paper, Qwen 1.5 blog, and Qwen2-0.5B-Instruct model card. All other results are evaluated using our framework. +

diff --git a/dllm/examples/bert/assets/chat.gif b/dllm/examples/bert/assets/chat.gif new file mode 100644 index 0000000..1f070fa Binary files /dev/null and b/dllm/examples/bert/assets/chat.gif differ diff --git a/dllm/examples/bert/chat.py b/dllm/examples/bert/chat.py new file mode 100644 index 0000000..53f91da --- /dev/null +++ b/dllm/examples/bert/chat.py @@ -0,0 +1,71 @@ +""" +Interactive chat / generation script for Bert models. + +Examples +-------- +# Raw multi-turn generation (default) +python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True +""" + +import sys +from dataclasses import dataclass +import transformers + +import dllm +from dllm.pipelines import llada +from dllm.tools.chat import multi_turn_chat, single_turn_generate + + +@dataclass +class ScriptArguments: + model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0" + seed: int = 42 + chat: bool = True + visualize: bool = True + + def __post_init__(self): + # same base-path resolution logic as in generate.py + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +@dataclass +class GeneratorConfig(llada.LLaDAGeneratorConfig): + steps: int = 128 + max_new_tokens: int = 128 + block_length: int = 32 + temperature: float = 0.0 + remasking: str = "low_confidence" + + +def main(): + parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig)) + script_args, gen_config = parser.parse_args_into_dataclasses() + transformers.set_seed(script_args.seed) + + model = dllm.utils.get_model(model_args=script_args).eval() + tokenizer = dllm.utils.get_tokenizer(model_args=script_args) + generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer) + + if script_args.chat: + multi_turn_chat( + generator=generator, + gen_config=gen_config, + visualize=script_args.visualize, + ) + else: + print("\nSingle-turn generation (no chat template).") + single_turn_generate( + generator=generator, + gen_config=gen_config, + visualize=script_args.visualize, + ) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nInterrupted. Bye!") + sys.exit(0) diff --git a/dllm/examples/bert/eval.sh b/dllm/examples/bert/eval.sh new file mode 100644 index 0000000..a15446e --- /dev/null +++ b/dllm/examples/bert/eval.sh @@ -0,0 +1,50 @@ + #!/usr/bin/env bash +# ===== Mandatory for proper import and evaluation ===== +export PYTHONPATH=.:$PYTHONPATH +export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation +export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset + +# ===== Optional but recommended for stability and debugging ===== +export PYTHONBREAKPOINT=0 # Disable interactive breakpoints +export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks +export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs +export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging + +# ===== Basic Settings ===== +model_name_or_path="dllm-collection/ModernBERT-large-chat-v0" +num_gpu=4 +while [[ $# -gt 0 ]]; do + case "$1" in + --model_name_or_path) + model_name_or_path="$2"; shift 2 ;; + --num_gpu) + num_gpu="$2"; shift 2 ;; + esac +done + +# ===== Common arguments ===== +common_args="--model bert --apply_chat_template" # BERT model is default to use chat template + +# ======================= +# BERT Instruct (Chat) Tasks +# ======================= + +accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \ + --tasks hellaswag_gen --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128" + +accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \ + --tasks mmlu_generative --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128" + +accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \ + --tasks mmlu_pro --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256" + +accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \ + --tasks arc_challenge_chat --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128" + +accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \ + --tasks winogrande --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128" diff --git a/dllm/examples/bert/generate.py b/dllm/examples/bert/generate.py new file mode 100644 index 0000000..455c9bd --- /dev/null +++ b/dllm/examples/bert/generate.py @@ -0,0 +1,73 @@ +""" +python -u examples/bert/generate.py --model_name_or_path "YOUR_MODEL_PATH" +""" + +from dataclasses import dataclass + +import transformers + +import dllm +from dllm.tools.chat import decode_trim +from dllm.pipelines import llada + + +@dataclass +class ScriptArguments: + model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0" + seed: int = 42 + visualize: bool = True + + def __post_init__(self): + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +@dataclass +class GeneratorConfig(llada.LLaDAGeneratorConfig): + steps: int = 128 + max_new_tokens: int = 128 + block_length: int = 64 + temperature: float = 0.0 + remasking: str = "low_confidence" + + +parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig)) +script_args, gen_config = parser.parse_args_into_dataclasses() +transformers.set_seed(script_args.seed) + +# Load model & tokenizer +model = dllm.utils.get_model(model_args=script_args).eval() +tokenizer = dllm.utils.get_tokenizer(model_args=script_args) +generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer) +terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer( + tokenizer=tokenizer +) + +# --- Example 1: Batch generation --- +print("\n" + "=" * 80) +print("TEST: bert.generate()".center(80)) +print("=" * 80) + +messages = [ + [{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}], + [{"role": "user", "content": "Please write an educational python function."}], +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, +) +outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True) +sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs) + +for iter, s in enumerate(sequences): + print("\n" + "-" * 80) + print(f"[Case {iter}]") + print("-" * 80) + print(s.strip() if s.strip() else "") +print("\n" + "=" * 80 + "\n") + +if script_args.visualize: + terminal_visualizer.visualize(outputs.histories, rich=True) diff --git a/dllm/examples/bert/pt.py b/dllm/examples/bert/pt.py new file mode 100644 index 0000000..c98720c --- /dev/null +++ b/dllm/examples/bert/pt.py @@ -0,0 +1,127 @@ +""" +Local users +------------ +- 1 GPU: + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/bert/pt.py + +- 8 GPUs (DDP): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml \ + examples/bert/pt.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 8 GPUs (DDP): + sbatch --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "ddp" \ + --script_path "examples/bert/pt.py" +""" + +import os +import functools +from dataclasses import dataclass, field + +import transformers +import accelerate + +import dllm + +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "answerdotai/ModernBERT-large" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "Trelis/tiny-shakespeare" + text_field: str = "Text" + max_length: int = 128 + streaming: bool = False + drop_tail: bool = True + insert_eos: bool = field( + default=True, + metadata={ + "help": "False when adjacent samples from the datasets are semantically coherent." + }, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = "models/ModernBERT-base/tiny-shakespeare" + num_train_epochs: int = 20 + learning_rate: float = 1e-4 + per_device_train_batch_size: int = 64 + per_device_eval_batch_size: int = 64 + eval_steps: float = 0.1 + save_steps: float = 0.1 + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + model = dllm.utils.get_model(model_args=model_args) + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + + # ----- Dataset ---------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_pt_dataset( + data_args.dataset_args, + streaming=data_args.streaming, + ) + dataset = dataset.map( + functools.partial( + dllm.utils.tokenize_and_group, + tokenizer=tokenizer, + text_field=data_args.text_field, + seq_length=data_args.max_length, + insert_eos=data_args.insert_eos, + drop_tail=data_args.drop_tail, + ), + batched=True, + remove_columns=dataset["train"].column_names, + **({} if data_args.streaming else {"num_proc": data_args.num_proc}), + **({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}), + ) + if data_args.streaming: + dataset = dataset.shuffle(seed=training_args.seed) + + # ----- Training -------------------------------------------------------------- + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = dllm.core.trainers.MDLMTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/bert/sft.py b/dllm/examples/bert/sft.py new file mode 100644 index 0000000..266ce51 --- /dev/null +++ b/dllm/examples/bert/sft.py @@ -0,0 +1,127 @@ +""" +Local users +------------ +- 1 GPU: + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/bert/sft.py + +- 8 GPUs (DDP): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml \ + examples/bert/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (DDP): + sbatch --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "ddp" \ + --script_path "examples/bert/sft.py" + +- 2 Nodes, 16 GPUs (DDP): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "ddp" \ + --script_path "examples/bert/sft.py" +""" + +import os +from dataclasses import dataclass, field +from functools import partial + +import transformers +import accelerate + +import dllm + +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "answerdotai/ModernBERT-large" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "tatsu-lab/alpaca" + max_length: int = 512 + load_preprocessed_data: bool = False + mask_prompt_loss: bool = field( + default=True, + metadata={"help": "Whether to mask the loss on the prompt tokens"}, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = "models/ModernBERT-large/alpaca" + group_by_length: bool = True + learning_rate: float = 1e-4 + num_train_epochs: int = 20 + per_device_train_batch_size: int = 64 + per_device_eval_batch_size: int = 64 + eval_steps: float = 0.1 + save_steps: float = 0.1 + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + model = dllm.utils.get_model(model_args=model_args) + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + + # ----- Dataset ---------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_sft_dataset( + data_args.dataset_args, + load_preprocessed_data=data_args.load_preprocessed_data, + ) + if not data_args.load_preprocessed_data: + map_fn = partial( + dllm.utils.default_sft_map_fn, + tokenizer=tokenizer, + mask_prompt_loss=data_args.mask_prompt_loss, + ) + dataset = dataset.map( + map_fn, + num_proc=data_args.num_proc, + desc="Mapping dataset to SFT format", + ) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + + # ----- Training -------------------------------------------------------------- + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = dllm.core.trainers.MDLMTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + data_collator=dllm.utils.NoAttentionMaskCollator( + tokenizer, + return_tensors="pt", + padding=True, + label_pad_token_id=tokenizer.pad_token_id, # finetune on padding + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/dream/README.md b/dllm/examples/dream/README.md new file mode 100644 index 0000000..49b0180 --- /dev/null +++ b/dllm/examples/dream/README.md @@ -0,0 +1,187 @@ +# Dream + +> 📄 Paper: [Dream 7B: Diffusion Large Language Models](https://arxiv.org/abs/2508.15487) | 💻 Code: [github.com/DreamLM/Dream](https://github.com/DreamLM/Dream) + +Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **Dream**. + +## Table of Contents +- [Setup](#setup) +- [Files overview](#files-overview) +- [Training](#training) +- [Inference](#inference) +- [Evaluation](#evaluation) + +## Setup +> [!IMPORTANT] +> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details. +> + + +## Files overview +``` +# tools relevant with Dream +dllm/pipelines/dream +├── __init__.py # Package initialization +├── models/ +│ ├── configuration_dream.py # Dream model configuration +│ ├── generation_utils.py # Diffusion-based generation logic +│ ├── modeling_dream.py # Core Dream model architecture +│ └── tokenization_dream.py # Tokenizer implementation for Dream +├── generator.py # Inference logic +├── trainer.py # Training logic (pretraining and SFT) +└── utils.py # Auxiliary utilities and helper functions + +# example entry points for training / inference / evaluation +examples/dream +├── chat.py # Interactive inference example +├── eval.sh # Automatic evaluation script +├── generate.py # Inference example +├── pt.py # Pretraining example +├── README.md # Documentation (you are here) +└── sft.py # Supervised finetuning example +``` + + +## Training + +### Finetuning +For example, to SFT [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) for instruction following on 8 GPUs, run: +```shell +accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/dream/sft.py \ + --model_name_or_path "Dream-org/Dream-v0-Base-7B" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \ + --max_length 1024 \ + --num_train_epochs 4 \ + --learning_rate 2e-5 +``` +If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run: +```shell +sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/dream/sft.py" \ + --model_name_or_path "Dream-org/Dream-v0-Base-7B" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \ + --max_length 1024 \ + --num_train_epochs 4 \ + --learning_rate 2e-5 +``` + + +#### Reproducing [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) +We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) by finetuning [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) using our training pipeline on the public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): + +```shell +# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training) +PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \ + --model_name_or_path "Dream-org/Dream-v0-Base-7B" \ + --sft_map_fn_path "examples.dream.sft.sft_map_fn" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "data/sft/dream/tulu-3-sft-mixture" \ + --num_proc 64 + +# train on 24*8=192 A100s with FSDP, take about 8 hours +sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/dream/sft.py" \ + --model_name_or_path "Dream-org/Dream-v0-Base-7B" \ + --dataset_args "data/sft/dream/tulu-3-sft-mixture" \ + --load_preprocessed_data True \ + --output_dir "models/Dream-7B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \ + --max_length 2048 \ + --truncation "right" \ + --group_by_length True \ + --num_train_epochs 5 \ + --learning_rate 1e-5 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 2 \ + --per_device_eval_batch_size 2 \ + --eval_on_start False \ + --eval_steps 0.1 \ + --save_steps 0.05 +``` + + +### Pretraining + +Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP: +```shell +sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/dream/pt.py" \ + --model_name_or_path "Dream-org/Dream-v0-Base-7B" \ + --dataset_args "mlfoundations/dclm-baseline-1.0" \ + --output_dir "models/Dream-7B-PT/dclm-baseline-1.0" \ + --max_length 1024 \ + --max_steps 2000 \ + --learning_rate 3e-4 +``` + +## Inference +We support batch inference for standard generation and infilling: + +```shell +python examples/dream/generate.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" +``` +We also support interactive multi-turn dialogue with visualization: +```shell +python examples/dream/chat.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" +``` + +## Evaluation +> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation. + +For example, to evaluate [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run: +```shell +# Use model_args to adjust the generation arguments for evalution. +accelerate launch --num_processes 4 \ + dllm/pipelines/dream/eval.py \ + --tasks "mmlu_pro" \ + --model "dream" \ + --apply_chat_template \ + --num_fewshot 0 \ + --model_args "pretrained=Dream-org/Dream-v0-Instruct-7B,mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" +``` + +To automatically evaluate [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) and [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on all benchmarks, run: +```shell +bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" --instruct True +bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Base-7B" --instruct False +``` + +### Evaluation results + +> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [Dream](https://github.com/DreamLM/Dream) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon. + +|                 | MMLU | BBH | ARC‑C | ARC‑E | Hellaswag | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | RACE | Countdown | Sudoku | Trip planning | +|:----------------|:-------:|:-------:|:-----:|:-----:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:-----------:|:----:|:-----------:| +| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (reported) | 69.5 | 57.9 | 59.9 | 83.9 | 73.3 | 74.8 | 75.8 | 77.2 | 39.6 | 36.6 | 57.9 | 56.2 | 44.7 | 16.0 | 81.0 | 17.8 | +| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (evaluated) | – | – | 59.7 | 83.3 | 73.1 | 72.9 | 72.0 | 69.6 | – | 35.5 | 45.8 | – | 43.0 | – | – | – | + + +

+Table 1. Evaluation results of + +Dream-8B-Base +. +

+ +| | MMLU | MMLU-Pro | GSM8K | Math | GPQA | HumanEval | MBPP | IFEval | +|:----------------|:----:|:---------:|:-----:|:----:|:----:|:-----------:|:----:|:----:| +| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(reported) | 67.0 | 43.3 | 81.0 | 39.2 | 33.0 | 55.5 | 58.8 | 62.5 | +| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(evaluated) | – | 43.0 | 82.6 | 39.9 | 32.4 | 59.1 | – | 62.3 | + +

+Table 2. Evaluation results of + +Dream-8B-Instruct +. +

+ + diff --git a/dllm/examples/dream/chat.py b/dllm/examples/dream/chat.py new file mode 100644 index 0000000..1872a6e --- /dev/null +++ b/dllm/examples/dream/chat.py @@ -0,0 +1,75 @@ +""" +Interactive chat / generation script for Dream models. + +Examples +-------- +# Chat mode (multi-turn, chat template) +python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True + +# Raw single-turn generation +python -u examples/dream/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False +""" + +import sys +from dataclasses import dataclass +import transformers + +import dllm +from dllm.pipelines import dream +from dllm.tools.chat import multi_turn_chat, single_turn_generate + + +@dataclass +class ScriptArguments: + model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B" + seed: int = 42 + chat: bool = True + visualize: bool = True + + def __post_init__(self): + # same base-path resolution logic as in generate.py + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +@dataclass +class GeneratorConfig(dream.DreamGeneratorConfig): + steps: int = 128 + max_new_tokens: int = 128 + temperature: float = 0.2 + top_p: float = 0.95 + alg: str = "entropy" + alg_temp: float = 0.0 + + +def main(): + parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig)) + script_args, gen_config = parser.parse_args_into_dataclasses() + transformers.set_seed(script_args.seed) + + model = dllm.utils.get_model(model_args=script_args).eval() + tokenizer = dllm.utils.get_tokenizer(model_args=script_args) + generator = dream.DreamGenerator(model=model, tokenizer=tokenizer) + + if script_args.chat: + multi_turn_chat( + generator=generator, + gen_config=gen_config, + visualize=script_args.visualize, + ) + else: + print("\nSingle-turn generation (no chat template).") + single_turn_generate( + generator=generator, + gen_config=gen_config, + visualize=script_args.visualize, + ) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nInterrupted. Bye!") + sys.exit(0) diff --git a/dllm/examples/dream/eval.sh b/dllm/examples/dream/eval.sh new file mode 100644 index 0000000..6989ea7 --- /dev/null +++ b/dllm/examples/dream/eval.sh @@ -0,0 +1,139 @@ +#!/usr/bin/env bash +# ===== Mandatory for proper import and evaluation ===== +export PYTHONPATH=.:$PYTHONPATH +export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation +export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset + + +# ===== Optional but recommended for stability and debugging ===== +export PYTHONBREAKPOINT=0 # Disable interactive breakpoints +export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks +export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs +export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging + + +# ===== Input Arguments ===== +model_name_or_path="Dream-org/Dream-v0-Instruct-7B" +instruct=True +num_gpu=4 +while [[ $# -gt 0 ]]; do + case "$1" in + --model_name_or_path) + model_name_or_path="$2"; shift 2 ;; + --instruct) + instruct="$2"; shift 2 ;; + --num_gpu) + num_gpu="$2"; shift 2 ;; + esac +done + + +# ===== Conditional Configurations ===== +if [ "$instruct" = "True" ]; then + echo ">>> Running in INSTRUCT mode" + common_args="--model dream --apply_chat_template" +else + echo ">>> Running in BASE mode" + common_args="--model dream" +fi + + +# ======================= +# Generation / Instruct Tasks +# ======================= + +if [ "$instruct" = "True" ]; then + # Instruct Tasks + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks mmlu_generative --num_fewshot 4 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks mmlu_pro --num_fewshot 4 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks gsm8k_cot --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=256,max_length=256,steps=256,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks minerva_math --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.0,top_p=1.0,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks humaneval_instruct_dream --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=768,max_length=768,steps=768,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks mbpp_instruct --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1024,max_length=1024,steps=1024,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks ifeval --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},mc_num=1,max_new_tokens=1280,max_length=1280,steps=1280,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true" + +else + # Base Generation Tasks + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks humaneval --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks gsm8k_cot --num_fewshot 8 ${common_args} \ + --model_args "pretrained=${model_name_or_path},max_new_tokens=256,steps=256,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks mbpp --num_fewshot 3 ${common_args} \ + --model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks minerva_math --num_fewshot 4 ${common_args} \ + --model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks bbh --num_fewshot 3 ${common_args} \ + --model_args "pretrained=${model_name_or_path},max_new_tokens=512,steps=512,temperature=0.0,top_p=0.95,add_bos_token=true,escape_until=true" +fi + + +# ======================= +# Likelihood Tasks (Base Only) +# ======================= + +if [ "$instruct" != "True" ]; then + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks mmlu --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks arc_easy --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks arc_challenge --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks hellaswag --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks piqa --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks winogrande --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/dream/eval.py \ + --tasks race --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},add_bos_token=true" +fi diff --git a/dllm/examples/dream/generate.py b/dllm/examples/dream/generate.py new file mode 100644 index 0000000..d841b35 --- /dev/null +++ b/dllm/examples/dream/generate.py @@ -0,0 +1,117 @@ +""" +python -u examples/dream/generate.py --model_name_or_path "YOUR_MODEL_PATH" +""" + +from dataclasses import dataclass + +import transformers + +import dllm +from dllm.tools.chat import decode_trim +from dllm.pipelines import dream + + +@dataclass +class ScriptArguments: + model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B" + seed: int = 42 + visualize: bool = True + + def __post_init__(self): + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +@dataclass +class GeneratorConfig(dream.DreamGeneratorConfig): + steps: int = 128 + max_new_tokens: int = 128 + temperature: float = 0.2 + top_p: float = 0.95 + alg: str = "entropy" + alg_temp: float = 0.0 + + +parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig)) +script_args, gen_config = parser.parse_args_into_dataclasses() +transformers.set_seed(script_args.seed) + +# Load model & tokenizer +model = dllm.utils.get_model(model_args=script_args).eval() +tokenizer = dllm.utils.get_tokenizer(model_args=script_args) +generator = dream.DreamGenerator(model=model, tokenizer=tokenizer) +terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer( + tokenizer=tokenizer +) + +# --- Example 1: Batch generation --- +print("\n" + "=" * 80) +print("TEST: dream.generate()".center(80)) +print("=" * 80) + +messages = [ + [{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}], + [{"role": "user", "content": "Please write an educational python function."}], +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, +) + +outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True) +sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs) + +for iter, s in enumerate(sequences): + print("\n" + "-" * 80) + print(f"[Case {iter}]") + print("-" * 80) + print(s.strip() if s.strip() else "") +print("\n" + "=" * 80 + "\n") + +if script_args.visualize: + terminal_visualizer.visualize(outputs.histories, rich=True) + +# --- Example 2: Batch fill-in-the-blanks --- +print("\n" + "=" * 80) +print("TEST: dream.infilling()".center(80)) +print("=" * 80) + +masked_messages = [ + [ + {"role": "user", "content": tokenizer.mask_token * 20}, + { + "role": "assistant", + "content": "Sorry, I do not have answer to this question.", + }, + ], + [ + {"role": "user", "content": "Please write an educational python function."}, + { + "role": "assistant", + "content": "def hello_" + tokenizer.mask_token * 20 + " return", + }, + ], +] + +inputs = tokenizer.apply_chat_template( + masked_messages, + add_generation_prompt=False, + tokenize=True, +) + +outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True) +sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs) + +for iter, (i, s) in enumerate(zip(inputs, sequences)): + print("\n" + "-" * 80) + print(f"[Case {iter}]") + print("-" * 80) + print("[Masked]:\n" + tokenizer.decode(i)) + print("\n[Filled]:\n" + (s.strip() if s.strip() else "")) +print("\n" + "=" * 80 + "\n") + +if script_args.visualize: + terminal_visualizer.visualize(outputs.histories, rich=True) diff --git a/dllm/examples/dream/pt.py b/dllm/examples/dream/pt.py new file mode 100644 index 0000000..9f27f8f --- /dev/null +++ b/dllm/examples/dream/pt.py @@ -0,0 +1,162 @@ +""" +Local users +------------ +- 1 GPU (4bit quant & LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/dream/pt.py \ + --load_in_4bit True --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/dream/pt.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 24 Nodes, 192 GPUs (FSDP): + sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/dream/pt.py" +""" + +import os +import functools +from dataclasses import dataclass, field + +import torch +import transformers +import accelerate + +import dllm +from dllm.pipelines import dream +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "Dream-org/Dream-v0-Base-7B" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]" + text_field: str = "text" + streaming: bool = True + drop_tail: bool = True + insert_eos: bool = field( + default=True, + metadata={ + "help": "False when adjacent samples from the datasets are semantically coherent." + }, + ) + random_length_ratio: float = field( + default=0.01, + metadata={ + "help": ( + "The probability of randomly cut sequences during training. " + "See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md." + ) + }, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = ( + "models/Dream-7B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]" + ) + learning_rate: float = 3e-4 + max_steps: int = 2_000 + per_device_train_batch_size: int = 4 + gradient_accumulation_steps: int = 4 + eval_steps: float = 0.05 + save_steps: float = 0.05 + # Dream PT specific args + # Note: Since Dream’s pretraining recipe is not public, + # this is only a reference implementation following LLaDA’s data processing approach. + loss_weight_type: str = field( + default="cart[geo_p:0.3]", + metadata={ + "help": ( + "The loss weight type. " + "See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml." + ) + }, + ) + + +def train(): + # ----- Parse & setup -------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + # necessary for streaming dataset + if data_args.streaming: + training_args.accelerator_config.dispatch_batches = False + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model --------------------------------------------------------------- + # initialize model weights from scratch + config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) + with dllm.utils.init_device_context_manager(): + model = transformers.AutoModel.from_config(config, dtype=torch.bfloat16) + + # ----- Tokenizer ----------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + # ----- Optional PEFT: LoRA ------------------------------------------------- + model = dllm.utils.load_peft(model=model, model_args=model_args) + + # ----- Dataset ------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_pt_dataset( + data_args.dataset_args, + streaming=data_args.streaming, + ) + dataset = dataset.map( + functools.partial( + dllm.utils.tokenize_and_group, + tokenizer=tokenizer, + text_field=data_args.text_field, + seq_length=data_args.max_length, + insert_eos=data_args.insert_eos, + drop_tail=data_args.drop_tail, + ), + batched=True, + remove_columns=dataset["train"].column_names, + **({} if data_args.streaming else {"num_proc": data_args.num_proc}), + **({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}), + ) + if data_args.streaming: + dataset = dataset.shuffle(seed=training_args.seed) + + # ----- Training -------------------------------------------------------------- + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = dream.DreamTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + loss_weight_type=training_args.loss_weight_type, + data_collator=dream.utils.DreamPTCollator( + tokenizer, + return_tensors="pt", + padding=True, + random_length_ratio=data_args.random_length_ratio, + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/dream/sft.py b/dllm/examples/dream/sft.py new file mode 100644 index 0000000..f8354cd --- /dev/null +++ b/dllm/examples/dream/sft.py @@ -0,0 +1,192 @@ +""" +Local users +------------ +- 1 GPU (4bit quant & LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/dream/sft.py \ + --load_in_4bit True --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/dream/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:1 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/dream/sft.py" + +- 2 Nodes, 16 GPUs (FSDP): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/dream/sft.py" +""" + +import os +from dataclasses import dataclass, field +from functools import partial + +import transformers +import accelerate + +import dllm +from dllm.pipelines import dream + +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "Dream-org/Dream-v0-Base-7B" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]" + load_preprocessed_data: bool = False + mask_prompt_loss: bool = field( + default=True, + metadata={"help": "Whether to mask the loss on the prompt tokens"}, + ) + # Dream SFT specific args + perbatch_cutoff: bool = field( + default=True, + metadata={ + "help": ( + "Randomly pick a response length from batch and trim other responses. " + "See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml." + ) + }, + ) + resp_cutoff_ratio: float = field( + default=0.0, + metadata={ + "help": ( + "The probability of randomly cutting sequences during training. " + "See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml." + ) + }, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = "models/Dream-7B-SFT" + group_by_length: bool = True + # Dream SFT specific args + loss_weight_type: str = field( + default="cart[geo_p:0.3]", + metadata={ + "help": ( + "The loss weight type. " + "See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml." + ) + }, + ) + + +# ------------------------------------------------------------------------------ +# SFT mapping function +# ------------------------------------------------------------------------------ +def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool) -> dict: + """ + Build Dream SFT features from a chat-format row. + + Returns: + dict with input_ids, labels, attention_mask, prompt_len + """ + prompt_tokens = tokenizer.apply_chat_template( + row["messages"][:-1], tokenize=True, add_generation_prompt=True + ) + prompt_response_tokens = tokenizer.apply_chat_template( + row["messages"], tokenize=True, add_generation_prompt=False + ) + labels = prompt_response_tokens.copy() + + if mask_prompt_loss: + labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens) + else: + # When training on all tokens, prepend a BOS token (if missing) + # so the model can predict the first token. + if prompt_response_tokens[0] != tokenizer.bos_token_id: + bos = [tokenizer.bos_token_id] + prompt_response_tokens = bos + prompt_response_tokens + prompt_tokens = bos + prompt_tokens + labels = bos + labels + labels[0] = -100 # ignore loss on BOS + + return { + "input_ids": prompt_response_tokens, + "labels": labels, + "prompt_len": len(prompt_tokens), + } + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + # necessary when batch contains customized fields + training_args.remove_unused_columns = False + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + model = dllm.utils.get_model(model_args=model_args) + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + + # ----- Dataset ---------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_sft_dataset( + data_args.dataset_args, + load_preprocessed_data=data_args.load_preprocessed_data, + ) + if not data_args.load_preprocessed_data: + map_fn = partial( + sft_map_fn, + tokenizer=tokenizer, + mask_prompt_loss=data_args.mask_prompt_loss, + ) + dataset = dataset.map( + map_fn, + num_proc=data_args.num_proc, + desc="Mapping dataset to SFT format", + ) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + + # ----- Training -------------------------------------------------------------- + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = dream.DreamTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + loss_weight_type=training_args.loss_weight_type, + data_collator=dream.utils.DreamSFTCollator( + tokenizer, + return_tensors="pt", + padding=True, + perbatch_cutoff=data_args.perbatch_cutoff, + resp_cutoff_ratio=data_args.resp_cutoff_ratio, + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/editflow/README.md b/dllm/examples/editflow/README.md new file mode 100644 index 0000000..efd9635 --- /dev/null +++ b/dllm/examples/editflow/README.md @@ -0,0 +1,3 @@ +Work in progress. + +Please see [`examples/editflow/bert/README.md`](/examples/editflow/bert/README.md) for examples of finetuning BERT with EditFlow. diff --git a/dllm/examples/editflow/_README.md b/dllm/examples/editflow/_README.md new file mode 100644 index 0000000..e882554 --- /dev/null +++ b/dllm/examples/editflow/_README.md @@ -0,0 +1,162 @@ +# Edit Flows + +> **Reference** +> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018) + +This directory provides an educational reference for training EditFlow models. It demonstrates how to adapt open-weight DLLMs—such as [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)—to support *insertion*, *deletion*, beyond the standard *substitution*(`mask`->`tokens`) operations. It also includes examples for training (pretraining and finetuning) EditFlow models from scratch. + +> [!NOTE] +> - Examples are available for both LLaDA and Dream, but this README focuses on adapting open-weight LLaDA for edit operations ([`adapt_llada.py`](/examples/editflow/adapt_llada.py)) and reusing its architecture for training from scratch ([`pt_llada.py`](/examples/editflow/pt_llada.py) -> [`sft_llada.py`](/examples/editflow/sft_llada.py)). +> - While `EditFlowCollator` supports custom `x0`, this README uses a fixed-length (128) masks as `x0`. The trained model generates text by replacing masks, deleting redundant ones, and inserting tokens as needed. To change the default `x0` distribution (e.g., empty sequences for [OneFlow](https://arxiv.org/abs/2510.03506)-like insertion-only generation), pass `--x0_sampler "empty"`. + +## Table of Contents +- [Setup](#setup) +- [Files overview](#files-overview) +- [Training](#training) + - [Adapting LLaDA-8B-Instruct to support insertion and deletion](#adapting-llada-8b-instruct-to-support-insertion-and-deletion) + - [Pretraining & Finetuning from scratch](#pretraining--finetuning-from-scratch) +- [Sampling](#sampling) +- [Acknowledgement](#acknowledgement) + +## Setup +> [!IMPORTANT] +> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details. + +## Files overview +``` +dllm/pipelines/editflow +├── __init__.py # Package initialization +├── models +│ ├── dream +│ │ └── modelling_dream.py # EditFlowDream: architecture based on Dream +│ └── llada +│ └── modelling_llada.py # EditFlowLLaDA: architecture based on LLaDA +├── trainer.py +└── utils.py + +# example entry point for training / sampling +examples/editflow +├── adapt_dream.py # Example of adapting Dream for EditFlow directly +├── adapt_llada.py # Example of adapting LLaDA for EditFlow directly +├── generate.py # Generation example +├── pt_dream.py # EditFlowDream pretraining example +├── pt_llada.py # EditFlowLLaDA pretraining example +├── pt.py # Pretraining function +├── README.md # Documentation (you are here) +├── sft_dream.py # EditFlowDream SFT example +├── sft_llada.py # EditFlowLLaDA SFT example +└── sft.py # Supervised finetuning function +``` + +## Training + +### Adapting [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) to support *insertion* and *deletion* + +The original LLaDA model generated text by iteratively substituting the given `` tokens to real tokens. + +

+ LLaDA demo +

+

Figure: Example Gradio demo for LLaDA.

+ +However, LLaDA supports only substitution. This example shows how to adapt it so that, during decoding, the model can not only replace fixed-length masks (e.g., 128 tokens) with real text but also insert new tokens and delete unnecessary masks adaptively: + +```shell +accelerate launch \ + --config_file scripts/accelerate_configs/zero2.yaml \ + examples/editflow/adapt_llada.py \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \ + --lm_head_key "model.transformer.ff_out" \ + --init_editflow_from_src True \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \ + --x0_sampler "masks[length:128]" \ + --max_length 1024 \ + --num_train_epochs 4 \ + --learning_rate 5e-5 +``` + +If you are using slurm and want to train across, for example, four nodes (32 GPUs total), run: +```shell +sbatch --nodes=4 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/adapt_llada.py" \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" \ + --lm_head_key "model.transformer.ff_out" \ + --init_editflow_from_src True \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture" \ + --x0_sampler "masks[length:128]" \ + --max_length 1024 \ + --num_train_epochs 4 \ + --learning_rate 5e-5 +``` + +After training, you can use the [generate.py](/examples/editflow/generate.py) scripts to provide a visualized decoding trace to see how the model performs *insertion* and *deletion* beyond regular mask *substitutions*. See [Sampling](#sampling) for details. + + +### Pretraining & Finetuning from scratch +You can also train an EditFlow model from scratch (pretrain → SFT) without adapting an existing DLLM. + +Pretrain on a subset of [mlfoundations/dclm-baseline-1.0](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) using 192 GPUs (24x8) and FSDP: + +```shell +sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/pt_llada.py" \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ + --dataset_args "mlfoundations/dclm-baseline-1.0" \ + --output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \ + --x0_sampler "masks[length:128]" \ + --max_length 1024 \ + --max_steps 2000 \ + --learning_rate 3e-4 +``` + +Finetune on a subset of [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) using 8 GPUS and FSDP for better instruction following: + +```shell +# you can also run locally with `accelerate ...` +sbatch --nodes=1 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/sft_llada.py" \ + --model_name_or_path "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0/checkpoint-final" \ + --dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]" \ + --output_dir "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0" \ + --x0_sampler "masks[length:128]" \ + --max_length 1024 \ + --num_train_epochs 4 \ + --learning_rate 5e-5 +``` + +## Sampling + +After training, you can visualize how the model performs mask substitution, insertion, and deletion during generation with [generate.py](/examples/editflow/generate.py). Inserted tokens appear blue, and tokens substituted from `` appear black, and deleted tokens are shown with a strikethrough before they disappear. + +```shell +# Generate a long sequence to visualize insertions after 128 tokens +python examples/editflow/generate.py \ + --model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \ + --tau 0.02 --mask_length 128 --seed 7070 \ + --prompt "write a romantic story" --make_gif + +# Generate a short sequence to visualize deletions after 128 tokens +python examples/editflow/generate.py \ + --model_name_or_path "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture/checkpoint-final" \ + --tau 0.02 --mask_length 128 --seed 7070 \ + --prompt "write a single-sentence romantic story" --make_gif +``` + +

+ EditFlow deletion demo +

+

Figure: Deletion & Substitution trace

+ +

+ LLaDA demo +

+

Figure: Inserction & Substitution trace

+ +## Acknowledgement + +This Edit Flows implementation is inspired by https://github.com/TheMatrixMaster/edit-flows-demo. diff --git a/dllm/examples/editflow/assets/all.gif b/dllm/examples/editflow/assets/all.gif new file mode 100644 index 0000000..5267b4d Binary files /dev/null and b/dllm/examples/editflow/assets/all.gif differ diff --git a/dllm/examples/editflow/assets/deletion.gif b/dllm/examples/editflow/assets/deletion.gif new file mode 100644 index 0000000..d2264db Binary files /dev/null and b/dllm/examples/editflow/assets/deletion.gif differ diff --git a/dllm/examples/editflow/assets/insertion.gif b/dllm/examples/editflow/assets/insertion.gif new file mode 100644 index 0000000..94b118b Binary files /dev/null and b/dllm/examples/editflow/assets/insertion.gif differ diff --git a/dllm/examples/editflow/bert/README.md b/dllm/examples/editflow/bert/README.md new file mode 100644 index 0000000..2c01e30 --- /dev/null +++ b/dllm/examples/editflow/bert/README.md @@ -0,0 +1,77 @@ +# Edit Flows - BERT + +> 📄 Paper: [Edit Flows: Flow Matching with Edit Operations](https://arxiv.org/abs/2506.09018) + + +## Warmup + +In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text with EditFlow. +You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`. + +### Pretrain + +To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run: +```shell +PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/editflow/bert/pt.py \ + --model_name_or_path "answerdotai/ModernBERT-large" \ + --dataset_args "Trelis/tiny-shakespeare" \ + --text_field "Text" \ + --insert_eos False \ + --max_length 128 \ + --num_train_epochs 20 \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 64 \ + --save_steps 0.1 \ + --x0_sampler "masks[length:64]" \ + --output_dir "models/EditFlow/ModernBERT-large/tiny-shakespeare" +``` + +To run inference with the model: +```shell +PYTHONPATH=. python examples/editflow/generate.py \ + --model_name_or_path "models/EditFlow/ModernBERT-large/tiny-shakespeare/checkpoint-final" \ + --tau 0.01 --mask_length 64 --seed 42 --make_gif + +# see `decode_trace.gif` +``` + + +### SFT +To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run: +```shell +PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \ + examples/editflow/bert/sft.py \ + --model_name_or_path "answerdotai/ModernBERT-large" \ + --dataset_args "tatsu-lab/alpaca" \ + --max_length 512 \ + --num_train_epochs 20 \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 64 \ + --save_steps 0.1 \ + --x0_sampler "masks[length:64]" \ + --output_dir "models/EditFlow/ModernBERT-large/alpaca" +``` + +To run inference with the model: +```shell +PYTHONPATH=. python examples/editflow/generate.py \ + --model_name_or_path "models/EditFlow/ModernBERT-large/alpaca/checkpoint-final" \ + --prompt "Could you please write a poem for me?" --tau 0.01 --mask_length 64 --seed 42 --make_gif + +# see `decode_trace.gif` +``` + + diff --git a/dllm/examples/editflow/bert/pt.py b/dllm/examples/editflow/bert/pt.py new file mode 100644 index 0000000..a32961d --- /dev/null +++ b/dllm/examples/editflow/bert/pt.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass + +import transformers + +import dllm +from examples.editflow import pt as editflow_pt + + +@dataclass +class ModelArguments(editflow_pt.ModelArguments): + model_name_or_path: str = "answerdotai/ModernBERT-large" + lm_head_key: str = "decoder" + + +@dataclass +class DataArguments(editflow_pt.DataArguments): + dataset_args: str = "Trelis/tiny-shakespeare" + text_field: str = "Text" + max_length: int = 128 + streaming: bool = False + drop_tail: bool = True + insert_eos: bool = False + + +@dataclass +class TrainingArguments(editflow_pt.TrainingArguments): + output_dir: str = "models/EditFlow/ModernBERT-large/tiny-shakespeare" + num_train_epochs: float = 20 + learning_rate: float = 3e-4 + per_device_train_batch_size: int = 64 + per_device_eval_batch_size: int = 64 + eval_steps: float = 0.1 + save_steps: float = 0.1 + x0_sampler: str = "masks[length:64]" + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + editflow_pt.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig, + ) diff --git a/dllm/examples/editflow/bert/sft.py b/dllm/examples/editflow/bert/sft.py new file mode 100644 index 0000000..c410e69 --- /dev/null +++ b/dllm/examples/editflow/bert/sft.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass + +import transformers + +import dllm +from examples.editflow import sft as editflow_sft + + +@dataclass +class ModelArguments(editflow_sft.ModelArguments): + model_name_or_path: str = "answerdotai/ModernBERT-large" + lm_head_key: str = "decoder" + + +@dataclass +class DataArguments(editflow_sft.DataArguments): + dataset_args: str = "tatsu-lab/alpaca" + max_length: int = 512 + + +@dataclass +class TrainingArguments(editflow_sft.TrainingArguments): + output_dir: str = "models/EditFlow/ModernBERT-large/alpaca" + num_train_epochs: float = 20 + learning_rate: float = 3e-4 + per_device_train_batch_size: int = 64 + per_device_eval_batch_size: int = 64 + eval_steps: float = 0.1 + save_steps: float = 0.1 + x0_sampler: str = "masks[length:64]" + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + editflow_sft.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + ef_config_cls=dllm.pipelines.editflow.EditFlowModernBertConfig, + ) diff --git a/dllm/examples/editflow/dream/adapt.py b/dllm/examples/editflow/dream/adapt.py new file mode 100644 index 0000000..735dd43 --- /dev/null +++ b/dllm/examples/editflow/dream/adapt.py @@ -0,0 +1,88 @@ +""" +Local users +------------ +- 1 GPU (LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/editflow/dream/adapt.py \ + --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/editflow/dream/adapt.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/dream/adapt.py" + +- 2 Nodes, 16 GPUs (FSDP): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/dream/adapt.py" +""" + +from dataclasses import dataclass + +import torch +import transformers + +import dllm +from examples.editflow import sft as editflow_sft + + +@dataclass +class ModelArguments(editflow_sft.ModelArguments): + model_name_or_path: str = "Dream-org/Dream-v0-Instruct-7B" + lm_head_key: str = "lm_head" + init_editflow_from_src: bool = True + + +@dataclass +class DataArguments(editflow_sft.DataArguments): + dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]" + + +@dataclass +class TrainingArguments(editflow_sft.TrainingArguments): + output_dir: str = ( + "models/EditFlow-Dream-7B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]" + ) + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dllm.utils.initial_training_setup(model_args, data_args, training_args) + # Create EditFlow model (bf16 init on CUDA) + ef_cfg = dllm.pipelines.editflow.EditFlowDreamConfig.from_pretrained( + model_args.model_name_or_path + ) + with dllm.utils.init_device_context_manager(): + model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16) + # Initialize EditFlow model from the src model: copies backbone & clones lm_head + if model_args.init_editflow_from_src: + src_model = transformers.AutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, dtype=torch.bfloat16 + ) + dllm.pipelines.editflow.utils.init_editflow_from_src( + model, src_model, lm_head_key=model_args.lm_head_key + ) + del src_model + model = dllm.utils.load_peft(model, model_args) + + editflow_sft.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + model=model, + ) diff --git a/dllm/examples/editflow/dream/pt.py b/dllm/examples/editflow/dream/pt.py new file mode 100644 index 0000000..b37d20c --- /dev/null +++ b/dllm/examples/editflow/dream/pt.py @@ -0,0 +1,67 @@ +""" +Local users +------------ +- 1 GPU (LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/editflow/dream/pt.py \ + --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/editflow/dream/pt.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:1 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/dream/pt.py" + +- 24 Nodes, 192 GPUs (FSDP): + sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/dream/pt.py" +""" + +from dataclasses import dataclass + +import transformers + +import dllm +from examples.editflow import pt as editflow_pt + + +@dataclass +class ModelArguments(editflow_pt.ModelArguments): + model_name_or_path: str = "Dream-org/Dream-v0-Base-7B" + lm_head_key: str = "lm_head" + + +@dataclass +class DataArguments(editflow_pt.DataArguments): + dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]" + + +@dataclass +class TrainingArguments(editflow_pt.TrainingArguments): + output_dir: str = ( + "models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]" + ) + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + editflow_pt.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + ef_config_cls=dllm.pipelines.editflow.EditFlowDreamConfig, + ) diff --git a/dllm/examples/editflow/dream/sft.py b/dllm/examples/editflow/dream/sft.py new file mode 100644 index 0000000..e8eb9d5 --- /dev/null +++ b/dllm/examples/editflow/dream/sft.py @@ -0,0 +1,66 @@ +""" +Local users +------------ +- 1 GPU (LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/editflow/dream/sft.py \ + --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/zero2.yaml \ + examples/editflow/dream/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/dream/sft.py" + +- 2 Nodes, 16 GPUs (FSDP): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/dream/sft.py" +""" + +from dataclasses import dataclass + +import transformers + +from examples.editflow import sft as editflow_sft + + +@dataclass +class ModelArguments(editflow_sft.ModelArguments): + model_name_or_path: str = ( + "models/EditFlow-Dream-7B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]/checkpoint-final" + ) + + +@dataclass +class DataArguments(editflow_sft.DataArguments): + dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]" + + +@dataclass +class TrainingArguments(editflow_sft.TrainingArguments): + output_dir: str = ( + "models/EditFlow-Dream-7B-Instruct-SFT/tulu-3-sft-mixture[train:10000,test:1000]" + ) + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + editflow_sft.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + ) diff --git a/dllm/examples/editflow/generate.py b/dllm/examples/editflow/generate.py new file mode 100644 index 0000000..723c2ee --- /dev/null +++ b/dllm/examples/editflow/generate.py @@ -0,0 +1,418 @@ +""" +Minimal EditFlow τ-leap generator for EditBase-Dream with diffusion-style visualization. + +What changed vs. your original: +- tau_leap_step_minimal returns (x_next, any_edit, step_trace) preserving all intermediates. +- generate_editflow_minimal returns (final_text, trace). +- render_consecutive_trace_gif(trace, tokenizer, ...) draws a GIF where each frame shows + ONLY the current output (like the Gemini diffusion page shows progressive refinement): + * SUB tokens in the current frame are orange + * INS tokens in the current frame are blue + * KEEP tokens are black + * If any deletions happened in the step, the title shows ⌫N (red) +""" + +# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --time=03:00:000 python examples/editflow/generate.py --model_name_or_path "models/EditFlow-Dream-Instruct-7B/tulu-3-sft-mixture/checkpoint-final" --tau 0.02 --mask_length 128 --seed 7070 --prompt "write a romantic story" --make_gif + +import math +from dataclasses import dataclass +from typing import Annotated + +import tyro +import torch +from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer + +from dllm.core.schedulers import BaseKappaScheduler, LinearKappaScheduler + + +# ------------------------------- Small utilities -------------------------------- + + +def _bernoulli_from_rate(rate: torch.Tensor, tau: float) -> torch.Tensor: + """First-order τ-leap Bernoulli with p ≈ rate * τ (clamped).""" + p = (rate.float() * float(tau)).clamp_(0.0, 1.0 - 1e-6) + return torch.bernoulli(p) + + +def _sample_from_logits(logits_row: torch.Tensor, temperature: float) -> int: + """Sample one token id from a 1D logits row with temperature. + temperature <= 0 -> greedy (argmax). + """ + if temperature <= 0.0: + return int(torch.argmax(logits_row).item()) + return int( + torch.distributions.Categorical(logits=(logits_row / temperature)) + .sample() + .item() + ) + + +@dataclass +class GenCfg: + tau: float = 0.02 # τ step + device: str = "cuda" if torch.cuda.is_available() else "cpu" + seed: int = 1234 + edit_prompt: bool = False # allow editing inside prompt region? + temperature: float = 0.7 # token sampling temperature (sub/ins) + verbose: bool = True # whether to show intermediate decoding traces + time_independent: bool = True + + +# -------------------------------- τ-leap one step -------------------------------- + + +@torch.no_grad() +def tau_leap_step_minimal( + x: torch.Tensor, # [T] + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + prompt_len: int, # number of initial prompt tokens (including BOS) + t: float, + sched: BaseKappaScheduler, + cfg: GenCfg, + prev_out: dict | None = None, # <-- pass prior step's model outputs + reuse_prev: bool = False, # <-- if True, reuse prev_out instead of forward() +) -> tuple[torch.Tensor, bool, dict, dict]: + """ + Single τ-leap step with deletion/substitution conflict resolution + and right-insert policy. + + Reuse semantics: + • If cfg.time_independent == True and reuse_prev == True and prev_out is not None, + we reuse `prev_out` tensors instead of calling model() again. + • Otherwise we run a fresh forward(). + + Viz-only convention: + • Any local annotated as _Ann[*, "viz-only"] is used only for human-visible + tracing / debugging (console logs, GIFs) and does not affect generation. + • Such variables are also prefixed with '_' for quick visual scanning. + + Returns: + x_next, any_edit, _step_trace, out_for_next (the freshly used model outputs) + """ + device = x.device + T = x.numel() + + # Decide whether to reuse the previous forward results + use_reuse = bool(cfg.time_independent and reuse_prev and (prev_out is not None)) + if use_reuse: + out = prev_out + else: + attn = torch.ones(1, T, dtype=torch.long, device=device) + t_tensor = torch.full((1, 1), float(t), device=device) + out = model(input_ids=x.unsqueeze(0), attention_mask=attn, t=t_tensor) + + del_rate_h = out["del_rate_hat"] # [1, T] + sub_rate_h = out["sub_rate_hat"] # [1, T] + ins_rate_h = out["ins_rate_hat"] # [1, T] + sub_logits = out["sub_logits"] # [1, T, V] + ins_logits = out["ins_logits"] # [1, T, V] + + # Scale normalized rates to true rates + tt = torch.tensor([[t]], device=device) + w = sched.weight(tt) + del_rate = del_rate_h * w + sub_rate = sub_rate_h * w + ins_rate = ins_rate_h * w + + # Clamp prompt_len within current T (robustness) + prompt_len_clamped = int(max(1, min(prompt_len, T))) + + if not cfg.edit_prompt: + # Protect the entire prompt span from del/sub + del_rate[:, :prompt_len_clamped] = 0.0 + sub_rate[:, :prompt_len_clamped] = 0.0 + # Disallow insertions inside the prompt EXCEPT at the last prompt token + if prompt_len_clamped >= 2: + ins_rate[:, : prompt_len_clamped - 1] = 0.0 + + # Combined "edit" (delete or substitute) event + comb_rate = (del_rate + sub_rate).squeeze(0) # [T] + comb_fire = _bernoulli_from_rate(comb_rate, cfg.tau).bool() # [T] + + # If an edit fires at i, choose deletion with prob λ_del/(λ_del+λ_sub) + p_del = (del_rate.squeeze(0) / (comb_rate + 1e-8)).clamp(0, 1) # [T] + choose_del = (torch.rand_like(p_del) < p_del) & comb_fire # [T] + choose_sub = comb_fire & (~choose_del) # [T] + + # Insertions (right of token i) + ins_fire = _bernoulli_from_rate(ins_rate.squeeze(0), cfg.tau).bool() # [T] + + # Token draws (algorithmic, not viz-only) + sub_samples: list[int | None] = [ + ( + _sample_from_logits(sub_logits[0, i], cfg.temperature) + if choose_sub[i] + else None + ) + for i in range(T) + ] + ins_samples: list[int | None] = [ + _sample_from_logits(ins_logits[0, i], cfg.temperature) if ins_fire[i] else None + for i in range(T) + ] + + # Build new sequence left→right (apply insertions to the RIGHT) + new_ids: list[int] = [] + + # --- viz-only per-position labels (for trace/GIF) --- + _before_ops: Annotated[list[str], "viz-only"] = ( + [] + ) # per 'before' position: DEL/SUB/KEEP + _after_ops: Annotated[list[str], "viz-only"] = ( + [] + ) # per 'after' token aligned to new_ids: INS/SUB/KEEP + + for i in range(T): + if choose_del[i]: + _before_ops.append("DEL") + # deletion -> no token appended + elif choose_sub[i]: + _before_ops.append("SUB") + new_tok = sub_samples[i] + new_ids.append(int(new_tok)) + _after_ops.append("SUB") + else: + _before_ops.append("KEEP") + new_ids.append(int(x[i].item())) + _after_ops.append("KEEP") + + if ins_samples[i] is not None: + new_ids.append(int(ins_samples[i])) + _after_ops.append("INS") + + x_next = torch.tensor(new_ids, dtype=torch.long, device=device) + any_edit = bool(comb_fire.any().item() or ins_fire.any().item()) + # Provide the exact outputs we used this step for the caller to pass forward + out_for_next = out + + # --- (vis) used only for verbose console trace --- + if cfg.verbose and (comb_fire.any() or ins_fire.any()): + + def _tok_str(tok_id: int) -> str: # viz-only helper + try: + s = tokenizer.decode([int(tok_id)]) + return s if s.strip() else f"<{int(tok_id)}>" + except Exception: + return f"<{int(tok_id)}>" + + _ops_strs: Annotated[list[str], "viz-only"] = [] + for i in range(T): + if choose_del[i]: + _ops_strs.append(f"DEL@{i}:{_tok_str(int(x[i]))}") + elif choose_sub[i]: + _ops_strs.append( + f"SUB@{i}:{_tok_str(int(x[i]))}->{_tok_str(sub_samples[i])}" + ) + if ins_samples[i] is not None: + _ops_strs.append(f"INS@{i}->{i+1}:{_tok_str(ins_samples[i])}") + print("[time]", f"{t:.4f}") + print("[events]", "; ".join(_ops_strs)) + print("[decode]\n", tokenizer.decode(new_ids, skip_special_tokens=False)) + print() + + # --- (vis) step trace payload (returned; used only for visualization downstream) --- + _step_trace: Annotated[dict, "viz-only"] = { + "t": float(t), + "x_before_ids": [int(i) for i in x.tolist()], + "x_after_ids": [int(i) for i in new_ids], + "before_ops": _before_ops, # viz-only labels + "after_ops": _after_ops, # viz-only labels + # below are algorithmic signals copied for visualization/analysis + "choose_del": [bool(v) for v in choose_del.tolist()], + "choose_sub": [bool(v) for v in choose_sub.tolist()], + "ins_fire": [bool(v) for v in ins_fire.tolist()], + "sub_samples": [int(s) if s is not None else None for s in sub_samples], + "ins_samples": [int(s) if s is not None else None for s in ins_samples], + "prompt_len": prompt_len_clamped, + "used_reuse": bool(use_reuse), + } + + return x_next, any_edit, _step_trace, out_for_next + + +# -------------------------------- top-level generate ------------------------------- + + +@torch.no_grad() +def generate_editflow_minimal( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + args, + cfg: GenCfg, +) -> tuple[str, dict]: + """ + Returns: + final_text, trace + + Notes on annotations: + • Any local annotated with Annotated[..., "viz-only"] is only used to build + the decode trace for visualization (e.g., GIF rendering) and has no effect + on the actual generation. Such variables are also prefixed with '_' to make + this visually obvious in code. + """ + torch.manual_seed(cfg.seed) + + # If prompt is None, start from BOS alone; otherwise ALWAYS prefix BOS + bos = getattr(tokenizer, "bos_token_id", None) + if bos is None: + raise ValueError("Tokenizer must have a BOS token for this sampler.") + + prompt = args.prompt + if prompt is None: + ids = [bos] # BOS alone + else: + ids = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=True, + add_generation_prompt=True, + ) + # ids = tokenizer.encode(prompt, add_special_tokens=False) + # ids = [bos] + enc["input_ids"] # ALWAYS prefix BOS + + prompt_len = len(ids) + + if args.mask_length: + if getattr(tokenizer, "mask_token_id", None) is None: + raise ValueError( + "Tokenizer must define mask_token_id when --mask_length > 0." + ) + ids = ids + [tokenizer.mask_token_id] * args.mask_length + + x = torch.tensor(ids, dtype=torch.long, device=model.device) + + sched = LinearKappaScheduler() + tau = cfg.tau + steps = math.ceil(1.0 / max(tau, 1e-9)) + + _trace: Annotated[dict, "viz-only: full decode trace for GIF/inspection"] = { + "steps": [], + "init": { + "t": 0.0, + "x_ids": [int(i) for i in x.tolist()], + "prompt_len": int(prompt_len), + }, + "end_t": 0.0, + } + + # Local-only reuse: if previous iteration had no edits, reuse its forward. + prev_out: dict | None = None + prev_had_edits = True # first iteration must run a forward + + t = 0.0 + for _ in range(steps): + # We can reuse prev_out only if the model is declared time-independent + # and the previous step had NO edits (sequence unchanged). + reuse_prev = ( + cfg.time_independent and not prev_had_edits and (prev_out is not None) + ) + + x, edited, _step_trace, prev_out = tau_leap_step_minimal( + x=x, + model=model, + tokenizer=tokenizer, + prompt_len=prompt_len, + t=t, + sched=sched, + cfg=cfg, + prev_out=prev_out, + reuse_prev=reuse_prev, + ) + + _step_trace: Annotated[dict, "viz-only: per-step intermediates for trace"] + _trace["steps"].append(_step_trace) + + prev_had_edits = edited + + t = min(1.0, t + tau) + if t >= 1.0 - args.time_epsilon: + break + + _trace["end_t"] = float(t) + + final_text = tokenizer.decode(x.tolist(), skip_special_tokens=False) + print("[final]") + return final_text, _trace + + +# ---------------------------------------- CLI ------------------------------------- + + +def main(): + @dataclass + class ScriptArgs: + # Required (no default) + model_name_or_path: Annotated[str, "Path or hub id for the model"] + time_independent: Annotated[ + bool, "Whether model is conditioned on time step" + ] = True + + prompt: Annotated[str | None, "Text prompt. If None, start from BOS alone."] = ( + None + ) + # Boolean flag: tyro exposes --edit-prompt / --no-edit-prompt automatically for bools + edit_prompt: Annotated[ + bool, + "Allow delete/substitute and insertions in the prompt region (BOS+prompt).", + ] = False + + # Generation-related args + tau: Annotated[float, "τ-leap size"] = 0.01 + time_epsilon: Annotated[ + float, "Match this with the `time_epsilon` arg used in your EditFlowTrainer" + ] = 1e-3 + mask_length: Annotated[ + int, + "Number of tokens appended after the prompt.\n" + "EditFlow will iteratively substitute, insert, or delete masks to form the output.", + ] = 128 + temperature: Annotated[float, "Token sampling temperature; 0 for greedy."] = 0.7 + + seed: Annotated[int, "Random seed"] = 1234 + verbose: Annotated[bool, "Whether to show intermediate decoding traces"] = True + + # Visualization + make_gif: Annotated[bool, "Render a decoding trace GIF after generation."] = ( + False + ) + gif_path: Annotated[ + str | None, "Output GIF path (default: decode_trace.gif)" + ] = None + frame_ms: Annotated[int, "Per-frame duration in ms"] = 120 + + args = tyro.cli(ScriptArgs) + + cfg = GenCfg( + tau=args.tau, + seed=args.seed, + edit_prompt=args.edit_prompt, + temperature=args.temperature, + verbose=args.verbose, + time_independent=args.time_independent, + ) + + model = AutoModel.from_pretrained( + args.model_name_or_path, + dtype=torch.bfloat16, + device_map="auto", + ).eval() + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + + final_text, trace = generate_editflow_minimal(model, tokenizer, args, cfg) + print(final_text) + + if args.make_gif: + from examples.editflow.viz import render_consecutive_trace_gif + + out = args.gif_path or "decode_trace.gif" + path = render_consecutive_trace_gif( + trace, + tokenizer, + out_path=out, + frame_ms=args.frame_ms, + ) + print(f"[gif saved] {path}") + + +if __name__ == "__main__": + main() diff --git a/dllm/examples/editflow/llada/adapt.py b/dllm/examples/editflow/llada/adapt.py new file mode 100644 index 0000000..23fb199 --- /dev/null +++ b/dllm/examples/editflow/llada/adapt.py @@ -0,0 +1,88 @@ +""" +Local users +------------ +- 1 GPU (LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/editflow/llada/adapt.py \ + --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/editflow/llada/adapt.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/llada/adapt.py" + +- 2 Nodes, 16 GPUs (FSDP): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/llada/adapt.py" +""" + +from dataclasses import dataclass + +import torch +import transformers + +import dllm +from examples.editflow import sft as editflow_sft + + +@dataclass +class ModelArguments(editflow_sft.ModelArguments): + model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct" + lm_head_key: str = "model.transformer.ff_out" + init_editflow_from_src: bool = True + + +@dataclass +class DataArguments(editflow_sft.DataArguments): + dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]" + + +@dataclass +class TrainingArguments(editflow_sft.TrainingArguments): + output_dir: str = ( + "models/EditFlow-LLaDA-8B-Instruct-Adapt/tulu-3-sft-mixture[train:10000,test:1000]" + ) + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + dllm.utils.initial_training_setup(model_args, data_args, training_args) + # Create EditFlow model (bf16 init on CUDA) + ef_cfg = dllm.pipelines.editflow.EditFlowLLaDAConfig.from_pretrained( + model_args.model_name_or_path + ) + with dllm.utils.init_device_context_manager(): + model = transformers.AutoModel.from_config(ef_cfg, dtype=torch.bfloat16) + # Initialize EditFlow model from the src model: copies backbone & clones lm_head + if model_args.init_editflow_from_src: + src_model = transformers.AutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, dtype=torch.bfloat16 + ) + dllm.pipelines.editflow.utils.init_editflow_from_src( + model, src_model, lm_head_key=model_args.lm_head_key + ) + del src_model + model = dllm.utils.load_peft(model, model_args) + + editflow_sft.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + model=model, + ) diff --git a/dllm/examples/editflow/llada/pt.py b/dllm/examples/editflow/llada/pt.py new file mode 100644 index 0000000..da5f1c4 --- /dev/null +++ b/dllm/examples/editflow/llada/pt.py @@ -0,0 +1,67 @@ +""" +Local users +------------ +- 1 GPU (LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/editflow/llada/pt.py \ + --lora True + +- 8 GPUs (DeepSpeed FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/editflow/llada/pt.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:1 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/llada/pt.py" + +- 24 Nodes, 192 GPUs (FSDP): + sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/llada/pt.py" +""" + +from dataclasses import dataclass + +import transformers + +import dllm +from examples.editflow import pt as editflow_pt + + +@dataclass +class ModelArguments(editflow_pt.ModelArguments): + model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base" + lm_head_key: str = "model.transformer.ff_out" + + +@dataclass +class DataArguments(editflow_pt.DataArguments): + dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]" + + +@dataclass +class TrainingArguments(editflow_pt.TrainingArguments): + output_dir: str = ( + "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]" + ) + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + editflow_pt.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + ef_config_cls=dllm.pipelines.editflow.EditFlowLLaDAConfig, + ) diff --git a/dllm/examples/editflow/llada/sft.py b/dllm/examples/editflow/llada/sft.py new file mode 100644 index 0000000..1990ded --- /dev/null +++ b/dllm/examples/editflow/llada/sft.py @@ -0,0 +1,66 @@ +""" +Local users +------------ +- 1 GPU (LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/editflow/llada/sft.py \ + --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/editflow/llada/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:1 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/llada/sft.py" + +- 2 Nodes, 16 GPUs (FSDP): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/editflow/llada/sft.py" +""" + +from dataclasses import dataclass + +import transformers + +from examples.editflow import sft as editflow_sft + + +@dataclass +class ModelArguments(editflow_sft.ModelArguments): + model_name_or_path: str = ( + "models/EditFlow-LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]/checkpoint-final" + ) + + +@dataclass +class DataArguments(editflow_sft.DataArguments): + dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]" + + +@dataclass +class TrainingArguments(editflow_sft.TrainingArguments): + output_dir: str = ( + "models/EditFlow-LLaDA-8B-Instruct-SFT/tulu-3-sft-mixture[train:10000,test:1000]" + ) + + +if __name__ == "__main__": + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + editflow_sft.train( + model_args=model_args, + data_args=data_args, + training_args=training_args, + ) diff --git a/dllm/examples/editflow/pt.py b/dllm/examples/editflow/pt.py new file mode 100644 index 0000000..20320d3 --- /dev/null +++ b/dllm/examples/editflow/pt.py @@ -0,0 +1,176 @@ +import os +import functools +from dataclasses import dataclass, field + +import transformers +import accelerate + +import dllm +from dllm.pipelines import editflow + +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = None # overwrite this + lm_head_key: str = field( + default=None, + metadata={ + "help": ( + "The key to the `lm_head` in the source model for initializing operation heads in the EditFlow model. " + "Overwrite this when `init_editflow_from_src` = True" + ) + }, + ) + init_editflow_from_src: bool = field( + default=True, + metadata={ + "help": "Whether to initialize EditFlow model from the source model." + }, + ) + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]" + text_field: str = "text" + streaming: bool = False + drop_tail: bool = True + insert_eos: bool = field( + default=True, + metadata={ + "help": "False when adjacent samples from the datasets are semantically coherent." + }, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = None # overwrite this + num_train_epochs: float = 20 + learning_rate: float = 3e-4 + # max_steps: int = 2_000 + per_device_train_batch_size: int = 3 + per_device_eval_batch_size: int = 3 + eval_steps: float = 0.1 + save_steps: float = 0.1 + # EditFlow specific args + scheduler_cls: str = field( + default="LinearKappaScheduler", + metadata={ + "help": ( + "The scheduler class controlling κ(t). " + "Available options: see `dllm/utils/schedulers/kappa.py`" + ) + }, + ) + normalize_per_position: bool = field( + default=True, + metadata={"help": "Whether to normalize the loss per position."}, + ) + max_w: float = field( + default=20.0, + metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."}, + ) + x0_sampler: str = field( + default="masks[length:128]", + metadata={ + "help": ( + "Choose the x0 sampler. " + "Available options: see `dllm/pipelines/editflow/utils.py`" + ) + }, + ) + + +def train( + model_args: ModelArguments, + data_args: DataArguments, + training_args: TrainingArguments, + ef_config_cls: type[transformers.PretrainedConfig], +): + # necessary when batch does not contain "labels" field + training_args.label_names = [] + # necessary when batch contains customized fields + training_args.remove_unused_columns = False + # necessary for streaming dataset + training_args.accelerator_config.dispatch_batches = False + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Load base Model and initialize EditFlow Model --------------------------- + # Create EditFlow model (bf16 init on CUDA) + ef_cfg = ef_config_cls.from_pretrained( + model_args.model_name_or_path, + dtype=model_args.dtype, + attn_implementation=model_args.attn_implementation, + ) + with dllm.utils.init_device_context_manager(): + model = transformers.AutoModel.from_config(ef_cfg) + if model_args.init_editflow_from_src: + # Load src model config & weights (bf16 on CUDA) for intializing EditFlow model + src_model = transformers.AutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, dtype=model_args.dtype + ) + # Initialize EditFlow model from the src model: copies backbone & clones lm_head + editflow.utils.init_editflow_from_src( + model, src_model, lm_head_key=model_args.lm_head_key + ) + del src_model + model = dllm.utils.load_peft(model, model_args) + + def _no_flops(*args, **kwargs): + return 0.0 + + model.floating_point_ops = _no_flops + + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + + # ----- Dataset ---------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_pt_dataset( + data_args.dataset_args, + streaming=data_args.streaming, + ) + dataset = dataset.map( + functools.partial( + dllm.utils.tokenize_and_group, + tokenizer=tokenizer, + text_field=data_args.text_field, + seq_length=data_args.max_length, + insert_eos=data_args.insert_eos, + drop_tail=data_args.drop_tail, + ), + batched=True, + remove_columns=dataset["train"].column_names, + **({} if data_args.streaming else {"num_proc": data_args.num_proc}), + **({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}), + ) + if data_args.streaming: + dataset = dataset.shuffle(seed=training_args.seed) + + # ----- Training -------------------------------------------------------------- + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = editflow.EditFlowTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + data_collator=editflow.utils.EditFlowCollator( + tokenizer=tokenizer, x0_sampler=training_args.x0_sampler + ), + scheduler=dllm.core.schedulers.make_kappa_scheduler( + training_args.scheduler_cls + ), + normalize_per_position=training_args.normalize_per_position, + max_w=training_args.max_w, + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) diff --git a/dllm/examples/editflow/sft.py b/dllm/examples/editflow/sft.py new file mode 100644 index 0000000..f0718c9 --- /dev/null +++ b/dllm/examples/editflow/sft.py @@ -0,0 +1,192 @@ +import os +from functools import partial +from dataclasses import dataclass, field + +import transformers +import accelerate + +import dllm +from dllm.pipelines import editflow +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = None # overwrite this + lm_head_key: str = field( + default=None, + metadata={ + "help": ( + "The key to the `lm_head` in the source model for initializing operation heads in the EditFlow model. " + "Overwrite this when `init_editflow_from_src` = True" + ) + }, + ) + init_editflow_from_src: bool = field( + default=True, + metadata={ + "help": "Whether to initialize EditFlow model from the source model." + }, + ) + init_editflow_from_editflow: bool = False + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "tatsu-lab/alpaca" + load_preprocessed_data: bool = False + mask_prompt_loss: bool = field( + default=True, + metadata={"help": "Whether to mask the loss on the prompt tokens"}, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = None # overwrite this + per_device_train_batch_size: int = 2 + per_device_eval_batch_size: int = 2 + learning_rate: float = 5e-5 + # EditFlow specific args + scheduler_cls: str = field( + default="LinearKappaScheduler", + metadata={ + "help": ( + "The scheduler class controlling κ(t). " + "Available options: see `dllm/utils/schedulers/kappa.py`" + ) + }, + ) + normalize_per_position: bool = field( + default=True, + metadata={"help": "Whether to normalize the loss per position."}, + ) + max_w: float = field( + default=20.0, + metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."}, + ) + x0_sampler: str = field( + default="masks[length:128]", + metadata={ + "help": ( + "Choose the x0 sampler. " + "Available options: see `dllm/pipelines/editflow/utils.py`" + ) + }, + ) + + +def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict: + # - `input_ids`` = prompt + response + # - `prompt_len` marks the prompt span to EXCLUDE from loss. + # (Remove prompt_len to train on all tokens—if so, ensure a BOS is prepended.) + prompt_response_tokens = tokenizer.apply_chat_template( + row["messages"], + tokenize=True, + add_generation_prompt=False, + ) + if mask_prompt_loss: + prompt_tokens = tokenizer.apply_chat_template( + row["messages"][:-1], + tokenize=True, + add_generation_prompt=True, + ) + return { + "input_ids": prompt_response_tokens, + "prompt_len": len(prompt_tokens), + } + else: + # When training on all tokens, prepend a BOS token (if missing) + # so the model can insert to the left of the very first token. + if prompt_response_tokens[0] != tokenizer.bos_token_id: + prompt_response_tokens = [tokenizer.bos_token_id] + prompt_response_tokens + return {"input_ids": prompt_response_tokens} + + +def train( + model_args: ModelArguments, + data_args: DataArguments, + training_args: TrainingArguments, + ef_config_cls: type[transformers.PretrainedConfig], +): + # necessary when batch does not contain "labels" field + training_args.label_names = [] + # necessary when batch contains customized fields + training_args.remove_unused_columns = False + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Load EditFlow Model ---------------------------------------------------- + if model_args.init_editflow_from_editflow: + model = dllm.utils.get_model(model_args=model_args) + else: + ef_cfg = ef_config_cls.from_pretrained( + model_args.model_name_or_path, + dtype=model_args.dtype, + attn_implementation=model_args.attn_implementation, + ) + with dllm.utils.init_device_context_manager(): + model = transformers.AutoModel.from_config(ef_cfg) + if model_args.init_editflow_from_src: + # Load src model config & weights (bf16 on CUDA) for intializing EditFlow model + src_model = transformers.AutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, dtype=model_args.dtype + ) + # Initialize EditFlow model from the src model: copies backbone & clones lm_head + editflow.utils.init_editflow_from_src( + model, src_model, lm_head_key=model_args.lm_head_key + ) + del src_model + model = dllm.utils.load_peft(model, model_args) + + def _no_flops(*args, **kwargs): + return 0.0 + + model.floating_point_ops = _no_flops + + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + + # ----- Dataset ---------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_sft_dataset( + data_args.dataset_args, + load_preprocessed_data=data_args.load_preprocessed_data, + ) + if not data_args.load_preprocessed_data: + map_fn = partial( + sft_map_fn, + tokenizer=tokenizer, + mask_prompt_loss=data_args.mask_prompt_loss, + ) + dataset = dataset.map( + map_fn, + num_proc=data_args.num_proc, + desc="Mapping dataset to SFT format", + ) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + + # ----- Training -------------------------------------------------------------- + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = editflow.EditFlowTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + data_collator=editflow.utils.EditFlowCollator( + tokenizer=tokenizer, x0_sampler=training_args.x0_sampler + ), + scheduler=dllm.core.schedulers.make_kappa_scheduler( + training_args.scheduler_cls + ), + normalize_per_position=training_args.normalize_per_position, + max_w=training_args.max_w, + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) diff --git a/dllm/examples/editflow/viz.py b/dllm/examples/editflow/viz.py new file mode 100644 index 0000000..5df7cca --- /dev/null +++ b/dllm/examples/editflow/viz.py @@ -0,0 +1,489 @@ +# ------------------------------ Visualization (NEW) ------------------------------ +# Diffusion-style consecutive output: only show the CURRENT output per frame. +# ------------------ Visualization (sanitized, masks stripped) ------------------ +from PIL import Image, ImageDraw, ImageFont + +import re +import unicodedata +from typing import Optional, List, Tuple, Annotated + + +def render_consecutive_trace_gif( + trace: dict, + tokenizer, + out_path: str = "decode_trace.gif", + font_size: int = 30, + line_spacing: int = 12, + frame_ms: int = 250, + final_ms: int = 5000, # final clean frame duration (ms) + max_width: int = 1400, + max_height: int = 3000, + margin: int = 32, + title_color=(80, 80, 80), + text_color=(0, 0, 0), # base black + mask_color=(150, 150, 150), + sub_nonmask_color=(200, 0, 0), # persistent red + ins_color=(0, 0, 200), # persistent blue + del_strike_color=(120, 120, 120), + events_color=(30, 30, 30), + box_color=(120, 120, 120), + bg_color=(255, 255, 255), +): + """ + Persistent coloring keyed by token *instance* (not token id): + - Inserted tokens -> BLUE across frames (until deleted/substituted again). + - Substitution nonmask→nonmask -> RED across frames (until deleted/substituted again). + - Substitution mask→nonmask -> stays BLACK (no extra color). + Adds a final clean frame (5s) with no events box. + """ + from PIL import Image, ImageDraw, ImageFont + import unicodedata + + # ---------- font ---------- + try: + font = ImageFont.truetype( + "assets/JetBrainsMono-VariableFont_wght.ttf", font_size + ) + except Exception: + print(f"fail to load target font") + font = ImageFont.load_default() + + # ---------- helpers ---------- + def _sanitize_token(s: str) -> str: + vis_mask_token = "[m]" + s = unicodedata.normalize("NFKC", s) + s = s.replace("Ċ", "\n").replace("▁", " ").replace("Ġ", " ") + s = s.replace("\t", " ") + s = s.replace("\u00a0", " ").replace("\u2007", " ").replace("\u202f", " ") + + # replace mask variants + if "mdm_mask" in s.lower(): + s = re.sub(r"<[\|]?\s*mdm_mask\s*[\|]?>", "[m]", s, flags=re.IGNORECASE) + s = s.replace("mdm_mask", "[m]") + if "mask" in s.lower(): + s = re.sub(r"<[\|]?\s*mask\s*[\|]?>", "[m]", s, flags=re.IGNORECASE) + s = s.replace("mask", "[m]") + + # replace <|...|> format tokens with bracketed form + s = re.sub(r"<\|\s*(.*?)\s*\|>", r"[\1]", s) + return s + + def _tok_str(tok_id: int) -> str: + try: + s = tokenizer.decode([int(tok_id)], skip_special_tokens=False) + s = s if s.strip() else f"<{int(tok_id)}>" + except Exception: + s = f"<{int(tok_id)}>" + return s.replace("\n", "\\n") + + TOKEN_RE = re.compile(r"\s+|\S+") + + def _wrap_text(draw: ImageDraw.ImageDraw, text: str, width_px: int) -> List[str]: + if text == "": + return [""] + lines: List[str] = [] + for para in text.split("\n"): + tokens = TOKEN_RE.findall(para) + cur = "" + for tok in tokens: + candidate = cur + tok + if draw.textlength(candidate, font=font) <= width_px: + cur = candidate + else: + if cur: + lines.append(cur) + cur = tok + while ( + draw.textlength(cur, font=font) > width_px and len(cur) > 0 + ): + lo, hi, fit = 1, len(cur), 1 + while lo <= hi: + mid = (lo + hi) // 2 + if draw.textlength(cur[:mid], font=font) <= width_px: + fit, lo = mid, mid + 1 + else: + hi = mid - 1 + lines.append(cur[:fit]) + cur = cur[fit:] + else: + t = tok + while draw.textlength(t, font=font) > width_px and len(t) > 0: + lo, hi, fit = 1, len(t), 1 + while lo <= hi: + mid = (lo + hi) // 2 + if draw.textlength(t[:mid], font=font) <= width_px: + fit, lo = mid, mid + 1 + else: + hi = mid - 1 + lines.append(t[:fit]) + t = t[fit:] + cur = t + lines.append(cur) + return lines or [""] + + tmp_img = Image.new("RGB", (10, 10), bg_color) + tmp_draw = ImageDraw.Draw(tmp_img) + text_width_budget = max_width - 2 * margin + + # mask detection + MASK_IDS = set() + if getattr(tokenizer, "mask_token_id", None) is not None: + MASK_IDS.add(int(tokenizer.mask_token_id)) + MASK_STRINGS = set() + mt = getattr(tokenizer, "mask_token", None) + if mt is not None: + MASK_STRINGS.add(str(mt)) + MASK_STRINGS.add("") + + def _is_mask_token(tok_id: int, tok_str_exact: str) -> bool: + return (int(tok_id) in MASK_IDS) or (tok_str_exact in MASK_STRINGS) + + def _wrap_tokens_with_index(tokens, deleted_flags): + lines, cur, cur_w = [], [], 0 + for i, tok in enumerate(tokens): + t = _sanitize_token(tok) + parts = t.split("\n") + for j, seg in enumerate(parts): + seg_rest = seg + while seg_rest: + w = tmp_draw.textlength(seg_rest, font=font) + if cur_w + w <= text_width_budget or not cur: + cur.append((seg_rest, i, deleted_flags[i])) + cur_w += w + seg_rest = "" + else: + lines.append(cur) + cur, cur_w = [], 0 + if j != len(parts) - 1: + lines.append(cur) + cur, cur_w = [], 0 + if cur: + lines.append(cur) + return lines + + def _draw_dashed_rectangle( + draw, xy, dash=8, gap=6, width=2, outline=(120, 120, 120) + ): + x0, y0, x1, y1 = xy + x = x0 + while x < x1: + x2 = min(x + dash, x1) + draw.line([(x, y0), (x2, y0)], fill=outline, width=width) + draw.line([(x, y1), (x2, y1)], fill=outline, width=width) + x += dash + gap + y = y0 + while y < y1: + y2 = min(y + dash, y1) + draw.line([(x0, y), (x0, y2)], fill=outline, width=width) + draw.line([(x1, y), (x1, y2)], fill=outline, width=width) + y += dash + gap + + def _ops_lines_for_step(st: dict): + if st is None: + return ["(no events)"] + lines = [] + x_before = st["x_before_ids"] + choose_del = st["choose_del"] + choose_sub = st["choose_sub"] + sub_samples = st["sub_samples"] + ins_samples = st["ins_samples"] + T = len(x_before) + for i in range(T): + if choose_del[i]: + lines.append(f"DEL@{i}:{_tok_str(int(x_before[i]))}") + elif choose_sub[i]: + lines.append( + f"SUB@{i}:{_tok_str(int(x_before[i]))}->{_tok_str(int(sub_samples[i]))}" + ) + if ins_samples[i] is not None: + lines.append(f"INS@{i}->{i+1}:{_tok_str(int(ins_samples[i]))}") + if not lines: + lines.append("(no events)") + return lines + + # ---- Instance-id machinery ---- + next_instance_id = 0 + + def _new_inst(): + nonlocal next_instance_id + val = next_instance_id + next_instance_id += 1 + return val + + # Current sequence at the *start* (ids + instance_ids) + curr_ids = list(trace["init"]["x_ids"]) + curr_inst = [_new_inst() for _ in curr_ids] + + # Persistent color by instance_id: {"blue", "red"} + color_by_inst = {} + + # ---------- PASS 1: measure required heights per frame ---------- + measurement_payload = [] + + for step_idx, st in enumerate([None] + trace["steps"]): + # build augmented view + if st is None: + aug_ids = list(curr_ids) + deleted_flags = [False] * len(aug_ids) + else: + x_before = st["x_before_ids"] + choose_del = st["choose_del"] + after_ids = st["x_after_ids"] + deleted_positions = [i for i, d in enumerate(choose_del) if d] + + aug_ids = list(after_ids) + deleted_flags = [False] * len(after_ids) + for i in sorted(deleted_positions, reverse=True): + aug_ids.insert(i, x_before[i]) + deleted_flags.insert(i, True) + + tokens = tokenizer.convert_ids_to_tokens(aug_ids) + wrapped_lines = _wrap_tokens_with_index(tokens, deleted_flags) + + # estimate ops lines for this step + if st: + ops_text = " • " + " • ".join(_ops_lines_for_step(st)) + else: + ops_text = "(no events)" + ops_lines = _wrap_text(tmp_draw, ops_text, text_width_budget) + + # compute height needed + body_h = len(wrapped_lines) * (font_size + line_spacing) + ops_h = len(ops_lines) * (font_size + line_spacing) + font_size # + 20 + required_h = margin + (font_size + line_spacing) + body_h + 20 + + measurement_payload.append( + { + "step_idx": step_idx, + "st": st, + "aug_ids": aug_ids, + "tokens": tokens, + "deleted_flags": deleted_flags, + "wrapped_lines": wrapped_lines, + "ops_lines": ops_lines, + "required_h": required_h, + } + ) + + # Measure clean final frame (no events) + final_text_ids = ( + trace["steps"][-1]["x_after_ids"] if trace["steps"] else trace["init"]["x_ids"] + ) + final_tokens = tokenizer.convert_ids_to_tokens(final_text_ids) + wrapped_clean = _wrap_tokens_with_index(final_tokens, [False] * len(final_tokens)) + clean_body_h = len(wrapped_clean) * (font_size + line_spacing) + clean_required_h = margin + (font_size + line_spacing) + clean_body_h + + # Pick a single uniform canvas height + max_required_h = max( + [p["required_h"] for p in measurement_payload] + [clean_required_h] + ) # + 20 + H = min(max_required_h, max_height) + W = max_width + + # For each frame we need an augmented view (with deleted placeholders) to draw + frames = [] + + # Iterate steps; for step_idx==0 we still draw "initial state" + steps_with_initial = [None] + trace["steps"] + + for step_idx, st in enumerate(steps_with_initial): + if st is None: + # initial frame: augmented is just current tokens + aug_ids = list(curr_ids) + aug_inst = list(curr_inst) + aug_deleted = [False] * len(aug_ids) + ops_lines = ["(no events)"] + title = "initial state" + else: + title = f"t = {st['t']:.3f}" + x_before = list(st["x_before_ids"]) + choose_del = list(st["choose_del"]) + choose_sub = list(st["choose_sub"]) + sub_samples = list(st["sub_samples"]) + ins_samples = list(st["ins_samples"]) + assert ( + len(x_before) == len(curr_ids) == len(curr_inst) + ), "trace 'x_before' must match current sequence." + + # Build augmented (drawn) and next (state-after) in one pass + aug_ids, aug_inst, aug_deleted = [], [], [] + next_ids, next_inst = [], [] + + for i in range(len(x_before)): + before_id = int(curr_ids[i]) + before_inst = curr_inst[i] + + if choose_del[i]: + # show deleted placeholder (strike-through) + aug_ids.append(before_id) + aug_inst.append(None) + aug_deleted.append(True) + # remove from next; also clear any persistent color + color_by_inst.pop(before_inst, None) + else: + if choose_sub[i]: + after_id = int(sub_samples[i]) + # in augmented we show the *after* token at same instance + aug_ids.append(after_id) + aug_inst.append(before_inst) + aug_deleted.append(False) + next_ids.append(after_id) + next_inst.append(before_inst) + + # update persistence by source type + if int(before_id) in MASK_IDS: + # mask → nonmask: no extra color (ensure cleared) + color_by_inst.pop(before_inst, None) + else: + # nonmask → nonmask: mark RED + color_by_inst[before_inst] = "red" + else: + # keep + aug_ids.append(before_id) + aug_inst.append(before_inst) + aug_deleted.append(False) + next_ids.append(before_id) + next_inst.append(before_inst) + + # insertion AFTER position i + if ins_samples[i] is not None: + ins_id = int(ins_samples[i]) + ins_inst = _new_inst() + aug_ids.append(ins_id) + aug_inst.append(ins_inst) + aug_deleted.append(False) + next_ids.append(ins_id) + next_inst.append(ins_inst) + # mark persistent BLUE for this *instance only* + color_by_inst[ins_inst] = "blue" + + # commit next state + curr_ids, curr_inst = next_ids, next_inst + ops_text = " • " + " • ".join(_ops_lines_for_step(st)) + ops_lines = _wrap_text(tmp_draw, ops_text, text_width_budget) + + # ----- render this frame ----- + tokens = tokenizer.convert_ids_to_tokens(aug_ids) + wrapped_lines = _wrap_tokens_with_index(tokens, aug_deleted) + + img = Image.new("RGB", (W, H), bg_color) + draw = ImageDraw.Draw(img) + + y = margin + draw.text((margin, y), title, fill=title_color, font=font) + y += font_size + line_spacing + + for line in wrapped_lines: + x = margin + for seg_text, tok_idx, is_deleted in line: + tok_id = int(aug_ids[tok_idx]) + tok_str_exact = tokens[tok_idx] + inst = aug_inst[tok_idx] + + if is_deleted: + # strike deleted — grey masks slightly different if desired + strike_color = ( + mask_color + if _is_mask_token(tok_id, tok_str_exact) + else del_strike_color + ) + strike = "".join(ch + "\u0336" for ch in seg_text) + draw.text((x, y), strike, fill=strike_color, font=font) + x += tmp_draw.textlength(strike, font=font) + else: + # choose color by *instance* + color = text_color + if inst is not None and inst in color_by_inst: + color = ( + ins_color + if color_by_inst[inst] == "blue" + else sub_nonmask_color + ) + elif _is_mask_token(tok_id, tok_str_exact): + color = mask_color + draw.text((x, y), seg_text, fill=color, font=font) + x += tmp_draw.textlength(seg_text, font=font) + y += font_size + line_spacing + + # draw events box for all but the extra final-clean frame we'll add later + # if step_idx != len(steps_with_initial) - 1: + # y += 20 + # x0, y0 = margin, y + # x1 = max_width - margin + # box_h = len(ops_lines) * (font_size + line_spacing) + font_size + 20 + # y1 = y0 + box_h + # _draw_dashed_rectangle(draw, (x0, y0, x1, y1), outline=box_color) + # draw.text((x0 + 10, y0 + 10), "events", fill=events_color, font=font) + # yy = y0 + font_size + 20 + # for l in ops_lines: + # draw.text((x0 + 10, yy), l, fill=events_color, font=font) + # yy += font_size + line_spacing + # y += 10 + frames.append(img) + + # ----- extra final clean frame (no events box), 5s ----- + final_ids = list(curr_ids) + final_inst = list(curr_inst) + final_tokens = tokenizer.convert_ids_to_tokens(final_ids) + + # wrap without deleted flags + def _wrap_clean(tokens): + lines, cur, cur_w = [], [], 0 + for i, tok in enumerate(tokens): + t = _sanitize_token(tok) + parts = t.split("\n") + for j, seg in enumerate(parts): + seg_rest = seg + while seg_rest: + w = tmp_draw.textlength(seg_rest, font=font) + if cur_w + w <= text_width_budget or not cur: + cur.append((seg_rest, i)) + cur_w += w + seg_rest = "" + else: + lines.append(cur) + cur, cur_w = [], 0 + if j != len(parts) - 1: + lines.append(cur) + cur, cur_w = [], 0 + if cur: + lines.append(cur) + return lines + + wrapped_clean = _wrap_clean(final_tokens) + + clean_img = Image.new("RGB", (W, H), bg_color) + draw = ImageDraw.Draw(clean_img) + draw.text((margin, margin), "final text", fill=title_color, font=font) + y = margin + font_size + line_spacing + for line in wrapped_clean: + x = margin + for seg_text, tok_idx in line: + tok_id = int(final_ids[tok_idx]) + tok_str_exact = final_tokens[tok_idx] + inst = final_inst[tok_idx] + color = text_color + if inst in color_by_inst: + color = ( + ins_color if color_by_inst[inst] == "blue" else sub_nonmask_color + ) + elif _is_mask_token(tok_id, tok_str_exact): + color = mask_color + draw.text((x, y), seg_text, fill=color, font=font) + x += tmp_draw.textlength(seg_text, font=font) + y += font_size + line_spacing + frames.append(clean_img) + + # save GIF + durations = [frame_ms] * (len(frames) - 1) + [final_ms] + frames[0].save( + out_path, + save_all=True, + append_images=frames[1:], + duration=durations, + loop=0, + disposal=2, + optimize=True, + ) + return out_path diff --git a/dllm/examples/llada/README.md b/dllm/examples/llada/README.md new file mode 100644 index 0000000..9dc62dd --- /dev/null +++ b/dllm/examples/llada/README.md @@ -0,0 +1,206 @@ +# LLaDA + +> 📄 Paper: [Large Language Diffusion Models](https://arxiv.org/abs/2502.09992) | 💻 Code: [github.com/ML-GSAI/LLaDA](https://github.com/ML-GSAI/LLaDA) + +Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **LLaDA**. + +## Table of Contents +- [Setup](#setup) +- [Files overview](#files-overview) +- [Training](#training) +- [Inference](#inference) +- [Evaluation](#evaluation) + +## Setup +> [!IMPORTANT] +> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details. +> +> **MoE checkpoints:** For models like [`LLaDA-MoE-7B-A1B-Base`](https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base), set `"model_type"` to `"lladamoe"` in the checkpoint’s `config.json`: +> ```diff +> - "model_type": "llada", +> + "model_type": "lladamoe", +> ``` +> + + +## Files overview +``` +# tools relevant with LLaDA +dllm/pipelines/llada +├── __init__.py # Package initialization +├── models/ +│ ├── configuration_lladamoe.py # LLaDA-MoE model configuration +│ ├── configuration_llada.py # LLaDA model configuration +│ ├── modeling_lladamoe.py # LLaDA-MoE model architecture +│ └── modeling_llada.py # LLaDA model architecture +├── generator.py # Inference logic +└── trainer.py # Training logic (pretraining and finetuning) + +# example entry points for training / inference / evaluation +examples/llada +├── chat.py # Interactive inference example +├── eval.sh # Automatic evaluation script +├── generate.py # Inference example +├── pt.py # Pretraining example +├── README.md # Documentation (you are here) +└── sft.py # Supervised finetuning example +``` + + + + + +## Training +### Finetuning + +For example, to SFT [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) for instruction following on 8 GPUs, run: +```shell +accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/llada/sft.py \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \ + --max_length 1024 \ + --num_train_epochs 4 \ + --learning_rate 2e-5 +``` +If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run: +```shell +sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/sft.py" \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "models/LLaDA-8B-SFT/tulu-3-sft-mixture" \ + --max_length 1024 \ + --num_train_epochs 4 \ + --learning_rate 2e-5 +``` + + + +#### Reproducing [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) +Though LLaDA is trained on proprietary data, we tried our best to reproduce [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) by finetuning [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) using our training pipeline on public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): + +```shell +# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training) +python dllm/tools/preprocess_sft_dataset.py \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ + --sft_map_fn_path "dllm.utils.default_sft_map_fn" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "data/sft/llada/tulu-3-sft-mixture" \ + --num_proc 64 + +# train on 24*8=192 A100s with FSDP, take about 8 hours +sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/sft.py" \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ + --dataset_args "data/sft/llada/tulu-3-sft-mixture" \ + --load_preprocessed_data True \ + --output_dir "models/LLaDA-8B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \ + --max_length 2048 \ + --truncation "right" \ + --group_by_length True \ + --num_train_epochs 5 \ + --learning_rate 1e-5 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --eval_on_start False \ + --eval_steps 0.1 \ + --save_steps 0.05 +``` + + + +### Pretraining + +Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP: +```shell +sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/pt.py" \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ + --dataset_args "mlfoundations/dclm-baseline-1.0" \ + --output_dir "models/LLaDA-8B-PT/dclm-baseline-1.0" \ + --max_length 1024 \ + --max_steps 2000 \ + --learning_rate 3e-4 +``` + +## Inference +We support batch inference for standard generation and infilling: + +```shell +python examples/llada/generate.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" +``` +We also support interactive multi-turn dialogue with visualization: + +```shell +python examples/llada/chat.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" +``` + +## Evaluation +> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation. + +For example, to evaluate [LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [MMLU-Pro](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run: +```shell +# Use model_args to adjust the generation arguments for evalution. +accelerate launch --num_processes 4 \ + dllm/pipelines/llada/eval.py \ + --tasks "mmlu_pro" \ + --model "llada" \ + --apply_chat_template \ + --num_fewshot 0 \ + --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0" +``` + +To automatically evaluate [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) and [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on all benchmarks, run: +```shell +bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Instruct --instruct True +bash examples/llada/eval.sh --model_name_or_path GSAI-ML/LLaDA-8B-Base --instruct False +``` + +### Evaluation results + + + +> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [LLaDA](https://github.com/ML-GSAI/LLaDA) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon. + + + +|               | MMLU | BBH | ARC‑C | Hellaswag | TruthfulQA | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | CEval | CMMLU | +|:----------------|:----:|:---:|:-----:|:-----------:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:------:| +| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(reported)| 65.9 | 49.7 | 45.9 | 70.5 | 46.1 | 74.8 | 73.6 | 70.3 | 31.4 | 25.2 | 35.4 | 40.0 | 70.5 | 69.9 | +| [`LLaDA-8B-Base`](https://huggingface.co/GSAI-ML/LLaDA-8B-Base)(evaluated)| 65.8 | – | 45.7 | 69.3 | 45.6 | 70.7 | 70.6 | 70.4 | – | – | 32.3 | 38.8 | 70.2 | 69.9 | + + +

+Table 1. Evaluation results of + +LLaDA-8B-Base +. +

+ +|                 | MMLU | MMLU‑Pro | ARC‑C | Hellaswag | GSM8K | Math | GPQA | HumanEval | MBPP | +|:----------------|:----:|:---------:|:-----:|:-----------:|:-----:|:----:|:----:|:-----------:|:----:| +| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(reported) | 65.5 | 37.0 | 88.5 | 74.6 | 69.4 | 31.9 | 33.3 | 49.4 | 41.0 | +| [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)(evaluated) | 67.3 | 36.2 | 86.6 | 76.7 | 81.1 | – | – | 65.0 | 70.2 | + +

+Table 2. Evaluation results of + +LLaDA-8B-Instruct +. +

diff --git a/dllm/examples/llada/chat.py b/dllm/examples/llada/chat.py new file mode 100644 index 0000000..f7d311c --- /dev/null +++ b/dllm/examples/llada/chat.py @@ -0,0 +1,74 @@ +""" +Interactive chat / generation script for LLaDA models. + +Examples +-------- +# Chat mode (multi-turn, chat template) +python -u examples/llada/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True + +# Raw single-turn generation +python -u examples/llada/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat False +""" + +import sys +from dataclasses import dataclass +import transformers + +import dllm +from dllm.pipelines import llada +from dllm.tools.chat import multi_turn_chat, single_turn_generate + + +@dataclass +class ScriptArguments: + model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct" + seed: int = 42 + chat: bool = True + visualize: bool = True + + def __post_init__(self): + # same base-path resolution logic as in generate.py + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +@dataclass +class GeneratorConfig(llada.LLaDAGeneratorConfig): + steps: int = 128 + max_new_tokens: int = 128 + block_length: int = 32 + temperature: float = 0.0 + remasking: str = "low_confidence" + + +def main(): + parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig)) + script_args, gen_config = parser.parse_args_into_dataclasses() + transformers.set_seed(script_args.seed) + + model = dllm.utils.get_model(model_args=script_args).eval() + tokenizer = dllm.utils.get_tokenizer(model_args=script_args) + generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer) + + if script_args.chat: + multi_turn_chat( + generator=generator, + gen_config=gen_config, + visualize=script_args.visualize, + ) + else: + print("\nSingle-turn generation (no chat template).") + single_turn_generate( + generator=generator, + gen_config=gen_config, + visualize=script_args.visualize, + ) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nInterrupted. Bye!") + sys.exit(0) diff --git a/dllm/examples/llada/eval.sh b/dllm/examples/llada/eval.sh new file mode 100644 index 0000000..0c3a7c7 --- /dev/null +++ b/dllm/examples/llada/eval.sh @@ -0,0 +1,151 @@ +#!/usr/bin/env bash +# ===== Mandatory for proper import and evaluation ===== +export PYTHONPATH=.:$PYTHONPATH +export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation +export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset + + +# ===== Optional but recommended for stability and debugging ===== +export PYTHONBREAKPOINT=0 # Disable interactive breakpoints +export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks +export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs +export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging + + +# ===== Input Arguments ===== +model_name_or_path="GSAI-ML/LLaDA-8B-Instruct" +instruct=True +num_gpu=4 +while [[ $# -gt 0 ]]; do + case "$1" in + --model_name_or_path) + model_name_or_path="$2"; shift 2 ;; + --instruct) + instruct="$2"; shift 2 ;; + --num_gpu) + num_gpu="$2"; shift 2 ;; + esac +done + +# ===== Conditional Configurations ===== +if [ "$instruct" = "True" ]; then + echo ">>> Running in INSTRUCT mode" + common_args="--model llada --apply_chat_template" +else + echo ">>> Running in BASE mode" + common_args="--model llada" +fi + + +# ======================= +# Generation Tasks +# ======================= + +if [ "$instruct" = "True" ]; then + # Instruct Generation Tasks + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks gsm8k_cot --num_fewshot 8 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks bbh --num_fewshot 3 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks minerva_math --num_fewshot 4 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks humaneval_instruct --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks mbpp_llada_instruct --num_fewshot 3 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + +else + # Base Generation Tasks + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks gsm8k --num_fewshot 8 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks bbh --num_fewshot 3 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks minerva_math --num_fewshot 4 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks humaneval --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks mbpp --num_fewshot 3 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0" +fi + + +# ======================= +# Likelihood Tasks +# ======================= + +if [ "$instruct" = "True" ]; then + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks mmlu_generative --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=3,steps=3,block_length=3,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks mmlu_pro --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks hellaswag_gen --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=3,steps=3,block_length=3,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks arc_challenge_chat --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=5,steps=5,block_length=5,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks gpqa_n_shot_gen --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=32,steps=32,block_length=32,cfg=0.0" + +else + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks gpqa_main_n_shot --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks truthfulqa_mc2 --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=2.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks arc_challenge --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks hellaswag --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks winogrande --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks piqa --num_fewshot 0 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=128,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.5" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks mmlu --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks cmmlu --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0" + + accelerate launch --num_processes ${num_gpu} dllm/pipelines/llada/eval.py \ + --tasks ceval-valid --num_fewshot 5 ${common_args} \ + --model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=1024,cfg=0.0" +fi diff --git a/dllm/examples/llada/generate.py b/dllm/examples/llada/generate.py new file mode 100644 index 0000000..afa131c --- /dev/null +++ b/dllm/examples/llada/generate.py @@ -0,0 +1,116 @@ +""" +python -u examples/llada/generate.py --model_name_or_path "YOUR_MODEL_PATH" +""" + +from dataclasses import dataclass + +import transformers + +import dllm +from dllm.tools.chat import decode_trim +from dllm.pipelines import llada + + +@dataclass +class ScriptArguments: + model_name_or_path: str = "GSAI-ML/LLaDA-8B-Instruct" + seed: int = 42 + visualize: bool = True + + def __post_init__(self): + self.model_name_or_path = dllm.utils.resolve_with_base_env( + self.model_name_or_path, "BASE_MODELS_DIR" + ) + + +@dataclass +class GeneratorConfig(llada.LLaDAGeneratorConfig): + steps: int = 128 + max_new_tokens: int = 128 + block_length: int = 32 + temperature: float = 0.0 + remasking: str = "low_confidence" + + +parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig)) +script_args, gen_config = parser.parse_args_into_dataclasses() +transformers.set_seed(script_args.seed) + +# Load model & tokenizer +model = dllm.utils.get_model(model_args=script_args).eval() +tokenizer = dllm.utils.get_tokenizer(model_args=script_args) +generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer) +terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer( + tokenizer=tokenizer +) + +# --- Example 1: Batch generation --- +print("\n" + "=" * 80) +print("TEST: llada.generate()".center(80)) +print("=" * 80) + +messages = [ + [{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}], + [{"role": "user", "content": "Please write an educational python function."}], +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, +) + +outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True) +sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs) + +for iter, s in enumerate(sequences): + print("\n" + "-" * 80) + print(f"[Case {iter}]") + print("-" * 80) + print(s.strip() if s.strip() else "") +print("\n" + "=" * 80 + "\n") + +if script_args.visualize: + terminal_visualizer.visualize(outputs.histories, rich=True) + +# --- Example 2: Batch fill-in-the-blanks --- +print("\n" + "=" * 80) +print("TEST: llada.infilling()".center(80)) +print("=" * 80) + +masked_messages = [ + [ + {"role": "user", "content": tokenizer.mask_token * 20}, + { + "role": "assistant", + "content": "Sorry, I do not have answer to this question.", + }, + ], + [ + {"role": "user", "content": "Please write an educational python function."}, + { + "role": "assistant", + "content": "def hello_" + tokenizer.mask_token * 20 + " return", + }, + ], +] + +inputs = tokenizer.apply_chat_template( + masked_messages, + add_generation_prompt=False, + tokenize=True, +) + +outputs = generator.infill(inputs, gen_config, return_dict_in_generate=True) +sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs) + +for iter, (i, s) in enumerate(zip(inputs, sequences)): + print("\n" + "-" * 80) + print(f"[Case {iter}]") + print("-" * 80) + print("[Masked]:\n" + tokenizer.decode(i)) + print("\n[Filled]:\n" + (s.strip() if s.strip() else "")) +print("\n" + "=" * 80 + "\n") + +if script_args.visualize: + terminal_visualizer.visualize(outputs.histories, rich=True) diff --git a/dllm/examples/llada/pt.py b/dllm/examples/llada/pt.py new file mode 100644 index 0000000..3fc7659 --- /dev/null +++ b/dllm/examples/llada/pt.py @@ -0,0 +1,174 @@ +""" +Local users +------------ +- 1 GPU (4bit quant & LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/llada/pt.py \ + --load_in_4bit True --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/llada/pt.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 24 Nodes, 192 GPUs (FSDP): + sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/pt.py" +""" + +import os +import functools +from dataclasses import dataclass, field + +import torch +import transformers +import accelerate + +import dllm + +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + # Uses only the configuration from model_name_or_path to initialize the model from scratch + model_name_or_path: str = ( + "GSAI-ML/LLaDA-8B-Base" # "inclusionAI/LLaDA-MoE-7B-A1B-Base" + ) + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "mlfoundations/dclm-baseline-1.0[train:10_000_000,test:10_000]" + text_field: str = "text" + streaming: bool = True + drop_tail: bool = True + insert_eos: bool = field( + default=True, + metadata={ + "help": "False when adjacent samples from the datasets are semantically coherent." + }, + ) + random_length_ratio: float = field( + default=0.01, + metadata={ + "help": ( + "The probability of randomly cut sequences during training. " + "See https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training for reference." + ) + }, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = ( + "models/LLaDA-8B-PT/dclm-baseline-1.0[train:10_000_000,test:10_000]" + ) + learning_rate: float = 3e-4 + max_steps: int = 2_000 + per_device_train_batch_size: int = 4 + gradient_accumulation_steps: int = 4 + eval_steps: float = 0.05 + save_steps: float = 0.05 + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + # necessary for streaming dataset + if data_args.streaming: + training_args.accelerator_config.dispatch_batches = False + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + # initialize model weights from scratch + config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) + with dllm.utils.init_device_context_manager(): + model = transformers.AutoModel.from_config( + config, dtype=torch.bfloat16, init_params=True + ) + + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + # ----- Optional PEFT: LoRA ---------------------------------------------------- + model = dllm.utils.load_peft(model=model, model_args=model_args) + + # ----- Dataset ---------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_pt_dataset( + data_args.dataset_args, + streaming=data_args.streaming, + ) + dataset = dataset.map( + functools.partial( + dllm.utils.tokenize_and_group, + tokenizer=tokenizer, + text_field=data_args.text_field, + seq_length=data_args.max_length, + insert_eos=data_args.insert_eos, + drop_tail=data_args.drop_tail, + ), + batched=True, + remove_columns=dataset["train"].column_names, + **({} if data_args.streaming else {"num_proc": data_args.num_proc}), + **({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}), + ) + if data_args.streaming: + dataset = dataset.shuffle(seed=training_args.seed) + + # ----- Training -------------------------------------------------------------- + @dataclass + class LLaDAPTCollator(transformers.DataCollatorForSeq2Seq): + # Reference: https://github.com/ML-GSAI/LLaDA/blob/main/GUIDELINES.md#pre-training + # By default, 1% of the pre-training data are truncated to a random length + random_length_ratio: float = 0.01 + + def __call__(self, features, return_tensors=None): + outputs = super().__call__(features, return_tensors) + if torch.rand(1) < self.random_length_ratio: + random_length = torch.randint( + 1, outputs["input_ids"].shape[1] + 1, (1,) + ) + for key in ["input_ids", "labels", "attention_mask"]: + if key in outputs: + outputs[key] = outputs[key][:, :random_length] + # Check if attention_mask is all ones and set it to None + if torch.all(outputs["attention_mask"] == 1): + outputs.pop("attention_mask") + return outputs + + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = dllm.core.trainers.MDLMTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + data_collator=LLaDAPTCollator( + tokenizer, + return_tensors="pt", + padding=True, + random_length_ratio=data_args.random_length_ratio, + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/llada/sft.py b/dllm/examples/llada/sft.py new file mode 100644 index 0000000..6f3e028 --- /dev/null +++ b/dllm/examples/llada/sft.py @@ -0,0 +1,120 @@ +""" +Local users +------------ +- 1 GPU (4bit quant & LoRA, useful for testing): + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/llada/sft.py \ + --load_in_4bit True --lora True + +- 8 GPUs (FSDP): + accelerate launch \ + --config_file scripts/accelerate_configs/fsdp.yaml \ + examples/llada/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 Node, 8 GPUs (FSDP): + sbatch --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/sft.py" + +- 2 Nodes, 16 GPUs (FSDP): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "fsdp" \ + --script_path "examples/llada/sft.py" +""" + +import os +from dataclasses import dataclass, field +from functools import partial + +import transformers +import accelerate + +import dllm +logger = dllm.utils.get_default_logger(__name__) + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "allenai/tulu-3-sft-mixture[train:10000,test:1000]" + load_preprocessed_data: bool = False + mask_prompt_loss: bool = field( + default=True, + metadata={"help": "Whether to mask the loss on the prompt tokens"}, + ) + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = "models/LLaDA-8B-SFT/tulu-3-sft-mixture[train:10000,test:1000]" + group_by_length: bool = True + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + model = dllm.utils.get_model(model_args=model_args) + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + + # ----- Dataset ---------------------------------------------------------------- + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_sft_dataset( + data_args.dataset_args, + load_preprocessed_data=data_args.load_preprocessed_data, + ) + if not data_args.load_preprocessed_data: + map_fn = partial( + dllm.utils.default_sft_map_fn, + tokenizer=tokenizer, + mask_prompt_loss=data_args.mask_prompt_loss, + ) + dataset = dataset.map( + map_fn, + num_proc=data_args.num_proc, + desc="Mapping dataset to SFT format", + ) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + + # ----- Training -------------------------------------------------------------- + accelerate.PartialState().wait_for_everyone() + logger.info("Start training...") + trainer = dllm.core.trainers.MDLMTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset.get("test", None), + args=training_args, + data_collator=dllm.utils.NoAttentionMaskCollator( + tokenizer, + return_tensors="pt", + padding=True, + label_pad_token_id=tokenizer.pad_token_id, # finetune on padding + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/rnd/README.md b/dllm/examples/rnd/README.md new file mode 100644 index 0000000..00d7bdd --- /dev/null +++ b/dllm/examples/rnd/README.md @@ -0,0 +1 @@ +WIP diff --git a/dllm/examples/rnd/preprocess.py b/dllm/examples/rnd/preprocess.py new file mode 100644 index 0000000..fa83161 --- /dev/null +++ b/dllm/examples/rnd/preprocess.py @@ -0,0 +1,114 @@ +# """ +# srun -p $PARTITION --quotatype=$QUOTATYPE --gres=gpu:1 --cpus-per-task=12 --time=03:00:000 + +# python examples/rnd/preprocess.py --dataset_args "HuggingFaceTB/smoltalk" --output_dir "data/sft_proc/rnd/smoltalk" +# """ +# import os +# from dataclasses import dataclass +# from typing import Dict, Any + +# import datasets +# import transformers +# import accelerate +# import tyro + +# import dllm + + +# # --- tyro: define dataclass for CLI args --- +# @dataclass +# class ScriptArguments: +# """Preprocess SFT dataset (batch_size=1 only)""" +# model_name_or_path: str = "radicalnumerics/RND1-Base-0910" +# dataset_args: str = "HuggingFaceTB/smoltalk" # required +# output_dir: str = "data/sft_proc/rnd/smoltalk" # required +# mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100 +# # TODO: strip_cols + +# def __post_init__(self): +# self.model_name_or_path = dllm.utils.resolve_with_base_env( +# self.model_name_or_path, "BASE_MODELS_DIR" +# ) + + +# def dataset_offline_preprocess(dataset: datasets.DatasetDict, map_fn: callable, output_dir: str): +# # Map with batch_size=1 and num_proc=1 (no batching, single process). +# state = accelerate.PartialState() +# with state.local_main_process_first(): +# processed = dataset.map( +# map_fn, +# batched=False, +# num_proc=16, +# load_from_cache_file=True, +# writer_batch_size=512, +# desc="offline preprocessing", +# ) + +# # # Keep only the three required columns to save space. +# # keep = {"input_ids", "labels", "prompt_len"} +# # def strip_cols(ds: datasets.Dataset) -> datasets.Dataset: +# # drop = [c for c in ds.column_names if c not in keep] +# # return ds.remove_columns(drop) if drop else ds + +# # if isinstance(processed, datasets.DatasetDict): +# # for split in list(processed.keys()): +# # processed[split] = strip_cols(processed[split]) +# # else: +# # processed = strip_cols(processed) + +# os.makedirs(output_dir, exist_ok=True) +# processed.save_to_disk(output_dir) +# print(f"[OK] Saved to: {output_dir}") + + +# def main(): +# # Parse with tyro +# args = tyro.cli(ScriptArguments) + +# # tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) +# tokenizer = dllm.utils.get_tokenizer(args) + +# # Load your raw dataset (must contain a "messages" field per example). +# dataset = dllm.data.load_sft_dataset(args.dataset_args) + +# dataset_offline_preprocess(dataset=dataset, map_fn=None, output_dir=args.output_dir) + + +# if __name__ == "__main__": +# main() + + +from functools import partial +import tyro + +import dllm +from dllm.tools.preprocess_sft_dataset import ScriptArguments, preprocess_sft_dataset + + +def main(): + from examples.rnd.sft import sft_map_fn + + # Parse with tyro + args = tyro.cli(ScriptArguments) + + # tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer = dllm.utils.get_tokenizer(args) + + # Load your raw dataset (must contain a "messages" field per example). + dataset = dllm.data.load_sft_dataset(args.dataset_args) + + map_fn = partial( + sft_map_fn, + tokenizer=tokenizer, + mask_prompt_loss=args.mask_prompt_loss, + ) + preprocess_sft_dataset( + dataset=dataset, + map_fn=map_fn, + output_dir=args.output_dir, + remove_columns=args.remove_columns, + ) + + +if __name__ == "__main__": + main() diff --git a/dllm/examples/rnd/sft.py b/dllm/examples/rnd/sft.py new file mode 100644 index 0000000..3a3ce4f --- /dev/null +++ b/dllm/examples/rnd/sft.py @@ -0,0 +1,199 @@ +""" +Local users +------------ +- 1 GPU: + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/rnd/sft.py + +- 8 GPUs (DeepSpeed ZeRO-2): + accelerate launch \ + --config_file scripts/accelerate_configs/zero2.yaml \ + examples/rnd/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 GPU: + sbatch --gres=gpu:1 scripts/train.slurm.sh \ + --accelerate_config "single_gpu" \ + --script_path "examples/rnd/sft.py" + +- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "zero2" \ + --script_path "examples/rnd/sft.py" +""" + +import os +from dataclasses import dataclass, field + +import transformers +import accelerate +import peft +import datasets + +import dllm +from dllm.pipelines import rnd + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "radicalnumerics/RND1-Base-0910" + moe_backend: str = "hf" + attn_implementation: str = "sdpa" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]" + truncation: str = "right" + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]" + # rnd specific + group_by_length: bool = True + mask_prompt_loss: bool = field( + default=True, + metadata={"help": "Whether to mask the loss on the prompt tokens"}, + ) + freeze_gate: bool = field( + default=True, + metadata={ + "help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)." + }, + ) + freeze_embedding: bool = field( + default=False, + metadata={"help": "If True, freeze embedding parameters."}, + ) + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + moe_backend=model_args.moe_backend, + attn_implementation=model_args.attn_implementation, + ) + model = dllm.utils.get_model(model_args=model_args, config=config) + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + # ----- Optionally freeze modules ---------------------------------------------- + if not isinstance(model, peft.PeftModel): + if getattr(training_args, "freeze_gate", False): + for n, m in model.named_modules(): + if n.endswith(".gate"): # only router gate, not gate_proj + for p in m.parameters(recurse=False): + p.requires_grad_(False) + + if getattr(training_args, "freeze_embedding", False): + # model.model.embed_tokens.requires_grad_(False) + model.model.embed_tokens.weight.requires_grad_(False) + + # ----- Dataset ---------------------------------------------------------------- + def sft_map_fn(row) -> dict: + prompt_tokens = tokenizer.apply_chat_template( + row["messages"][:-1], + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + ) + prompt_response_tokens = tokenizer.apply_chat_template( + row["messages"], tokenize=True, add_generation_prompt=False + ) + labels = prompt_response_tokens.copy() + if training_args.mask_prompt_loss: + # use -100 in labels to indicate positions where tokens should not be masked + # and loss is ignored; all other positions match `input_ids` + labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens) + else: + # When training on all tokens, prepend a BOS token (if missing) + # so the model can make predictions for the first mask token. + if prompt_response_tokens[0] != tokenizer.bos_token_id: + bos = [tokenizer.bos_token_id] + prompt_response_tokens = bos + prompt_response_tokens + prompt_tokens = bos + prompt_tokens + labels = bos + labels + labels[0] = -100 # ignore loss on the BOS token + # `prompt_len` helps `post_process_dataset` truncate long sequences properly + return { + "input_ids": prompt_response_tokens, + "labels": labels, + # "attention_mask": [1.0] * len(prompt_response_tokens), + "prompt_len": len(prompt_tokens), + } + + if not data_args.load_from_disk: + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_sft_dataset(data_args.dataset_args) + dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + else: + from datasets import disable_caching + + disable_caching() + dataset = datasets.load_from_disk(data_args.dataset_args) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + + # ----- Training -------------------------------------------------------------- + @dataclass + class RNDSFTCollator(transformers.DataCollatorForSeq2Seq): + def __call__(self, features, return_tensors=None): + outputs = super().__call__(features, return_tensors) + # RND is finetuned on padding + outputs.pop("attention_mask") + # temp fix here (`group_by_length=True` leads to shape mismatch) + # clip seq_len (second dim) to the same for outputs `input_ids, labels` + import torch + + keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs] + if keys_to_clip: + # Get smallest seq_len to avoid out-of-bounds + min_len = min( + outputs[k].size(1) + for k in keys_to_clip + if isinstance(outputs[k], torch.Tensor) + ) + for k in keys_to_clip: + t = outputs[k] + if isinstance(t, torch.Tensor) and t.size(1) != min_len: + outputs[k] = t[:, :min_len] + return outputs + + trainer = rnd.RNDTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + args=training_args, + data_collator=RNDSFTCollator( + tokenizer, + # pad_to_multiple_of=8, + return_tensors="pt", + padding=True, + label_pad_token_id=tokenizer.pad_token_id, # RND is finetuned on padding + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/rnd/sft_v2.py b/dllm/examples/rnd/sft_v2.py new file mode 100644 index 0000000..78e43a9 --- /dev/null +++ b/dllm/examples/rnd/sft_v2.py @@ -0,0 +1,199 @@ +""" +Local users +------------ +- 1 GPU: + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/rnd/sft.py + +- 8 GPUs (DeepSpeed ZeRO-2): + accelerate launch \ + --config_file scripts/accelerate_configs/zero2.yaml \ + examples/rnd/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 GPU: + sbatch --gres=gpu:1 scripts/train.slurm.sh \ + --accelerate_config "single_gpu" \ + --script_path "examples/rnd/sft.py" + +- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "zero2" \ + --script_path "examples/rnd/sft.py" +""" + +import os +from dataclasses import dataclass, field + +import transformers +import accelerate +import peft +import datasets + +import dllm +from dllm.pipelines import rnd + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "radicalnumerics/RND1-Base-0910" + moe_backend: str = "hf" + attn_implementation: str = "sdpa" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]" + truncation: str = "right" + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]" + # rnd specific + group_by_length: bool = True + mask_prompt_loss: bool = field( + default=True, + metadata={"help": "Whether to mask the loss on the prompt tokens"}, + ) + freeze_gate: bool = field( + default=True, + metadata={ + "help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)." + }, + ) + freeze_embedding: bool = field( + default=False, + metadata={"help": "If True, freeze embedding parameters."}, + ) + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + moe_backend=model_args.moe_backend, + attn_implementation=model_args.attn_implementation, + ) + model = dllm.utils.get_model(model_args=model_args, config=config) + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + # ----- Optionally freeze modules ---------------------------------------------- + if not isinstance(model, peft.PeftModel): + if getattr(training_args, "freeze_gate", False): + for n, m in model.named_modules(): + if n.endswith(".gate"): # only router gate, not gate_proj + for p in m.parameters(recurse=False): + p.requires_grad_(False) + + if getattr(training_args, "freeze_embedding", False): + # model.model.embed_tokens.requires_grad_(False) + model.model.embed_tokens.weight.requires_grad_(False) + + # ----- Dataset ---------------------------------------------------------------- + def sft_map_fn(row) -> dict: + prompt_tokens = tokenizer.apply_chat_template( + row["messages"][:-1], + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + ) + prompt_response_tokens = tokenizer.apply_chat_template( + row["messages"], tokenize=True, add_generation_prompt=False + ) + labels = prompt_response_tokens.copy() + if training_args.mask_prompt_loss: + # use -100 in labels to indicate positions where tokens should not be masked + # and loss is ignored; all other positions match `input_ids` + labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens) + else: + # When training on all tokens, prepend a BOS token (if missing) + # so the model can make predictions for the first mask token. + if prompt_response_tokens[0] != tokenizer.bos_token_id: + bos = [tokenizer.bos_token_id] + prompt_response_tokens = bos + prompt_response_tokens + prompt_tokens = bos + prompt_tokens + labels = bos + labels + labels[0] = -100 # ignore loss on the BOS token + # `prompt_len` helps `post_process_dataset` truncate long sequences properly + return { + "input_ids": prompt_response_tokens, + "labels": labels, + # "attention_mask": [1.0] * len(prompt_response_tokens), + "prompt_len": len(prompt_tokens), + } + + if not data_args.load_from_disk: + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_sft_dataset(data_args.dataset_args) + dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + else: + dataset = datasets.load_from_disk(data_args.dataset_args) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + + # ----- Training -------------------------------------------------------------- + @dataclass + class RNDSFTCollator(transformers.DataCollatorForSeq2Seq): + def __call__(self, features, return_tensors=None): + outputs = super().__call__(features, return_tensors) + # RND is finetuned on padding + outputs.pop("attention_mask") + # temp fix here (`group_by_length=True` leads to shape mismatch) + # clip seq_len (second dim) to the same for outputs `input_ids, labels` + # TODO -> FIXED: clip all relevant tensors to a common seq_len + # Determine common length across present tensors + import torch + + keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs] + if keys_to_clip: + # Get smallest seq_len to avoid out-of-bounds + min_len = min( + outputs[k].size(1) + for k in keys_to_clip + if isinstance(outputs[k], torch.Tensor) + ) + for k in keys_to_clip: + t = outputs[k] + if isinstance(t, torch.Tensor) and t.size(1) != min_len: + outputs[k] = t[:, :min_len] + return outputs + + tokenizer.pad_token_id = tokenizer.mask_token_ids + trainer = rnd.RNDTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + args=training_args, + data_collator=RNDSFTCollator( + tokenizer, + # pad_to_multiple_of=8, + return_tensors="pt", + padding=True, + label_pad_token_id=-100, + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/examples/rnd/sft_v3.py b/dllm/examples/rnd/sft_v3.py new file mode 100644 index 0000000..ca048fd --- /dev/null +++ b/dllm/examples/rnd/sft_v3.py @@ -0,0 +1,219 @@ +""" +Local users +------------ +- 1 GPU: + accelerate launch \ + --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \ + examples/rnd/sft.py + +- 8 GPUs (DeepSpeed ZeRO-2): + accelerate launch \ + --config_file scripts/accelerate_configs/zero2.yaml \ + examples/rnd/sft.py + +Slurm users +# Note: run `mkdir logs` before running sbatch; and adjust +# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster. +------------ +- 1 GPU: + sbatch --gres=gpu:1 scripts/train.slurm.sh \ + --accelerate_config "ddp" \ + --script_path "examples/rnd/sft.py" + +- 2 Nodes, 16 GPUs (DeepSpeed ZeRO-2): + sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \ + --accelerate_config "zero2" \ + --script_path "examples/rnd/sft.py" +""" + +import os +from dataclasses import dataclass, field + +import transformers +import accelerate +import peft +import datasets + +import dllm +from dllm.pipelines import rnd + + +@dataclass +class ModelArguments(dllm.utils.ModelArguments): + model_name_or_path: str = "radicalnumerics/RND1-Base-0910" + moe_backend: str = "hf" + attn_implementation: str = "sdpa" + + +@dataclass +class DataArguments(dllm.utils.DataArguments): + dataset_args: str = "HuggingFaceTB/smoltalk[train:10000,test:1000]" + truncation: str = "right" + + +@dataclass +class TrainingArguments(dllm.utils.TrainingArguments): + output_dir: str = "models/RND1-SFT-0910/smoltalk[train:10000,test:1000]" + # rnd specific + # group_by_length: bool = True + mask_prompt_loss: bool = field( + default=True, + metadata={"help": "Whether to mask the loss on the prompt tokens"}, + ) + freeze_gate: bool = field( + default=True, + metadata={ + "help": "If True, freeze routing gate parameters (e.g., MoE router/gating layers)." + }, + ) + freeze_embedding: bool = field( + default=False, + metadata={"help": "If True, freeze embedding parameters."}, + ) + perbatch_cutoff: bool = field( + default=True, + metadata={ + "help": ( + "Randomly pick a response length from batch and trim other responses. " + "See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml." + ) + }, + ) + resp_cutoff_ratio: float = field( + default=0.0, + metadata={ + "help": ( + "The probability of randomly cutting sequences during training. " + "See https://github.com/DreamLM/Dream/blob/main/src/trainer/config/sft_trainer.yaml." + ) + }, + ) + + +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + # necessary when batch contains customized fields + training_args.remove_unused_columns = False + dllm.utils.print_args_main(model_args, data_args, training_args) + dllm.utils.initial_training_setup(model_args, data_args, training_args) + + # ----- Model ------------------------------------------------------------------ + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + moe_backend=model_args.moe_backend, + attn_implementation=model_args.attn_implementation, + ) + model = dllm.utils.get_model(model_args=model_args, config=config) + # ----- Tokenizer -------------------------------------------------------------- + tokenizer = dllm.utils.get_tokenizer(model_args=model_args) + # ----- Optionally freeze modules ---------------------------------------------- + if not isinstance(model, peft.PeftModel): + if getattr(training_args, "freeze_gate", False): + for n, m in model.named_modules(): + if n.endswith(".gate"): # only router gate, not gate_proj + for p in m.parameters(recurse=False): + p.requires_grad_(False) + + if getattr(training_args, "freeze_embedding", False): + # model.model.embed_tokens.requires_grad_(False) + model.model.embed_tokens.weight.requires_grad_(False) + + # ----- Dataset ---------------------------------------------------------------- + def sft_map_fn(row) -> dict: + prompt_tokens = tokenizer.apply_chat_template( + row["messages"][:-1], + tokenize=True, + add_generation_prompt=True, + enable_thinking=False, + ) + prompt_response_tokens = tokenizer.apply_chat_template( + row["messages"], tokenize=True, add_generation_prompt=False + ) + labels = prompt_response_tokens.copy() + if training_args.mask_prompt_loss: + # use -100 in labels to indicate positions where tokens should not be masked + # and loss is ignored; all other positions match `input_ids` + labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens) + else: + # When training on all tokens, prepend a BOS token (if missing) + # so the model can make predictions for the first mask token. + if prompt_response_tokens[0] != tokenizer.bos_token_id: + bos = [tokenizer.bos_token_id] + prompt_response_tokens = bos + prompt_response_tokens + prompt_tokens = bos + prompt_tokens + labels = bos + labels + labels[0] = -100 # ignore loss on the BOS token + # `prompt_len` helps `post_process_dataset` truncate long sequences properly + return { + "input_ids": prompt_response_tokens, + "labels": labels, + "attention_mask": [1] * len(prompt_response_tokens), + "prompt_len": len(prompt_tokens), + } + + if not data_args.load_from_disk: + with accelerate.PartialState().local_main_process_first(): + dataset = dllm.data.load_sft_dataset(data_args.dataset_args) + dataset = dataset.map(sft_map_fn, num_proc=data_args.num_proc) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + else: + dataset = datasets.load_from_disk(data_args.dataset_args) + # truncate / filter long sequences if needed + dataset = dllm.utils.post_process_dataset(dataset, data_args) + + # ----- Training -------------------------------------------------------------- + # @dataclass + # class RNDSFTCollator(transformers.DataCollatorForSeq2Seq): + # def __call__(self, features, return_tensors=None): + # outputs = super().__call__(features, return_tensors) + # # RND is finetuned on padding + # outputs.pop("attention_mask") + # # temp fix here (`group_by_length=True` leads to shape mismatch) + # # clip seq_len (second dim) to the same for outputs `input_ids, labels` + # import torch + # keys_to_clip = [k for k in ("input_ids", "labels") if k in outputs] + # if keys_to_clip: + # # Get smallest seq_len to avoid out-of-bounds + # min_len = min(outputs[k].size(1) for k in keys_to_clip if isinstance(outputs[k], torch.Tensor)) + # for k in keys_to_clip: + # t = outputs[k] + # if isinstance(t, torch.Tensor) and t.size(1) != min_len: + # outputs[k] = t[:, :min_len] + # return outputs + trainer = rnd.RNDTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + args=training_args, + # data_collator=RNDSFTCollator( + # tokenizer, + # # pad_to_multiple_of=8, + # return_tensors="pt", + # padding=True, + # label_pad_token_id=-100, # RND is finetuned on padding + # ), + data_collator=dllm.pipelines.dream.utils.DreamSFTCollator( + tokenizer, + # pad_to_multiple_of=8, + return_tensors="pt", + padding=True, + label_pad_token_id=-100, + perbatch_cutoff=training_args.perbatch_cutoff, + resp_cutoff_ratio=training_args.resp_cutoff_ratio, + ), + ) + trainer.train() + trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) + trainer.processing_class.save_pretrained( + os.path.join(training_args.output_dir, "checkpoint-final") + ) + + +if __name__ == "__main__": + train() diff --git a/dllm/pyproject.toml b/dllm/pyproject.toml new file mode 100644 index 0000000..61f1c2c --- /dev/null +++ b/dllm/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["dllm"] + +[project] +name = "dllm" +version = "0.1.0" +description = "dLLM: Simple Diffusion Language Modeling" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "transformers==4.57.0", + "accelerate==1.11.0", + "deepspeed==0.18.0", + "peft==0.17.1", + "datasets==4.2.0", + "sentencepiece==0.2.0", + "tyro", + "wandb", + "omegaconf", + "tqdm", + "matplotlib", + "pytest", + "rich", +] + +[project.optional-dependencies] +optional = [ + "bitsandbytes==0.48.1", + "vllm==0.8.5.post1", + "flash-attn==2.8.3", +] + +[tool.black] +line-length = 88 +exclude = ''' +( + (^|/)dllm/pipelines/llada/models(/|$)| + (^|/)dllm/pipelines/dream/models(/|$)| + (^|/)dllm/pipelines/rnd/models(/|$)| + (^|/)lm-evaluation-harness(/|$) +) +''' + +[tool.pytest.ini_options] +testpaths = ["scripts/tests"] +python_files = ["test_*.py"] +addopts = "-v -ra" diff --git a/dllm/scripts/accelerate_configs/cpu.yaml b/dllm/scripts/accelerate_configs/cpu.yaml new file mode 100644 index 0000000..254420c --- /dev/null +++ b/dllm/scripts/accelerate_configs/cpu.yaml @@ -0,0 +1,7 @@ +compute_environment: LOCAL_MACHINE +distributed_type: NO +mixed_precision: "no" +num_processes: 1 +machine_rank: 0 +num_machines: 1 +downcast_bf16: "no" diff --git a/dllm/scripts/accelerate_configs/ddp.yaml b/dllm/scripts/accelerate_configs/ddp.yaml new file mode 100644 index 0000000..90506c9 --- /dev/null +++ b/dllm/scripts/accelerate_configs/ddp.yaml @@ -0,0 +1,6 @@ +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +machine_rank: 0 +num_machines: 1 +num_processes: 8 diff --git a/dllm/scripts/accelerate_configs/fsdp.yaml b/dllm/scripts/accelerate_configs/fsdp.yaml new file mode 100644 index 0000000..ef1ad29 --- /dev/null +++ b/dllm/scripts/accelerate_configs/fsdp.yaml @@ -0,0 +1,56 @@ +# compute_environment: LOCAL_MACHINE +# debug: false +# distributed_type: FSDP +# downcast_bf16: 'no' +# enable_cpu_affinity: false +# fsdp_config: +# fsdp_activation_checkpointing: true # Need fix from: https://github.com/huggingface/transformers/pull/36610 +# fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP +# fsdp_backward_prefetch: BACKWARD_PRE +# fsdp_cpu_ram_efficient_loading: true +# fsdp_forward_prefetch: true +# fsdp_offload_params: false +# fsdp_sharding_strategy: FULL_SHARD +# fsdp_state_dict_type: FULL_STATE_DICT +# fsdp_sync_module_states: true +# fsdp_use_orig_params: true +# machine_rank: 0 +# main_training_function: main +# mixed_precision: bf16 +# num_machines: 1 +# num_processes: 8 +# rdzv_backend: static +# same_network: true +# tpu_env: [] +# tpu_use_cluster: false +# tpu_use_sudo: false +# use_cpu: false + +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false # Need fix from: https://github.com/huggingface/transformers/pull/36610 + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_POST + fsdp_forward_prefetch: false + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: true + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/dllm/scripts/accelerate_configs/zero1.yaml b/dllm/scripts/accelerate_configs/zero1.yaml new file mode 100644 index 0000000..acd303a --- /dev/null +++ b/dllm/scripts/accelerate_configs/zero1.yaml @@ -0,0 +1,19 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +# mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/dllm/scripts/accelerate_configs/zero2.yaml b/dllm/scripts/accelerate_configs/zero2.yaml new file mode 100644 index 0000000..9595d1f --- /dev/null +++ b/dllm/scripts/accelerate_configs/zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +# mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/dllm/scripts/accelerate_configs/zero3.yaml b/dllm/scripts/accelerate_configs/zero3.yaml new file mode 100644 index 0000000..98eb040 --- /dev/null +++ b/dllm/scripts/accelerate_configs/zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +# mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/dllm/scripts/accelerate_configs/zero3_moe.yaml b/dllm/scripts/accelerate_configs/zero3_moe.yaml new file mode 100644 index 0000000..5ae063b --- /dev/null +++ b/dllm/scripts/accelerate_configs/zero3_moe.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_moe_layer_cls_names: RND1DecoderLayer # LLaDAMoEDecoderLayer + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/dllm/scripts/eval.slurm.sh b/dllm/scripts/eval.slurm.sh new file mode 100644 index 0000000..b9ebaf6 --- /dev/null +++ b/dllm/scripts/eval.slurm.sh @@ -0,0 +1,292 @@ +#!/usr/bin/env bash +#SBATCH --job-name=model-eval +#SBATCH --partition=mllm_safety +#SBATCH --quotatype=spot +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=8 +#SBATCH --time=20:00:00 +#SBATCH --output=logs/%x-%j.out +#SBATCH --error=logs/%x-%j.err +#SBATCH --requeue + +# ============================================================ +# Unified Evaluation Configuration + Execution Script +# ============================================================ + + +# ------------------------------------------------------------ +# Declare associative arrays +# ------------------------------------------------------------ +declare -A eval_llada_base_configs +declare -A eval_llada_instruct_configs + +declare -A eval_dream_base_configs +declare -A eval_dream_instruct_configs + +declare -A eval_bert_configs + + +# ============================================================ +# ==================== LLaDA CONFIGS ======================== +# ============================================================ +# eval_llada_configs[""]="num_fewshot|max_new_tokens|steps|block_length|seed|mc_num|cfg" +# ============================================================ + +# ---------- Base Generation ---------- +eval_llada_base_configs["gsm8k"]="8|1024|1024|32|1234|128|0.0" +eval_llada_base_configs["bbh"]="3|1024|1024|32|1234|128|0.0" +eval_llada_base_configs["minerva_math"]="4|1024|1024|32|1234|128|0.0" +eval_llada_base_configs["humaneval"]="0|1024|1024|32|1234|128|0.0" +eval_llada_base_configs["mbpp"]="3|1024|1024|32|1234|128|0.0" + +# ---------- Base Likelihood ---------- +eval_llada_base_configs["gpqa_main_n_shot"]="5|1024|1024|1024|1234|128|0.5" +eval_llada_base_configs["truthfulqa_mc2"]="0|1024|1024|1024|1234|128|2.0" +eval_llada_base_configs["arc_challenge"]="0|1024|1024|1024|1234|128|0.5" +eval_llada_base_configs["hellaswag"]="0|1024|1024|1024|1234|128|0.5" +eval_llada_base_configs["winogrande"]="5|1024|1024|1024|1234|128|0.0" +eval_llada_base_configs["piqa"]="0|1024|1024|1024|1234|128|0.5" +eval_llada_base_configs["mmlu"]="5|1024|1024|1024|1234|1|0.0" +eval_llada_base_configs["cmmlu"]="5|1024|1024|1024|1234|1|0.0" +eval_llada_base_configs["ceval-valid"]="5|1024|1024|1024|1234|1|0.0" + +# ---------- Instruct Generation ---------- +eval_llada_instruct_configs["gsm8k_cot"]="8|1024|1024|32|1234|1|0.0" +eval_llada_instruct_configs["bbh"]="3|1024|1024|32|1234|1|0.0" +eval_llada_instruct_configs["minerva_math"]="4|1024|1024|32|1234|1|0.0" +eval_llada_instruct_configs["humaneval_instruct"]="0|1024|1024|32|1234|1|0.0" +eval_llada_instruct_configs["mbpp_llada_instruct"]="3|1024|1024|32|1234|1|0.0" + +eval_llada_instruct_configs["mmlu_generative"]="0|3|3|3|1234|1|0.0" +eval_llada_instruct_configs["mmlu_pro"]="0|256|256|256|1234|1|0.0" +eval_llada_instruct_configs["hellaswag_gen"]="0|3|3|3|1234|1|0.0" +eval_llada_instruct_configs["arc_challarc_challenge_chatenge"]="0|5|5|5|1234|1|0.0" +eval_llada_instruct_configs["gpqa_n_shot_gen"]="5|32|32|32|1234|1|0.0" + +# ============================================================ +# ==================== DREAM CONFIGS ======================== +# ============================================================ +# eval_dream_configs[""]="num_fewshot|max_new_tokens|steps|temperature|top_p|seed|mc_num" +# ============================================================ + +# ---------- Base Generation ---------- +eval_dream_base_configs["humaneval_dream"]="0|512|512|0.2|0.95|1234|1" +eval_dream_base_configs["gsm8k_cot"]="8|256|256|0.0|0.95|1234|1" +eval_dream_base_configs["mbpp"]="3|512|512|0.2|0.95|1234|1" +eval_dream_base_configs["minerva_math"]="4|512|512|0.0|0.95|1234|1" +eval_dream_base_configs["bbh"]="3|512|512|0.0|0.95|1234|1" + +# ---------- Base Likelihood ---------- +eval_dream_base_configs["mmlu"]="5|512|512|0.0|0.95|1234|128" +eval_dream_base_configs["arc_easy"]="0|512|512|0.0|0.95|1234|128" +eval_dream_base_configs["arc_challenge"]="0|512|512|0.0|0.95|1234|128" +eval_dream_base_configs["hellaswag"]="0|512|512|0.0|0.95|1234|128" +eval_dream_base_configs["piqa"]="0|512|512|0.0|0.95|1234|128" +eval_dream_base_configs["gpqa_main_n_shot"]="5|512|512|0.0|0.95|1234|128" +eval_dream_base_configs["winogrande"]="5|512|512|0.0|0.95|1234|128" +eval_dream_base_configs["race"]="0|512|512|0.0|0.95|1234|128" + +# ---------- Instruct Generation ---------- +eval_dream_instruct_configs["mmlu_generative"]="4|128|128|0.1|0.9|1234|1" +eval_dream_instruct_configs["mmlu_generative_dream"]="4|128|128|0.1|0.9|1234|1" +eval_dream_instruct_configs["mmlu_pro"]="4|128|128|0.1|0.9|1234|1" +eval_dream_instruct_configs["gsm8k_cot"]="0|256|256|0.1|0.9|1234|1" +eval_dream_instruct_configs["minerva_math"]="0|512|512|0.1|0.9|1234|1" +eval_dream_instruct_configs["gpqa_main_n_shot"]="5|128|128|0.0|1.0|1234|1" +eval_dream_instruct_configs["humaneval_instruct"]="0|768|768|0.1|0.9|1234|1" +eval_dream_instruct_configs["mbpp_instruct"]="0|1024|1024|0.1|0.9|1234|1" +eval_dream_instruct_configs["mbpp_instruct_dream"]="0|1024|1024|0.1|0.9|1234|1" +eval_dream_instruct_configs["ifeval"]="0|1280|1280|0.1|0.9|1234|1" + +# ============================================================ +# ==================== BERT CONFIGS ========================= +# ============================================================ +# eval_bert_configs[""]="num_fewshot|max_new_tokens|steps|block_length|seed|mc_num" +# ============================================================ + +eval_bert_configs["mmlu"]="5|512|512|32|1234|128" +eval_bert_configs["ceval-valid"]="5|1024|1024|32|1234|128" +eval_bert_configs["cmmlu"]="5|1024|1024|32|1234|128" +eval_bert_configs["hellaswag"]="0|1024|1024|1024|1234|128" +eval_bert_configs["winogrande"]="0|128|128|128|1234|128" + +eval_bert_configs["gsm8k_bert"]="8|256|256|32|1234|128" +eval_bert_configs["minerva_math"]="4|256|256|32|1234|128" +eval_bert_configs["humaneval"]="0|256|256|32|1234|128" +eval_bert_configs["bbh"]="3|256|256|32|1234|128" + + +eval_bert_configs["hellaswag_gen"]="0|128|128|128|1234|1" +eval_bert_configs["mmlu_generative"]="0|128|128|128|1234|1" +eval_bert_configs["mmlu_pro"]="0|256|256|256|1234|1" +eval_bert_configs["arc_challenge_chat"]="0|128|128|128|1234|1" + +# ============================================================ +# ====================== END CONFIGS ======================== +# ============================================================ + + +# ============================================================ +# ====================== END CONFIGS ======================== +# ============================================================ + + +# ===== Derived variables ===== +NUM_NODES=${SLURM_NNODES} +GPUS_PER_NODE=$(echo "${SLURM_JOB_GPUS}" | tr ',' '\n' | wc -l) +WORLD_SIZE=$((NUM_NODES * GPUS_PER_NODE)) +MASTER_PORT=$((20000 + SLURM_JOB_ID % 10000)) +NODELIST=($(scontrol show hostnames "${SLURM_JOB_NODELIST}")) +MASTER_ADDR=${NODELIST[0]} +TRAIN_NODES=("${NODELIST[@]}") + +echo "============================" +echo "JOB NAME: ${SLURM_JOB_NAME}" +echo "JOB ID: ${SLURM_JOB_ID}" +echo "NUM_NODES: ${NUM_NODES}" +echo "WORLD_SIZE: ${WORLD_SIZE}" +echo "MASTER: ${MASTER_ADDR}:${MASTER_PORT}" +echo "============================" + +# ===== Environment ===== +export PYTHONBREAKPOINT=0 +export NCCL_ASYNC_ERROR_HANDLING=1 +export NCCL_DEBUG=warn +export TORCH_DISTRIBUTED_DEBUG=DETAIL +export PYTHONPATH=.:$PYTHONPATH +export HF_ALLOW_CODE_EVAL=1 +export HF_DATASETS_TRUST_REMOTE_CODE=1 # For cmmlu dataset +export MASTER_ADDR MASTER_PORT WORLD_SIZE + + +MODEL_CLASS=${1,,} # "llada" or "dream" +TASK=${2:-"gsm8k"} # dataset name +MODEL_NAME=${3} # model path or name (required) +INSTRUCT=${4:-"False"} # whether to evaluate instruct model +BATCH_SIZE=${5:-"1"} # control batchsize +USE_LOG=${6:-"False"} # optional: enable logging +LIMIT=${7:-"None"} # optional: limit number of test samples (default None) + + +if [[ -z "${MODEL_NAME}" ]]; then + echo "❌ Missing model name/path argument!" + echo "Usage: sbatch eval_model.sh [instruct] [batch_size]" + exit 1 +fi + +if [[ "${MODEL_NAME}" == /* ]]; then + MODEL_PATH="${MODEL_NAME}" +else + MODEL_PATH="${BASE_MODELS_DIR}/${MODEL_NAME}" +fi + +case "${MODEL_CLASS}" in + llada) + if [[ "${INSTRUCT,,}" == "true" ]]; then + CONFIG="${eval_llada_instruct_configs[$TASK]}" + CONFIG_SET="instruct" + else + CONFIG="${eval_llada_base_configs[$TASK]}" + CONFIG_SET="base" + fi + + if [[ -z "${CONFIG}" ]]; then + echo "❌ Unknown task '${TASK}' for LLaDA (${CONFIG_SET} mode)." + echo "Available tasks (base): ${!eval_llada_base_configs[@]}" + echo "Available tasks (instruct): ${!eval_llada_instruct_configs[@]}" + exit 1 + fi + + IFS="|" read -r NUM_FEWSHOT MAX_NEW_TOKENS STEPS BLOCK_LENGTH SEED MC_NUM CFG <<< "${CONFIG}" + + MODEL_TYPE="llada" + SCRIPT_PATH="dllm/pipelines/llada/eval.py" + MODEL_ARGS="pretrained=${MODEL_PATH},is_check_greedy=False,mc_num=${MC_NUM},max_new_tokens=${MAX_NEW_TOKENS},steps=${STEPS},block_length=${BLOCK_LENGTH},cfg=${CFG}" + ;; + + dream) + if [[ "${INSTRUCT,,}" == "true" ]]; then + CONFIG="${eval_dream_instruct_configs[$TASK]}" + CONFIG_SET="instruct" + else + CONFIG="${eval_dream_base_configs[$TASK]}" + CONFIG_SET="base" + fi + + if [[ -z "${CONFIG}" ]]; then + echo "❌ Unknown task '${TASK}' for Dream (${CONFIG_SET} mode)." + echo "Available tasks (base): ${!eval_dream_base_configs[@]}" + echo "Available tasks (instruct): ${!eval_dream_instruct_configs[@]}" + exit 1 + fi + + IFS="|" read -r NUM_FEWSHOT MAX_NEW_TOKENS STEPS TEMPERATURE TOP_P SEED MC_NUM <<< "${CONFIG}" + + MODEL_TYPE="dream" + SCRIPT_PATH="dllm/pipelines/dream/eval.py" + MODEL_ARGS="pretrained=${MODEL_PATH},mc_num=${MC_NUM},max_new_tokens=${MAX_NEW_TOKENS},steps=${STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},add_bos_token=true,escape_until=true" + ;; + + bert) + CONFIG="${eval_bert_configs[$TASK]}" + if [[ -z "${CONFIG}" ]]; then + echo "❌ Unknown task '${TASK}' for BERT." + echo "Available tasks: ${!eval_bert_configs[@]}" + exit 1 + fi + + IFS="|" read -r NUM_FEWSHOT MAX_NEW_TOKENS STEPS BLOCK_LENGTH SEED MC_NUM <<< "${CONFIG}" + + MODEL_TYPE="bert" + SCRIPT_PATH="dllm/pipelines/bert/eval.py" + MODEL_ARGS="pretrained=${MODEL_PATH},is_check_greedy=False,mc_num=${MC_NUM},max_new_tokens=${MAX_NEW_TOKENS},steps=${STEPS},block_length=${BLOCK_LENGTH}" + ;; + + *) + echo "❌ Invalid model_class '${MODEL_CLASS}'. Must be 'llada' or 'dream' or 'bert'." + exit 1 + ;; +esac + + +[[ "${INSTRUCT}" == "True" ]] && APPLY_CHAT_TEMPLATE_ARG="--apply_chat_template True" || APPLY_CHAT_TEMPLATE_ARG="" +[[ "${LIMIT}" == "None" ]] && LIMIT_ARG="" || LIMIT_ARG="--limit ${LIMIT}" +[[ "${USE_LOG}" == "True" ]] && \ + LOG_ARG="--log_samples --output_path ./logs/${MODEL_CLASS}_${TASK}_${SLURM_JOB_ID}_samples.json" \ + || LOG_ARG="--output_path ./logs/${MODEL_CLASS}_${TASK}_${SLURM_JOB_ID}_samples.json" + +echo -e "\nLaunching ${MODEL_CLASS} on ${TASK} using ${MODEL_PATH}" +echo "============================" +echo "Few-shot: ${NUM_FEWSHOT}" +echo "Seed: ${SEED}" +echo "Batch size: ${BATCH_SIZE}" +echo "Use chat template: ${USE_CHAT_TEMPLATE}" +echo "============================" + +RUN_CMD="accelerate launch \ + --num_processes ${WORLD_SIZE} \ + --num_machines ${NUM_NODES} \ + --main_process_ip ${MASTER_ADDR} \ + --main_process_port ${MASTER_PORT} \ + --machine_rank ${SLURM_PROCID} \ + ${SCRIPT_PATH} \ + --num_fewshot ${NUM_FEWSHOT} \ + --batch_size ${BATCH_SIZE} \ + --model ${MODEL_TYPE} \ + --model_args \"${MODEL_ARGS}\" \ + --tasks ${TASK} \ + --seed ${SEED} \ + ${LOG_ARG} \ + --confirm_run_unsafe_code \ + ${LIMIT_ARG} \ + ${APPLY_CHAT_TEMPLATE_ARG}" + +if [[ "${NUM_NODES}" -eq 1 ]]; then + echo "Single-node execution" + eval ${RUN_CMD} +else + echo "Multi-node execution" + srun --nodes="${NUM_NODES}" --ntasks="${NUM_NODES}" --nodelist="${SLURM_JOB_NODELIST}" ${RUN_CMD} +fi diff --git a/dllm/scripts/tests/test_attention_mask.py b/dllm/scripts/tests/test_attention_mask.py new file mode 100644 index 0000000..725748f --- /dev/null +++ b/dllm/scripts/tests/test_attention_mask.py @@ -0,0 +1,144 @@ +""" +LLaDA / MoE / Dream / RND attention mask invariance tests (compact version) +""" + +import gc + +import torch +import transformers +import dllm +import pytest + +ERROR_THRESHOLD = 1e-3 + + +def _cuda_cleanup(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Reclaim interprocess memory blocks (useful after large model del) + try: + torch.cuda.ipc_collect() + except Exception: + # Not all PyTorch builds expose ipc_collect on all platforms + pass + + +def _forward_variants(model): + """ + Run the 5 padding/mask variants and return tensors sliced to the 'real' tokens [1,2,3,4]. + Returns dict: {'A','B','C','D','E'} each [1, 4, H] + """ + device = model.device + + # A: no padding + a_ids = torch.tensor([[1, 2, 3, 4]], device=device) + a_mask = torch.tensor([[1, 1, 1, 1]], device=device) + + # B: left-pad a 0 + b_ids = torch.tensor([[0, 1, 2, 3, 4]], device=device) + b_mask = torch.tensor([[0, 1, 1, 1, 1]], device=device) + + # C: right-pad a 0 + c_ids = torch.tensor([[1, 2, 3, 4, 0]], device=device) + c_mask = torch.tensor([[1, 1, 1, 1, 0]], device=device) + + # D: same as A but attention_mask=None + d_ids = torch.tensor([[1, 2, 3, 4]], device=device) + d_mask = None + + # E: same as A but omit attention_mask entirely + e_ids = torch.tensor([[1, 2, 3, 4]], device=device) + + with torch.no_grad(): + out_A = model(input_ids=a_ids, attention_mask=a_mask).logits # [1,4,H] + out_B = model(input_ids=b_ids, attention_mask=b_mask).logits[:, 1:] # [1,4,H] + out_C = model(input_ids=c_ids, attention_mask=c_mask).logits[:, :-1] # [1,4,H] + out_D = model(input_ids=d_ids, attention_mask=d_mask).logits # [1,4,H] + out_E = model(input_ids=e_ids).logits # [1,4,H] + + return {"A": out_A, "B": out_B, "C": out_C, "D": out_D, "E": out_E} + + +def _assert_invariance(outs: dict, tag: str): + ref = outs["A"] + for k in ("B", "C", "D", "E"): + assert torch.allclose( + ref, outs[k], atol=ERROR_THRESHOLD, rtol=ERROR_THRESHOLD + ), f"[{tag}] Mismatch A vs {k}" + + +@pytest.mark.parametrize( + "repo, attn_impl, human_name", + [ + ("GSAI-ML/LLaDA-8B-Base", None, "LLaDA Base"), + ("inclusionAI/LLaDA-MoE-7B-A1B-Base", None, "LLaDA MoE"), + ("Dream-org/Dream-v0-Base-7B", None, "Dream Base"), + ("radicalnumerics/RND1-Base-0910", None, "RND Base (native)"), + ("radicalnumerics/RND1-Base-0910", "sdpa", "RND Base (SDPA)"), + ], +) +def test_attention_mask_invariance(repo, attn_impl, human_name): + """ + For each model/backend: + 1) Check padding/mask invariance across A..E on the 'real' tokens. + 2) Print a ✅ message for debug visibility (pytest still enforces assertions). + """ + model_path = dllm.utils.resolve_with_base_env(repo, "BASE_MODELS_DIR") + + if attn_impl is None: + model = transformers.AutoModel.from_pretrained( + model_path, dtype=torch.float32, device_map="auto" + ).eval() + else: + config = transformers.AutoConfig.from_pretrained( + model_path, attn_implementation=attn_impl + ) + model = transformers.AutoModel.from_pretrained( + model_path, config=config, dtype=torch.float32, device_map="auto" + ).eval() + + outs = _forward_variants(model) + _assert_invariance(outs, human_name) + + print(f"✅ {human_name} attention mask invariance passed within {ERROR_THRESHOLD}.") + del model + gc.collect() + _cuda_cleanup() + + +def test_rnd_native_vs_sdpa_equivalence(): + """ + Verify RND (native attention) and RND (SDPA) produce equivalent logits on the + same real tokens across A..E variants. + """ + repo = "radicalnumerics/RND1-Base-0910" + model_path = dllm.utils.resolve_with_base_env(repo, "BASE_MODELS_DIR") + + # native + model_native = transformers.AutoModel.from_pretrained( + model_path, dtype=torch.float32, device_map="auto" + ).eval() + + # sdpa + config_sdpa = transformers.AutoConfig.from_pretrained( + model_path, attn_implementation="sdpa" + ) + model_sdpa = transformers.AutoModel.from_pretrained( + model_path, config=config_sdpa, dtype=torch.float32, device_map="auto" + ).eval() + + outs_native = _forward_variants(model_native) # expects helper from your file + outs_sdpa = _forward_variants(model_sdpa) + + for k in ("A", "B", "C", "D", "E"): + assert torch.allclose( + outs_native[k], outs_sdpa[k], atol=ERROR_THRESHOLD, rtol=ERROR_THRESHOLD + ), f"[RND cross-backend] native vs SDPA mismatch on {k}" + + print(f"✅ RND native vs SDPA equivalence passed within {ERROR_THRESHOLD}.") + # Explicitly drop model references + del model_native + del model_sdpa + # Collect Python garbage and release CUDA caches + gc.collect() + _cuda_cleanup() diff --git a/dllm/scripts/tests/test_dream_generation.py b/dllm/scripts/tests/test_dream_generation.py new file mode 100644 index 0000000..e69de29 diff --git a/dllm/scripts/train.slurm.sh b/dllm/scripts/train.slurm.sh new file mode 100644 index 0000000..9e51a26 --- /dev/null +++ b/dllm/scripts/train.slurm.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +#SBATCH --job-name=dllm +#SBATCH --nodes=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=24 +#SBATCH --ntasks-per-node=1 +#SBATCH --partition=mllm_safety +#SBATCH --quotatype=spot +#SBATCH --output=./logs/%x-%j.out +#SBATCH --err=./logs/%x-%j.err +#SBATCH --requeue +#SBATCH --time=3-00:00:00 + +# ===== Cluster variables ===== +NUM_NODES=${SLURM_NNODES} +GPUS_PER_NODE=$(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n' | wc -l) +WORLD_SIZE=$((NUM_NODES * GPUS_PER_NODE)) +NODELIST=($(scontrol show hostnames "${SLURM_JOB_NODELIST}")) +MASTER_ADDR=${NODELIST[0]} +MASTER_PORT=$((20000 + SLURM_JOB_ID % 10000)) +TRAIN_NODES=("${NODELIST[@]}") + +echo "===== System Variables =====" +{ + echo "NUM_NODES=$NUM_NODES" + echo "GPUS_PER_NODE=$GPUS_PER_NODE" + echo "WORLD_SIZE=$WORLD_SIZE" + echo "MASTER_ADDR=$MASTER_ADDR" + echo "MASTER_PORT=$MASTER_PORT" +} | column -t -s= + +echo "Nodes allocated:" +for node in "${TRAIN_NODES[@]}"; do + echo " - $node" +done +echo "============================" + +# ===== Environment ===== +export NCCL_ASYNC_ERROR_HANDLING=1 +export PYTHONPATH=. + +# ===== Default options ===== +accelerate_config="zero2" +script_path="scripts/examples/llada_sft.py" + +# ===== Parse arguments ===== +# Stop parsing known options as soon as we hit an unknown one +FORWARD_ARGS=() +while [[ $# -gt 0 ]]; do + case "$1" in + --accelerate_config) + accelerate_config="$2"; shift 2 ;; + --script_path) + script_path="$2"; shift 2 ;; + *) + FORWARD_ARGS=("$@"); break ;; # everything else goes to the training script + esac +done + +echo "===== Script Variables =====" +echo "--accelerate_config ${accelerate_config}" +echo "--script_path ${script_path}" +echo "--forwarded script args:" +printf '%s\n' "${FORWARD_ARGS[@]}" | xargs -n 2 +echo "============================" + +# ===== Launch ===== +srun --nodes="${NUM_NODES}" --ntasks="${NUM_NODES}" --nodelist="${SLURM_JOB_NODELIST}" \ + accelerate launch \ + --config_file "scripts/accelerate_configs/${accelerate_config}.yaml" \ + --num_machines "${NUM_NODES}" \ + --num_processes "${WORLD_SIZE}" \ + --main_process_ip "${MASTER_ADDR}" \ + --main_process_port "${MASTER_PORT}" \ + --machine_rank "${SLURM_PROCID}" \ + --rdzv_backend c10d \ + "${script_path}" "${FORWARD_ARGS[@]}" diff --git a/generate-batch.py b/generate-batch.py index e533482..0743401 100644 --- a/generate-batch.py +++ b/generate-batch.py @@ -34,8 +34,8 @@ def get_argument_parser(): "-attr_list", type=str, # default="beat,duration,,instrument,tempo", - default="pitch", - # default='bar,position,velocity,duration,program,tempo,timesig', + # default="pitch", + default='bar,position,velocity,duration,program,tempo,timesig', help="attribute list for attribute-controlled generation", ) parser.add_argument( @@ -88,13 +88,13 @@ def get_argument_parser(): parser.add_argument( "-num_processes", type=int, - default=4, + default=8, help="number of processes to use", ) parser.add_argument( "-gpu_ids", type=str, - default="0,1,2,3,5", + default="0,1,2,3,4,5,6,7", help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')", ) parser.add_argument( diff --git a/len_tunes/Melody/len_oct8.png b/len_tunes/Melody/len_oct8.png index f83eb5c..374d509 100644 Binary files a/len_tunes/Melody/len_oct8.png and b/len_tunes/Melody/len_oct8.png differ diff --git a/midi_sim.py b/midi_sim.py index 7ae08c2..1f4593c 100644 --- a/midi_sim.py +++ b/midi_sim.py @@ -99,9 +99,12 @@ def compare_pair(file_a: str, file_b: str): def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8): files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")] - files_a = files_a[:100] # 仅比较前100个文件以节省时间 + # remove files end with _prompt.mid + files_a = [f for f in files_a if not f.endswith("_prompt.mid")] + files_a = files_a files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")] + results = [] pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files") with ProcessPoolExecutor(max_workers=max_workers) as executor: @@ -110,6 +113,8 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", pbar.update(1) try: results.append(fut.result()) + if results[-1][2] == 0: + print(f"Exact match found: {results[-1][0]} and {results[-1][1]}") except Exception as e: print(fut.result()) print(f"Error comparing pair: {e}") @@ -129,6 +134,6 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", if __name__ == "__main__": - dir_a = "wandb/run-20251027_161354-f9j1mwp2/uncond_min_p_t0.05_temp1.25_epochch8" + dir_a = "wandb/run-20251124_104410-bjdyzt85ar_aux_melody/uncond_min_p_t0.2_temp1.25" dir_b = "dataset/Melody" - batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6) \ No newline at end of file + batch_compare(dir_a, dir_b, out_csv="midi_similarity_withbase_p0.6.csv", max_workers=24) \ No newline at end of file diff --git a/octuple_token_analysis_report.json b/octuple_token_analysis_report.json new file mode 100644 index 0000000..3216ee3 --- /dev/null +++ b/octuple_token_analysis_report.json @@ -0,0 +1,1421 @@ +{ + "summary": { + "total_tokens": 13138687744, + "num_columns": 8 + }, + "columns": { + "pitch": { + "total_tokens": 1642335968, + "unique_tokens": 152, + "token_counts": { + "1": 1565807, + "12": 2485816, + "14": 4642588, + "15": 2686135, + "16": 5510338, + "17": 5355563, + "18": 5007560, + "19": 11290973, + "20": 5073519, + "21": 13786057, + "22": 8807533, + "23": 12696765, + "24": 16371540, + "26": 23636568, + "27": 11493596, + "28": 23490048, + "29": 19017743, + "30": 17168854, + "31": 31916911, + "32": 13531850, + "33": 35363235, + "34": 19937689, + "35": 30067786, + "36": 33081225, + "38": 46745881, + "39": 21960177, + "40": 45106459, + "41": 33902224, + "42": 33982698, + "43": 55427133, + "44": 25052307, + "45": 62706678, + "46": 33327561, + "47": 53146812, + "48": 52065620, + "50": 67083713, + "51": 27134350, + "52": 60706253, + "53": 35952930, + "54": 40358128, + "55": 51312761, + "56": 22297017, + "57": 50097602, + "58": 20941091, + "59": 35424543, + "60": 27124835, + "62": 29041960, + "63": 10877943, + "64": 20530992, + "65": 11040402, + "66": 10578042, + "67": 11799034, + "69": 8744719, + "71": 5215768, + "72": 4334145, + "74": 3877881, + "76": 2598522, + "78": 1476047, + "81": 1379059, + "84": 802045, + "105": 1415919, + "106": 4381597, + "123": 2683622, + "49": 32901147, + "61": 17847569, + "68": 5028713, + "70": 4258518, + "73": 2406336, + "75": 1798894, + "77": 1881646, + "79": 1923634, + "83": 860356, + "25": 9091273, + "37": 19291681, + "114": 2663131, + "116": 5183840, + "119": 1105289, + "138": 1486373, + "101": 13123905, + "102": 15755386, + "103": 4568700, + "104": 24051465, + "112": 3903395, + "115": 2560809, + "120": 4264838, + "107": 2496966, + "109": 1275387, + "111": 1898760, + "82": 855807, + "141": 1062060, + "6": 256272, + "13": 1752286, + "80": 851954, + "86": 597366, + "126": 3813453, + "108": 23122121, + "113": 1254951, + "121": 641123, + "10": 781496, + "85": 425229, + "87": 218526, + "148": 774434, + "127": 1396884, + "4": 189108, + "5": 214984, + "7": 593095, + "8": 384595, + "9": 962880, + "11": 2023710, + "147": 1043169, + "118": 1157643, + "135": 2523758, + "142": 1630489, + "143": 851479, + "88": 268247, + "110": 3738941, + "117": 5925473, + "89": 182520, + "91": 125814, + "122": 812893, + "125": 1455498, + "90": 128227, + "95": 80800, + "139": 584476, + "96": 93397, + "130": 1731163, + "145": 152029, + "150": 156082, + "128": 1503679, + "129": 1857083, + "131": 310145, + "132": 161855, + "133": 444346, + "134": 460413, + "136": 2057129, + "140": 253111, + "151": 364625, + "149": 332127, + "92": 34748, + "152": 65470, + "137": 176674, + "124": 63879, + "144": 55077, + "146": 388347, + "94": 116801, + "93": 329142, + "153": 61331, + "154": 24740, + "99": 70053, + "98": 34663, + "97": 72838, + "100": 35080 + }, + "top_20": { + "50": 67083713, + "45": 62706678, + "52": 60706253, + "43": 55427133, + "47": 53146812, + "48": 52065620, + "55": 51312761, + "57": 50097602, + "38": 46745881, + "40": 45106459, + "54": 40358128, + "53": 35952930, + "59": 35424543, + "33": 35363235, + "42": 33982698, + "41": 33902224, + "46": 33327561, + "36": 33081225, + "49": 32901147, + "31": 31916911 + }, + "bottom_20": { + "89": 182520, + "137": 176674, + "132": 161855, + "150": 156082, + "145": 152029, + "90": 128227, + "91": 125814, + "94": 116801, + "96": 93397, + "95": 80800, + "97": 72838, + "99": 70053, + "152": 65470, + "124": 63879, + "153": 61331, + "144": 55077, + "100": 35080, + "92": 34748, + "98": 34663, + "154": 24740 + } + }, + "position": { + "total_tokens": 1642335968, + "unique_tokens": 97, + "token_counts": { + "0": 1565807, + "4": 319618129, + "5": 6329946, + "6": 21992580, + "7": 6550897, + "8": 101716803, + "9": 10645784, + "10": 32622087, + "11": 4079939, + "12": 189425867, + "13": 4759473, + "14": 22994657, + "15": 7371652, + "16": 126741165, + "17": 11170620, + "18": 29551363, + "19": 4049653, + "20": 210069958, + "21": 4739571, + "22": 20404843, + "23": 5582797, + "24": 94202245, + "25": 10141857, + "26": 24841033, + "27": 3500048, + "28": 161135286, + "29": 3807011, + "30": 18121821, + "31": 6356785, + "32": 90779192, + "33": 8971306, + "34": 20609995, + "35": 3061988, + "36": 14308368, + "37": 169206, + "38": 521846, + "39": 224872, + "40": 3405531, + "41": 271635, + "42": 520585, + "43": 143819, + "44": 15095654, + "45": 103941, + "46": 373349, + "48": 2310470, + "49": 159627, + "50": 308131, + "47": 173340, + "52": 3720342, + "60": 1444865, + "68": 2293232, + "76": 3025606, + "84": 1867431, + "92": 2302442, + "51": 57329, + "56": 181817, + "72": 205118, + "96": 212692, + "64": 262463, + "80": 178585, + "54": 39425, + "58": 35623, + "62": 37887, + "66": 26784, + "70": 33582, + "74": 26813, + "78": 47694, + "82": 33505, + "86": 29669, + "88": 248958, + "90": 31553, + "94": 48620, + "98": 26234, + "53": 12147, + "59": 8373, + "61": 9445, + "63": 13897, + "65": 12397, + "69": 13830, + "71": 15735, + "75": 10098, + "77": 15166, + "79": 20522, + "81": 12917, + "83": 13645, + "85": 15437, + "87": 16162, + "91": 9702, + "93": 12597, + "95": 17921, + "57": 12162, + "73": 16146, + "89": 14388, + "97": 13515, + "99": 4324, + "55": 13774, + "67": 10867 + }, + "top_20": { + "4": 319618129, + "20": 210069958, + "12": 189425867, + "28": 161135286, + "16": 126741165, + "8": 101716803, + "24": 94202245, + "32": 90779192, + "10": 32622087, + "18": 29551363, + "26": 24841033, + "14": 22994657, + "6": 21992580, + "34": 20609995, + "22": 20404843, + "30": 18121821, + "44": 15095654, + "36": 14308368, + "17": 11170620, + "9": 10645784 + }, + "bottom_20": { + "71": 15735, + "85": 15437, + "77": 15166, + "89": 14388, + "63": 13897, + "69": 13830, + "55": 13774, + "83": 13645, + "97": 13515, + "81": 12917, + "93": 12597, + "65": 12397, + "57": 12162, + "53": 12147, + "67": 10867, + "75": 10098, + "91": 9702, + "61": 9445, + "59": 8373, + "99": 4324 + } + }, + "bar": { + "total_tokens": 1642335968, + "unique_tokens": 513, + "token_counts": { + "0": 1565807, + "4": 15323806, + "5": 17617738, + "6": 18893196, + "7": 18476765, + "9": 20134896, + "10": 20822505, + "11": 19790645, + "12": 20453874, + "13": 20446479, + "14": 21063949, + "15": 20412240, + "16": 20875994, + "17": 20777973, + "18": 21012358, + "19": 19894438, + "20": 19697630, + "21": 19362656, + "22": 19537676, + "23": 18986155, + "24": 19110405, + "25": 19050818, + "26": 19104068, + "27": 18268039, + "28": 18335709, + "29": 18105778, + "30": 18111411, + "31": 17601061, + "32": 17703054, + "33": 17526250, + "34": 17544896, + "35": 16710650, + "36": 16279538, + "37": 15996939, + "38": 15964843, + "39": 15480569, + "40": 15369289, + "41": 15203162, + "42": 15127633, + "43": 14626784, + "44": 14468752, + "45": 14234280, + "46": 14204787, + "47": 13837793, + "48": 13754151, + "49": 13571015, + "50": 13514252, + "51": 13094105, + "52": 12919093, + "53": 12721729, + "54": 12614443, + "55": 12280475, + "56": 12228303, + "57": 12034709, + "58": 11958497, + "59": 11628726, + "60": 11462726, + "61": 11283469, + "62": 11178215, + "63": 10879468, + "64": 10808166, + "65": 10658452, + "66": 10528226, + "67": 10219224, + "68": 10044455, + "69": 9832377, + "70": 9697338, + "71": 9438013, + "72": 9359958, + "8": 19930196, + "73": 9202506, + "74": 9045536, + "75": 8861108, + "76": 8717191, + "77": 8547985, + "78": 8448585, + "79": 8222491, + "80": 8162992, + "81": 7991858, + "82": 7884306, + "83": 7582839, + "84": 7451908, + "85": 7258075, + "86": 7160715, + "87": 6958687, + "88": 6882533, + "89": 6752348, + "90": 6672595, + "91": 6500432, + "92": 6399913, + "93": 6273016, + "94": 6199414, + "95": 6046043, + "96": 5946712, + "97": 5851312, + "98": 5759038, + "99": 5619110, + "100": 5499385, + "101": 5392374, + "102": 5300130, + "103": 5171056, + "104": 5097248, + "105": 4999786, + "106": 4943844, + "107": 4820924, + "108": 4737737, + "109": 4623955, + "110": 4559372, + "111": 4440893, + "112": 4377173, + "113": 4276946, + "114": 4254798, + "115": 4137545, + "116": 4084600, + "117": 3978344, + "118": 3919894, + "119": 3827318, + "120": 3794540, + "121": 3696997, + "122": 3651955, + "123": 3568488, + "124": 3492862, + "125": 3414183, + "126": 3375323, + "127": 3307306, + "128": 3260111, + "129": 3199007, + "130": 3155738, + "131": 3075890, + "132": 3016922, + "133": 2945681, + "134": 2904782, + "135": 2856992, + "137": 2763199, + "138": 2728449, + "139": 2675231, + "140": 2640423, + "141": 2576873, + "142": 2536991, + "143": 2485374, + "144": 2455782, + "145": 2429811, + "146": 2404785, + "147": 2337015, + "148": 2271553, + "149": 2226155, + "150": 2218946, + "151": 2165900, + "152": 2122476, + "153": 2133605, + "154": 2107814, + "155": 2063505, + "156": 2035026, + "157": 1993644, + "158": 1975246, + "159": 1938389, + "160": 1924111, + "161": 1843236, + "162": 1814056, + "163": 1776581, + "164": 1745435, + "165": 1701830, + "166": 1671951, + "167": 1646641, + "168": 1623040, + "169": 1606575, + "170": 1589191, + "171": 1558402, + "172": 1542638, + "173": 1510861, + "174": 1488337, + "175": 1460665, + "176": 1439443, + "177": 1418019, + "178": 1402130, + "179": 1388475, + "180": 1369962, + "181": 1337214, + "182": 1319834, + "183": 1298105, + "184": 1275344, + "185": 1262340, + "186": 1246620, + "187": 1231544, + "188": 1225636, + "189": 1212845, + "190": 1212060, + "191": 1196025, + "192": 1183902, + "193": 1159144, + "136": 2816316, + "194": 1139812, + "195": 1118339, + "196": 1113002, + "197": 1093484, + "198": 1084804, + "199": 1061472, + "200": 1054456, + "201": 1040864, + "202": 1031487, + "203": 1015512, + "204": 1006144, + "205": 997545, + "206": 993152, + "207": 983572, + "208": 976632, + "209": 963697, + "210": 953747, + "211": 945141, + "212": 937761, + "213": 924343, + "214": 908414, + "215": 894453, + "216": 878582, + "217": 873995, + "218": 870652, + "219": 859898, + "220": 848690, + "221": 841836, + "222": 834782, + "223": 815018, + "224": 807684, + "225": 799987, + "226": 792840, + "227": 781252, + "228": 774756, + "229": 767929, + "230": 730542, + "231": 721880, + "232": 736313, + "233": 740394, + "234": 728468, + "235": 733359, + "236": 721853, + "237": 704232, + "238": 704252, + "239": 698703, + "240": 693883, + "241": 674101, + "242": 682881, + "243": 674281, + "244": 667457, + "245": 663744, + "246": 647841, + "247": 648702, + "248": 639012, + "249": 637016, + "250": 628305, + "251": 624574, + "252": 589382, + "253": 577055, + "254": 592378, + "255": 592591, + "256": 579356, + "257": 589448, + "258": 586865, + "259": 566481, + "260": 570459, + "261": 564722, + "262": 558620, + "263": 541250, + "264": 549063, + "265": 545117, + "266": 539646, + "267": 532533, + "268": 519715, + "269": 527694, + "270": 518894, + "271": 512753, + "272": 509646, + "273": 501427, + "274": 466762, + "275": 461624, + "276": 455696, + "277": 449499, + "278": 449332, + "279": 446493, + "280": 445187, + "281": 442968, + "282": 445073, + "283": 435942, + "284": 434258, + "285": 429074, + "286": 424052, + "287": 419430, + "288": 420166, + "289": 414299, + "290": 402423, + "291": 397156, + "292": 396581, + "293": 394027, + "294": 391099, + "295": 385612, + "296": 381976, + "297": 381198, + "298": 380325, + "299": 378367, + "300": 376878, + "301": 374872, + "302": 366112, + "303": 361122, + "304": 358597, + "305": 361557, + "306": 358511, + "307": 355471, + "308": 350041, + "309": 348060, + "310": 349314, + "311": 342611, + "312": 338576, + "313": 340822, + "314": 337382, + "315": 330635, + "316": 327649, + "317": 331952, + "318": 329229, + "319": 328551, + "320": 324378, + "321": 322949, + "322": 325996, + "323": 320769, + "324": 322288, + "325": 315639, + "326": 309682, + "327": 308691, + "328": 305928, + "329": 305131, + "330": 301786, + "331": 300658, + "332": 302336, + "333": 302279, + "334": 303181, + "335": 295581, + "336": 292285, + "337": 293105, + "338": 288229, + "339": 287896, + "340": 285672, + "341": 285162, + "342": 283359, + "343": 280253, + "344": 277170, + "345": 277374, + "346": 274391, + "347": 267604, + "348": 266094, + "349": 262992, + "350": 261981, + "351": 256633, + "352": 256540, + "353": 255648, + "354": 255874, + "355": 251929, + "356": 257799, + "357": 254779, + "358": 256132, + "359": 248169, + "360": 252501, + "361": 253610, + "362": 249695, + "363": 248335, + "364": 249118, + "365": 245219, + "366": 242007, + "367": 233646, + "368": 234927, + "369": 234469, + "370": 230116, + "371": 225864, + "372": 225819, + "373": 224421, + "374": 222761, + "375": 218373, + "376": 214948, + "377": 215503, + "378": 210568, + "379": 206378, + "380": 207312, + "381": 208608, + "382": 207805, + "383": 205300, + "384": 205086, + "385": 204429, + "386": 202224, + "387": 200461, + "388": 198135, + "389": 198941, + "390": 197532, + "391": 200147, + "392": 197217, + "393": 198179, + "394": 198437, + "395": 199755, + "396": 192595, + "397": 194212, + "398": 188364, + "399": 188546, + "400": 183536, + "401": 183079, + "402": 181097, + "403": 180014, + "404": 178617, + "405": 178100, + "406": 176194, + "407": 176050, + "408": 176327, + "409": 175408, + "410": 172012, + "411": 169257, + "412": 167370, + "413": 168618, + "414": 170376, + "415": 169047, + "416": 168436, + "417": 170379, + "418": 166991, + "419": 163743, + "420": 162153, + "421": 162803, + "422": 159490, + "423": 158685, + "424": 157742, + "425": 156282, + "426": 153898, + "427": 155763, + "428": 153970, + "429": 155828, + "430": 152491, + "431": 152203, + "432": 149088, + "433": 149623, + "434": 145625, + "435": 144551, + "436": 142272, + "437": 144348, + "438": 143065, + "439": 141566, + "440": 140419, + "441": 136460, + "442": 136034, + "443": 135598, + "444": 135685, + "445": 135748, + "446": 135775, + "447": 134907, + "448": 137522, + "449": 135414, + "450": 133798, + "451": 132647, + "452": 130311, + "453": 130279, + "454": 129329, + "455": 128841, + "456": 129603, + "457": 129010, + "458": 129024, + "459": 127104, + "460": 125324, + "461": 125282, + "462": 122738, + "463": 123250, + "464": 120890, + "465": 125019, + "466": 124461, + "467": 121412, + "468": 118898, + "469": 120947, + "470": 121169, + "471": 119435, + "472": 117471, + "473": 117713, + "474": 114912, + "475": 113900, + "476": 112790, + "477": 111842, + "478": 110916, + "479": 108206, + "480": 110872, + "481": 110845, + "482": 108084, + "483": 106857, + "484": 106939, + "485": 107448, + "486": 107026, + "487": 107089, + "488": 106862, + "489": 107147, + "490": 105644, + "491": 105305, + "492": 105905, + "493": 104218, + "494": 104186, + "495": 102637, + "496": 102957, + "497": 102978, + "498": 100815, + "499": 101363, + "500": 100561, + "501": 99161, + "502": 100998, + "503": 98269, + "504": 99121, + "505": 99135, + "506": 97698, + "507": 95445, + "508": 94625, + "509": 95657, + "510": 94061, + "511": 93300, + "512": 93659, + "513": 93949, + "514": 93234, + "515": 93534 + }, + "top_20": { + "14": 21063949, + "18": 21012358, + "16": 20875994, + "10": 20822505, + "17": 20777973, + "12": 20453874, + "13": 20446479, + "15": 20412240, + "9": 20134896, + "8": 19930196, + "19": 19894438, + "11": 19790645, + "20": 19697630, + "22": 19537676, + "21": 19362656, + "24": 19110405, + "26": 19104068, + "25": 19050818, + "23": 18986155, + "6": 18893196 + }, + "bottom_20": { + "496": 102957, + "495": 102637, + "499": 101363, + "502": 100998, + "498": 100815, + "500": 100561, + "501": 99161, + "505": 99135, + "504": 99121, + "503": 98269, + "506": 97698, + "509": 95657, + "507": 95445, + "508": 94625, + "510": 94061, + "513": 93949, + "512": 93659, + "515": 93534, + "511": 93300, + "514": 93234 + } + }, + "velocity": { + "total_tokens": 1642335968, + "unique_tokens": 33, + "token_counts": { + "0": 1565807, + "16": 106044879, + "19": 121520524, + "22": 17907084, + "23": 804640042, + "27": 240655729, + "31": 108584666, + "35": 66731277, + "12": 28070596, + "17": 4936957, + "18": 10781458, + "20": 8010499, + "24": 8397082, + "25": 13708838, + "26": 13413483, + "28": 19109874, + "29": 7454319, + "30": 5637544, + "34": 2423305, + "32": 11600968, + "7": 4642247, + "8": 1221063, + "10": 1446777, + "11": 2203636, + "13": 3500622, + "15": 3620571, + "21": 12580299, + "14": 2805850, + "33": 5792684, + "4": 1330633, + "5": 144791, + "9": 1219397, + "6": 632467 + }, + "top_20": { + "23": 804640042, + "27": 240655729, + "19": 121520524, + "31": 108584666, + "16": 106044879, + "35": 66731277, + "12": 28070596, + "28": 19109874, + "22": 17907084, + "25": 13708838, + "26": 13413483, + "21": 12580299, + "32": 11600968, + "18": 10781458, + "24": 8397082, + "20": 8010499, + "29": 7454319, + "33": 5792684, + "30": 5637544, + "17": 4936957 + }, + "bottom_20": { + "18": 10781458, + "24": 8397082, + "20": 8010499, + "29": 7454319, + "33": 5792684, + "30": 5637544, + "17": 4936957, + "7": 4642247, + "15": 3620571, + "13": 3500622, + "14": 2805850, + "34": 2423305, + "11": 2203636, + "0": 1565807, + "10": 1446777, + "4": 1330633, + "8": 1221063, + "9": 1219397, + "6": 632467, + "5": 144791 + } + }, + "duration": { + "total_tokens": 1642335968, + "unique_tokens": 65, + "token_counts": { + "0": 1565807, + "4": 119984912, + "5": 299694070, + "7": 538838927, + "8": 19136000, + "11": 349624100, + "14": 12318781, + "15": 29051651, + "18": 24332165, + "19": 66064899, + "23": 3676616, + "27": 22562754, + "33": 8501228, + "35": 23195696, + "39": 1621335, + "67": 1782901, + "9": 32147354, + "38": 558300, + "26": 8296587, + "6": 52116530, + "17": 1026287, + "12": 886044, + "22": 1957457, + "30": 669089, + "13": 1723176, + "31": 1523183, + "37": 942258, + "43": 4633121, + "51": 2926756, + "45": 193355, + "47": 598926, + "41": 219388, + "42": 1618315, + "49": 124390, + "50": 909324, + "21": 594628, + "20": 248907, + "34": 741490, + "10": 1609766, + "32": 126822, + "55": 348310, + "16": 1151906, + "40": 83688, + "46": 188411, + "63": 70582, + "25": 280673, + "24": 247343, + "28": 108302, + "29": 176753, + "36": 445239, + "59": 170006, + "52": 28115, + "53": 173115, + "44": 73755, + "48": 40286, + "65": 267548, + "54": 11724, + "61": 33938, + "66": 18810, + "62": 3911, + "57": 54054, + "64": 4279, + "56": 2641, + "60": 2049, + "58": 7235 + }, + "top_20": { + "7": 538838927, + "11": 349624100, + "5": 299694070, + "4": 119984912, + "19": 66064899, + "6": 52116530, + "9": 32147354, + "15": 29051651, + "18": 24332165, + "35": 23195696, + "27": 22562754, + "8": 19136000, + "14": 12318781, + "33": 8501228, + "26": 8296587, + "43": 4633121, + "23": 3676616, + "51": 2926756, + "22": 1957457, + "67": 1782901 + }, + "bottom_20": { + "29": 176753, + "53": 173115, + "59": 170006, + "32": 126822, + "49": 124390, + "28": 108302, + "40": 83688, + "44": 73755, + "63": 70582, + "57": 54054, + "48": 40286, + "61": 33938, + "52": 28115, + "66": 18810, + "54": 11724, + "58": 7235, + "64": 4279, + "62": 3911, + "56": 2641, + "60": 2049 + } + }, + "program": { + "total_tokens": 1642335968, + "unique_tokens": 130, + "token_counts": { + "0": 1565807, + "4": 590202528, + "13": 7020453, + "15": 11568234, + "18": 1798627, + "51": 11245932, + "60": 51895506, + "61": 42179722, + "62": 31563208, + "64": 32513674, + "69": 32843971, + "70": 22681555, + "71": 14861448, + "72": 20317595, + "73": 2036256, + "74": 15757879, + "75": 53920981, + "76": 9170781, + "77": 53361941, + "132": 156060306, + "56": 53946714, + "29": 16869726, + "44": 58164622, + "23": 7895444, + "25": 4467168, + "43": 1770776, + "37": 14236216, + "52": 46809671, + "65": 1670471, + "131": 38496, + "28": 31972258, + "16": 19214717, + "36": 7272440, + "49": 6565376, + "57": 10228467, + "79": 1048044, + "46": 32141823, + "12": 1105576, + "17": 7309929, + "5": 7913663, + "8": 5425561, + "31": 13691387, + "38": 9419991, + "100": 675080, + "103": 495060, + "34": 5131422, + "68": 5536230, + "26": 2989169, + "47": 9812833, + "95": 790975, + "7": 1707791, + "78": 5426193, + "112": 277768, + "50": 8733491, + "19": 490269, + "45": 21634143, + "109": 1488631, + "10": 6436008, + "30": 3593273, + "53": 1551137, + "84": 3064410, + "114": 369590, + "6": 2257405, + "118": 1406534, + "99": 122909, + "63": 628074, + "14": 803067, + "58": 1198428, + "89": 894826, + "21": 543678, + "93": 413020, + "33": 4332158, + "110": 178985, + "42": 1516660, + "32": 1896817, + "66": 1031795, + "85": 2762339, + "123": 33251, + "98": 206211, + "54": 1320902, + "82": 421580, + "92": 677534, + "116": 564397, + "24": 1158967, + "40": 567304, + "102": 184337, + "48": 1363501, + "41": 323978, + "126": 73126, + "22": 917427, + "27": 439300, + "55": 550996, + "104": 453970, + "106": 147786, + "39": 1239473, + "9": 1178077, + "96": 189883, + "97": 206162, + "108": 215159, + "91": 829886, + "111": 1020805, + "83": 856538, + "121": 474742, + "11": 545499, + "20": 936159, + "94": 374895, + "59": 173628, + "80": 93518, + "90": 40354, + "105": 42011, + "113": 359614, + "88": 365918, + "101": 76255, + "35": 203523, + "81": 112112, + "86": 310378, + "115": 54477, + "120": 218005, + "107": 217116, + "87": 127634, + "124": 11419, + "129": 10526, + "130": 12120, + "119": 93731, + "67": 212346, + "122": 79728, + "127": 12119, + "125": 57752, + "128": 10862, + "117": 33869 + }, + "top_20": { + "4": 590202528, + "132": 156060306, + "44": 58164622, + "56": 53946714, + "75": 53920981, + "77": 53361941, + "60": 51895506, + "52": 46809671, + "61": 42179722, + "69": 32843971, + "64": 32513674, + "46": 32141823, + "28": 31972258, + "62": 31563208, + "70": 22681555, + "45": 21634143, + "72": 20317595, + "16": 19214717, + "29": 16869726, + "74": 15757879 + }, + "bottom_20": { + "87": 127634, + "99": 122909, + "81": 112112, + "119": 93731, + "80": 93518, + "122": 79728, + "101": 76255, + "126": 73126, + "125": 57752, + "115": 54477, + "105": 42011, + "90": 40354, + "131": 38496, + "117": 33869, + "123": 33251, + "130": 12120, + "127": 12119, + "124": 11419, + "128": 10862, + "129": 10526 + } + }, + "tempo": { + "total_tokens": 1642335968, + "unique_tokens": 33, + "token_counts": { + "0": 1565807, + "7": 40255445, + "12": 55500781, + "14": 67397229, + "15": 50223556, + "16": 442514188, + "17": 75142988, + "9": 56688242, + "13": 88177737, + "4": 13060108, + "6": 19372161, + "8": 50803286, + "10": 91290939, + "22": 65178994, + "19": 70235288, + "18": 58517046, + "28": 25511675, + "11": 86042535, + "25": 50653417, + "5": 6210341, + "23": 37394478, + "35": 19365351, + "20": 55803768, + "21": 29373616, + "31": 9823593, + "24": 16882039, + "26": 21472940, + "32": 3784501, + "29": 9549287, + "34": 10705011, + "27": 7808256, + "30": 5037444, + "33": 993921 + }, + "top_20": { + "16": 442514188, + "10": 91290939, + "13": 88177737, + "11": 86042535, + "17": 75142988, + "19": 70235288, + "14": 67397229, + "22": 65178994, + "18": 58517046, + "9": 56688242, + "20": 55803768, + "12": 55500781, + "8": 50803286, + "25": 50653417, + "15": 50223556, + "7": 40255445, + "23": 37394478, + "21": 29373616, + "28": 25511675, + "26": 21472940 + }, + "bottom_20": { + "25": 50653417, + "15": 50223556, + "7": 40255445, + "23": 37394478, + "21": 29373616, + "28": 25511675, + "26": 21472940, + "6": 19372161, + "35": 19365351, + "24": 16882039, + "4": 13060108, + "34": 10705011, + "31": 9823593, + "29": 9549287, + "27": 7808256, + "5": 6210341, + "30": 5037444, + "32": 3784501, + "0": 1565807, + "33": 993921 + } + }, + "timesig": { + "total_tokens": 1642335968, + "unique_tokens": 10, + "token_counts": { + "0": 1565807, + "7": 15695389, + "12": 1218108996, + "4": 12280458, + "10": 92647775, + "11": 6676456, + "8": 16074543, + "9": 161413472, + "6": 84169807, + "5": 33703265 + }, + "top_20": { + "12": 1218108996, + "9": 161413472, + "10": 92647775, + "6": 84169807, + "5": 33703265, + "8": 16074543, + "7": 15695389, + "4": 12280458, + "11": 6676456, + "0": 1565807 + }, + "bottom_20": { + "12": 1218108996, + "9": 161413472, + "10": 92647775, + "6": 84169807, + "5": 33703265, + "8": 16074543, + "7": 15695389, + "4": 12280458, + "11": 6676456, + "0": 1565807 + } + } + } +} \ No newline at end of file diff --git a/octuple_token_analysis_report.txt b/octuple_token_analysis_report.txt new file mode 100644 index 0000000..e1e52a0 --- /dev/null +++ b/octuple_token_analysis_report.txt @@ -0,0 +1,467 @@ +================================================================================ +OCTUPLE分词结果统计分析报告 +================================================================================ + +总token数: 13,138,687,744 +分析的列数: 8 + +-------------------------------------------------------------------------------- +列 0: pitch +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 152 + Token值范围: [1, 154] + 平均出现次数: 10804841.89 + + Top 20 最常见的token: + 1. Token 50: 67,083,713 次 ( 4.08%) + 2. Token 45: 62,706,678 次 ( 3.82%) + 3. Token 52: 60,706,253 次 ( 3.70%) + 4. Token 43: 55,427,133 次 ( 3.37%) + 5. Token 47: 53,146,812 次 ( 3.24%) + 6. Token 48: 52,065,620 次 ( 3.17%) + 7. Token 55: 51,312,761 次 ( 3.12%) + 8. Token 57: 50,097,602 次 ( 3.05%) + 9. Token 38: 46,745,881 次 ( 2.85%) + 10. Token 40: 45,106,459 次 ( 2.75%) + 11. Token 54: 40,358,128 次 ( 2.46%) + 12. Token 53: 35,952,930 次 ( 2.19%) + 13. Token 59: 35,424,543 次 ( 2.16%) + 14. Token 33: 35,363,235 次 ( 2.15%) + 15. Token 42: 33,982,698 次 ( 2.07%) + 16. Token 41: 33,902,224 次 ( 2.06%) + 17. Token 46: 33,327,561 次 ( 2.03%) + 18. Token 36: 33,081,225 次 ( 2.01%) + 19. Token 49: 32,901,147 次 ( 2.00%) + 20. Token 31: 31,916,911 次 ( 1.94%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 154: 24,740 次 ( 0.00%) + 2. Token 98: 34,663 次 ( 0.00%) + 3. Token 92: 34,748 次 ( 0.00%) + 4. Token 100: 35,080 次 ( 0.00%) + 5. Token 144: 55,077 次 ( 0.00%) + 6. Token 153: 61,331 次 ( 0.00%) + 7. Token 124: 63,879 次 ( 0.00%) + 8. Token 152: 65,470 次 ( 0.00%) + 9. Token 99: 70,053 次 ( 0.00%) + 10. Token 97: 72,838 次 ( 0.00%) + 11. Token 95: 80,800 次 ( 0.00%) + 12. Token 96: 93,397 次 ( 0.01%) + 13. Token 94: 116,801 次 ( 0.01%) + 14. Token 91: 125,814 次 ( 0.01%) + 15. Token 90: 128,227 次 ( 0.01%) + 16. Token 145: 152,029 次 ( 0.01%) + 17. Token 150: 156,082 次 ( 0.01%) + 18. Token 132: 161,855 次 ( 0.01%) + 19. Token 137: 176,674 次 ( 0.01%) + 20. Token 89: 182,520 次 ( 0.01%) + + 分布统计: + 最小出现次数: 24,740 + 最大出现次数: 67,083,713 + 中位数出现次数: 2,542,284 + 标准差: 15,554,192.13 + +-------------------------------------------------------------------------------- +列 1: position +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 97 + Token值范围: [0, 99] + 平均出现次数: 16931298.64 + + Top 20 最常见的token: + 1. Token 4: 319,618,129 次 (19.46%) + 2. Token 20: 210,069,958 次 (12.79%) + 3. Token 12: 189,425,867 次 (11.53%) + 4. Token 28: 161,135,286 次 ( 9.81%) + 5. Token 16: 126,741,165 次 ( 7.72%) + 6. Token 8: 101,716,803 次 ( 6.19%) + 7. Token 24: 94,202,245 次 ( 5.74%) + 8. Token 32: 90,779,192 次 ( 5.53%) + 9. Token 10: 32,622,087 次 ( 1.99%) + 10. Token 18: 29,551,363 次 ( 1.80%) + 11. Token 26: 24,841,033 次 ( 1.51%) + 12. Token 14: 22,994,657 次 ( 1.40%) + 13. Token 6: 21,992,580 次 ( 1.34%) + 14. Token 34: 20,609,995 次 ( 1.25%) + 15. Token 22: 20,404,843 次 ( 1.24%) + 16. Token 30: 18,121,821 次 ( 1.10%) + 17. Token 44: 15,095,654 次 ( 0.92%) + 18. Token 36: 14,308,368 次 ( 0.87%) + 19. Token 17: 11,170,620 次 ( 0.68%) + 20. Token 9: 10,645,784 次 ( 0.65%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 99: 4,324 次 ( 0.00%) + 2. Token 59: 8,373 次 ( 0.00%) + 3. Token 61: 9,445 次 ( 0.00%) + 4. Token 91: 9,702 次 ( 0.00%) + 5. Token 75: 10,098 次 ( 0.00%) + 6. Token 67: 10,867 次 ( 0.00%) + 7. Token 53: 12,147 次 ( 0.00%) + 8. Token 57: 12,162 次 ( 0.00%) + 9. Token 65: 12,397 次 ( 0.00%) + 10. Token 93: 12,597 次 ( 0.00%) + 11. Token 81: 12,917 次 ( 0.00%) + 12. Token 97: 13,515 次 ( 0.00%) + 13. Token 83: 13,645 次 ( 0.00%) + 14. Token 55: 13,774 次 ( 0.00%) + 15. Token 69: 13,830 次 ( 0.00%) + 16. Token 63: 13,897 次 ( 0.00%) + 17. Token 89: 14,388 次 ( 0.00%) + 18. Token 77: 15,166 次 ( 0.00%) + 19. Token 85: 15,437 次 ( 0.00%) + 20. Token 71: 15,735 次 ( 0.00%) + + 分布统计: + 最小出现次数: 4,324 + 最大出现次数: 319,618,129 + 中位数出现次数: 262,463 + 标准差: 48,675,567.69 + +-------------------------------------------------------------------------------- +列 2: bar +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 513 + Token值范围: [0, 515] + 平均出现次数: 3201434.64 + + Top 20 最常见的token: + 1. Token 14: 21,063,949 次 ( 1.28%) + 2. Token 18: 21,012,358 次 ( 1.28%) + 3. Token 16: 20,875,994 次 ( 1.27%) + 4. Token 10: 20,822,505 次 ( 1.27%) + 5. Token 17: 20,777,973 次 ( 1.27%) + 6. Token 12: 20,453,874 次 ( 1.25%) + 7. Token 13: 20,446,479 次 ( 1.24%) + 8. Token 15: 20,412,240 次 ( 1.24%) + 9. Token 9: 20,134,896 次 ( 1.23%) + 10. Token 8: 19,930,196 次 ( 1.21%) + 11. Token 19: 19,894,438 次 ( 1.21%) + 12. Token 11: 19,790,645 次 ( 1.21%) + 13. Token 20: 19,697,630 次 ( 1.20%) + 14. Token 22: 19,537,676 次 ( 1.19%) + 15. Token 21: 19,362,656 次 ( 1.18%) + 16. Token 24: 19,110,405 次 ( 1.16%) + 17. Token 26: 19,104,068 次 ( 1.16%) + 18. Token 25: 19,050,818 次 ( 1.16%) + 19. Token 23: 18,986,155 次 ( 1.16%) + 20. Token 6: 18,893,196 次 ( 1.15%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 514: 93,234 次 ( 0.01%) + 2. Token 511: 93,300 次 ( 0.01%) + 3. Token 515: 93,534 次 ( 0.01%) + 4. Token 512: 93,659 次 ( 0.01%) + 5. Token 513: 93,949 次 ( 0.01%) + 6. Token 510: 94,061 次 ( 0.01%) + 7. Token 508: 94,625 次 ( 0.01%) + 8. Token 507: 95,445 次 ( 0.01%) + 9. Token 509: 95,657 次 ( 0.01%) + 10. Token 506: 97,698 次 ( 0.01%) + 11. Token 503: 98,269 次 ( 0.01%) + 12. Token 504: 99,121 次 ( 0.01%) + 13. Token 505: 99,135 次 ( 0.01%) + 14. Token 501: 99,161 次 ( 0.01%) + 15. Token 500: 100,561 次 ( 0.01%) + 16. Token 498: 100,815 次 ( 0.01%) + 17. Token 502: 100,998 次 ( 0.01%) + 18. Token 499: 101,363 次 ( 0.01%) + 19. Token 495: 102,637 次 ( 0.01%) + 20. Token 496: 102,957 次 ( 0.01%) + + 分布统计: + 最小出现次数: 93,234 + 最大出现次数: 21,063,949 + 中位数出现次数: 570,459 + 标准差: 5,386,254.49 + +-------------------------------------------------------------------------------- +列 3: velocity +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 33 + Token值范围: [0, 35] + 平均出现次数: 49767756.61 + + Top 20 最常见的token: + 1. Token 23: 804,640,042 次 (48.99%) + 2. Token 27: 240,655,729 次 (14.65%) + 3. Token 19: 121,520,524 次 ( 7.40%) + 4. Token 31: 108,584,666 次 ( 6.61%) + 5. Token 16: 106,044,879 次 ( 6.46%) + 6. Token 35: 66,731,277 次 ( 4.06%) + 7. Token 12: 28,070,596 次 ( 1.71%) + 8. Token 28: 19,109,874 次 ( 1.16%) + 9. Token 22: 17,907,084 次 ( 1.09%) + 10. Token 25: 13,708,838 次 ( 0.83%) + 11. Token 26: 13,413,483 次 ( 0.82%) + 12. Token 21: 12,580,299 次 ( 0.77%) + 13. Token 32: 11,600,968 次 ( 0.71%) + 14. Token 18: 10,781,458 次 ( 0.66%) + 15. Token 24: 8,397,082 次 ( 0.51%) + 16. Token 20: 8,010,499 次 ( 0.49%) + 17. Token 29: 7,454,319 次 ( 0.45%) + 18. Token 33: 5,792,684 次 ( 0.35%) + 19. Token 30: 5,637,544 次 ( 0.34%) + 20. Token 17: 4,936,957 次 ( 0.30%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 5: 144,791 次 ( 0.01%) + 2. Token 6: 632,467 次 ( 0.04%) + 3. Token 9: 1,219,397 次 ( 0.07%) + 4. Token 8: 1,221,063 次 ( 0.07%) + 5. Token 4: 1,330,633 次 ( 0.08%) + 6. Token 10: 1,446,777 次 ( 0.09%) + 7. Token 0: 1,565,807 次 ( 0.10%) + 8. Token 11: 2,203,636 次 ( 0.13%) + 9. Token 34: 2,423,305 次 ( 0.15%) + 10. Token 14: 2,805,850 次 ( 0.17%) + 11. Token 13: 3,500,622 次 ( 0.21%) + 12. Token 15: 3,620,571 次 ( 0.22%) + 13. Token 7: 4,642,247 次 ( 0.28%) + 14. Token 17: 4,936,957 次 ( 0.30%) + 15. Token 30: 5,637,544 次 ( 0.34%) + 16. Token 33: 5,792,684 次 ( 0.35%) + 17. Token 29: 7,454,319 次 ( 0.45%) + 18. Token 20: 8,010,499 次 ( 0.49%) + 19. Token 24: 8,397,082 次 ( 0.51%) + 20. Token 18: 10,781,458 次 ( 0.66%) + + 分布统计: + 最小出现次数: 144,791 + 最大出现次数: 804,640,042 + 中位数出现次数: 7,454,319 + 标准差: 142,327,810.09 + +-------------------------------------------------------------------------------- +列 4: duration +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 65 + Token值范围: [0, 67] + 平均出现次数: 25266707.20 + + Top 20 最常见的token: + 1. Token 7: 538,838,927 次 (32.81%) + 2. Token 11: 349,624,100 次 (21.29%) + 3. Token 5: 299,694,070 次 (18.25%) + 4. Token 4: 119,984,912 次 ( 7.31%) + 5. Token 19: 66,064,899 次 ( 4.02%) + 6. Token 6: 52,116,530 次 ( 3.17%) + 7. Token 9: 32,147,354 次 ( 1.96%) + 8. Token 15: 29,051,651 次 ( 1.77%) + 9. Token 18: 24,332,165 次 ( 1.48%) + 10. Token 35: 23,195,696 次 ( 1.41%) + 11. Token 27: 22,562,754 次 ( 1.37%) + 12. Token 8: 19,136,000 次 ( 1.17%) + 13. Token 14: 12,318,781 次 ( 0.75%) + 14. Token 33: 8,501,228 次 ( 0.52%) + 15. Token 26: 8,296,587 次 ( 0.51%) + 16. Token 43: 4,633,121 次 ( 0.28%) + 17. Token 23: 3,676,616 次 ( 0.22%) + 18. Token 51: 2,926,756 次 ( 0.18%) + 19. Token 22: 1,957,457 次 ( 0.12%) + 20. Token 67: 1,782,901 次 ( 0.11%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 60: 2,049 次 ( 0.00%) + 2. Token 56: 2,641 次 ( 0.00%) + 3. Token 62: 3,911 次 ( 0.00%) + 4. Token 64: 4,279 次 ( 0.00%) + 5. Token 58: 7,235 次 ( 0.00%) + 6. Token 54: 11,724 次 ( 0.00%) + 7. Token 66: 18,810 次 ( 0.00%) + 8. Token 52: 28,115 次 ( 0.00%) + 9. Token 61: 33,938 次 ( 0.00%) + 10. Token 48: 40,286 次 ( 0.00%) + 11. Token 57: 54,054 次 ( 0.00%) + 12. Token 63: 70,582 次 ( 0.00%) + 13. Token 44: 73,755 次 ( 0.00%) + 14. Token 40: 83,688 次 ( 0.01%) + 15. Token 28: 108,302 次 ( 0.01%) + 16. Token 49: 124,390 次 ( 0.01%) + 17. Token 32: 126,822 次 ( 0.01%) + 18. Token 59: 170,006 次 ( 0.01%) + 19. Token 53: 173,115 次 ( 0.01%) + 20. Token 29: 176,753 次 ( 0.01%) + + 分布统计: + 最小出现次数: 2,049 + 最大出现次数: 538,838,927 + 中位数出现次数: 669,089 + 标准差: 86,525,334.33 + +-------------------------------------------------------------------------------- +列 5: program +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 130 + Token值范围: [0, 132] + 平均出现次数: 12633353.60 + + Top 20 最常见的token: + 1. Token 4: 590,202,528 次 (35.94%) + 2. Token 132: 156,060,306 次 ( 9.50%) + 3. Token 44: 58,164,622 次 ( 3.54%) + 4. Token 56: 53,946,714 次 ( 3.28%) + 5. Token 75: 53,920,981 次 ( 3.28%) + 6. Token 77: 53,361,941 次 ( 3.25%) + 7. Token 60: 51,895,506 次 ( 3.16%) + 8. Token 52: 46,809,671 次 ( 2.85%) + 9. Token 61: 42,179,722 次 ( 2.57%) + 10. Token 69: 32,843,971 次 ( 2.00%) + 11. Token 64: 32,513,674 次 ( 1.98%) + 12. Token 46: 32,141,823 次 ( 1.96%) + 13. Token 28: 31,972,258 次 ( 1.95%) + 14. Token 62: 31,563,208 次 ( 1.92%) + 15. Token 70: 22,681,555 次 ( 1.38%) + 16. Token 45: 21,634,143 次 ( 1.32%) + 17. Token 72: 20,317,595 次 ( 1.24%) + 18. Token 16: 19,214,717 次 ( 1.17%) + 19. Token 29: 16,869,726 次 ( 1.03%) + 20. Token 74: 15,757,879 次 ( 0.96%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 129: 10,526 次 ( 0.00%) + 2. Token 128: 10,862 次 ( 0.00%) + 3. Token 124: 11,419 次 ( 0.00%) + 4. Token 127: 12,119 次 ( 0.00%) + 5. Token 130: 12,120 次 ( 0.00%) + 6. Token 123: 33,251 次 ( 0.00%) + 7. Token 117: 33,869 次 ( 0.00%) + 8. Token 131: 38,496 次 ( 0.00%) + 9. Token 90: 40,354 次 ( 0.00%) + 10. Token 105: 42,011 次 ( 0.00%) + 11. Token 115: 54,477 次 ( 0.00%) + 12. Token 125: 57,752 次 ( 0.00%) + 13. Token 126: 73,126 次 ( 0.00%) + 14. Token 101: 76,255 次 ( 0.00%) + 15. Token 122: 79,728 次 ( 0.00%) + 16. Token 80: 93,518 次 ( 0.01%) + 17. Token 119: 93,731 次 ( 0.01%) + 18. Token 81: 112,112 次 ( 0.01%) + 19. Token 99: 122,909 次 ( 0.01%) + 20. Token 87: 127,634 次 ( 0.01%) + + 分布统计: + 最小出现次数: 10,526 + 最大出现次数: 590,202,528 + 中位数出现次数: 1,132,272 + 标准差: 54,071,389.60 + +-------------------------------------------------------------------------------- +列 6: tempo +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 33 + Token值范围: [0, 35] + 平均出现次数: 49767756.61 + + Top 20 最常见的token: + 1. Token 16: 442,514,188 次 (26.94%) + 2. Token 10: 91,290,939 次 ( 5.56%) + 3. Token 13: 88,177,737 次 ( 5.37%) + 4. Token 11: 86,042,535 次 ( 5.24%) + 5. Token 17: 75,142,988 次 ( 4.58%) + 6. Token 19: 70,235,288 次 ( 4.28%) + 7. Token 14: 67,397,229 次 ( 4.10%) + 8. Token 22: 65,178,994 次 ( 3.97%) + 9. Token 18: 58,517,046 次 ( 3.56%) + 10. Token 9: 56,688,242 次 ( 3.45%) + 11. Token 20: 55,803,768 次 ( 3.40%) + 12. Token 12: 55,500,781 次 ( 3.38%) + 13. Token 8: 50,803,286 次 ( 3.09%) + 14. Token 25: 50,653,417 次 ( 3.08%) + 15. Token 15: 50,223,556 次 ( 3.06%) + 16. Token 7: 40,255,445 次 ( 2.45%) + 17. Token 23: 37,394,478 次 ( 2.28%) + 18. Token 21: 29,373,616 次 ( 1.79%) + 19. Token 28: 25,511,675 次 ( 1.55%) + 20. Token 26: 21,472,940 次 ( 1.31%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 33: 993,921 次 ( 0.06%) + 2. Token 0: 1,565,807 次 ( 0.10%) + 3. Token 32: 3,784,501 次 ( 0.23%) + 4. Token 30: 5,037,444 次 ( 0.31%) + 5. Token 5: 6,210,341 次 ( 0.38%) + 6. Token 27: 7,808,256 次 ( 0.48%) + 7. Token 29: 9,549,287 次 ( 0.58%) + 8. Token 31: 9,823,593 次 ( 0.60%) + 9. Token 34: 10,705,011 次 ( 0.65%) + 10. Token 4: 13,060,108 次 ( 0.80%) + 11. Token 24: 16,882,039 次 ( 1.03%) + 12. Token 35: 19,365,351 次 ( 1.18%) + 13. Token 6: 19,372,161 次 ( 1.18%) + 14. Token 26: 21,472,940 次 ( 1.31%) + 15. Token 28: 25,511,675 次 ( 1.55%) + 16. Token 21: 29,373,616 次 ( 1.79%) + 17. Token 23: 37,394,478 次 ( 2.28%) + 18. Token 7: 40,255,445 次 ( 2.45%) + 19. Token 15: 50,223,556 次 ( 3.06%) + 20. Token 25: 50,653,417 次 ( 3.08%) + + 分布统计: + 最小出现次数: 993,921 + 最大出现次数: 442,514,188 + 中位数出现次数: 37,394,478 + 标准差: 74,693,721.41 + +-------------------------------------------------------------------------------- +列 7: timesig +-------------------------------------------------------------------------------- + 总token数: 1,642,335,968 + 唯一token数: 10 + Token值范围: [0, 12] + 平均出现次数: 164233596.80 + + Top 20 最常见的token: + 1. Token 12: 1,218,108,996 次 (74.17%) + 2. Token 9: 161,413,472 次 ( 9.83%) + 3. Token 10: 92,647,775 次 ( 5.64%) + 4. Token 6: 84,169,807 次 ( 5.13%) + 5. Token 5: 33,703,265 次 ( 2.05%) + 6. Token 8: 16,074,543 次 ( 0.98%) + 7. Token 7: 15,695,389 次 ( 0.96%) + 8. Token 4: 12,280,458 次 ( 0.75%) + 9. Token 11: 6,676,456 次 ( 0.41%) + 10. Token 0: 1,565,807 次 ( 0.10%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 0: 1,565,807 次 ( 0.10%) + 2. Token 11: 6,676,456 次 ( 0.41%) + 3. Token 4: 12,280,458 次 ( 0.75%) + 4. Token 7: 15,695,389 次 ( 0.96%) + 5. Token 8: 16,074,543 次 ( 0.98%) + 6. Token 5: 33,703,265 次 ( 2.05%) + 7. Token 6: 84,169,807 次 ( 5.13%) + 8. Token 10: 92,647,775 次 ( 5.64%) + 9. Token 9: 161,413,472 次 ( 9.83%) + 10. Token 12: 1,218,108,996 次 (74.17%) + + 分布统计: + 最小出现次数: 1,565,807 + 最大出现次数: 1,218,108,996 + 中位数出现次数: 24,888,904 + 标准差: 354,629,911.49 + +================================================================================ +跨列分析 +================================================================================ + + pitch : 1,642,335,968 tokens (12.50%) + position : 1,642,335,968 tokens (12.50%) + bar : 1,642,335,968 tokens (12.50%) + velocity : 1,642,335,968 tokens (12.50%) + duration : 1,642,335,968 tokens (12.50%) + program : 1,642,335,968 tokens (12.50%) + tempo : 1,642,335,968 tokens (12.50%) + timesig : 1,642,335,968 tokens (12.50%) + +================================================================================ +报告生成完成 +================================================================================ \ No newline at end of file diff --git a/octuple_token_analysis_report_part.json b/octuple_token_analysis_report_part.json new file mode 100644 index 0000000..587e494 --- /dev/null +++ b/octuple_token_analysis_report_part.json @@ -0,0 +1,1421 @@ +{ + "summary": { + "total_tokens": 1321324184, + "num_columns": 8 + }, + "columns": { + "pitch": { + "total_tokens": 165165523, + "unique_tokens": 152, + "token_counts": { + "1": 167245, + "35": 3032955, + "36": 3291100, + "37": 1941303, + "38": 4698214, + "39": 2206807, + "40": 4524339, + "41": 3389405, + "42": 3443881, + "43": 5553458, + "44": 2537783, + "45": 6319724, + "46": 3339383, + "47": 5378723, + "48": 5238197, + "49": 3305354, + "50": 6756218, + "51": 2720803, + "52": 6124177, + "53": 3608951, + "54": 4114700, + "55": 5193681, + "56": 2265515, + "57": 5063741, + "58": 2089985, + "59": 3596239, + "60": 2723215, + "61": 1801260, + "62": 2922633, + "63": 1101245, + "64": 2067008, + "65": 1107153, + "66": 1087622, + "67": 1185439, + "68": 511912, + "69": 883348, + "70": 422124, + "71": 534337, + "76": 256530, + "15": 266969, + "16": 555604, + "17": 530023, + "18": 508214, + "19": 1125764, + "20": 508255, + "21": 1358053, + "22": 873423, + "23": 1258481, + "24": 1638390, + "25": 901114, + "26": 2381704, + "27": 1153903, + "28": 2374631, + "29": 1898440, + "31": 3182789, + "32": 1370753, + "33": 3557109, + "34": 1986973, + "14": 468000, + "30": 1742566, + "74": 390983, + "11": 205589, + "13": 173517, + "72": 432072, + "73": 239736, + "101": 1300461, + "102": 1554275, + "104": 2448208, + "108": 2300492, + "109": 121388, + "111": 187910, + "113": 125812, + "114": 259813, + "115": 255911, + "116": 530597, + "117": 589677, + "118": 111202, + "119": 103805, + "121": 60218, + "122": 83779, + "123": 259663, + "125": 135500, + "8": 37468, + "10": 78173, + "75": 180317, + "77": 193784, + "103": 462755, + "120": 409291, + "79": 188398, + "81": 137367, + "82": 85620, + "84": 77566, + "87": 26589, + "89": 25596, + "91": 17313, + "96": 12211, + "138": 153130, + "12": 253935, + "78": 152179, + "80": 85746, + "141": 105564, + "135": 250244, + "142": 170681, + "7": 63207, + "112": 370757, + "110": 378695, + "86": 58567, + "83": 90440, + "106": 448054, + "107": 242033, + "129": 183942, + "130": 179627, + "124": 7025, + "9": 98899, + "131": 29462, + "132": 15082, + "128": 152550, + "133": 45284, + "140": 23151, + "143": 86240, + "85": 43396, + "4": 22345, + "5": 24682, + "6": 25099, + "88": 26727, + "90": 14891, + "145": 16302, + "126": 380569, + "134": 47726, + "98": 3947, + "136": 194229, + "139": 55426, + "137": 16361, + "144": 6163, + "147": 100882, + "149": 38339, + "105": 148313, + "127": 143145, + "146": 37485, + "94": 10733, + "93": 38276, + "152": 6992, + "92": 3928, + "148": 80931, + "150": 15272, + "97": 8181, + "95": 8772, + "153": 5693, + "151": 33174, + "99": 7149, + "100": 3087, + "154": 2893 + }, + "top_20": { + "50": 6756218, + "45": 6319724, + "52": 6124177, + "43": 5553458, + "47": 5378723, + "48": 5238197, + "55": 5193681, + "57": 5063741, + "38": 4698214, + "40": 4524339, + "54": 4114700, + "53": 3608951, + "59": 3596239, + "33": 3557109, + "42": 3443881, + "41": 3389405, + "46": 3339383, + "49": 3305354, + "36": 3291100, + "31": 3182789 + }, + "bottom_20": { + "4": 22345, + "91": 17313, + "137": 16361, + "145": 16302, + "150": 15272, + "132": 15082, + "90": 14891, + "96": 12211, + "94": 10733, + "95": 8772, + "97": 8181, + "99": 7149, + "124": 7025, + "152": 6992, + "144": 6163, + "153": 5693, + "98": 3947, + "92": 3928, + "100": 3087, + "154": 2893 + } + }, + "position": { + "total_tokens": 165165523, + "unique_tokens": 97, + "token_counts": { + "0": 167245, + "4": 32574271, + "6": 2216672, + "7": 681728, + "8": 10194883, + "9": 1108378, + "10": 3268085, + "11": 421638, + "12": 18937356, + "13": 494427, + "14": 2308328, + "15": 769775, + "16": 12662506, + "17": 1159945, + "18": 2956264, + "19": 415805, + "20": 21155511, + "21": 490924, + "22": 2020825, + "23": 578675, + "24": 9383102, + "25": 1044601, + "26": 2464244, + "27": 356155, + "28": 16053301, + "29": 391908, + "30": 1808841, + "31": 657855, + "32": 9043069, + "33": 926683, + "34": 2047098, + "35": 313577, + "36": 1430246, + "40": 345204, + "44": 1473726, + "48": 234895, + "5": 662403, + "37": 17165, + "39": 22019, + "41": 29467, + "43": 14422, + "46": 41352, + "47": 17725, + "49": 17138, + "51": 4963, + "52": 372619, + "60": 135808, + "68": 228686, + "76": 294865, + "84": 185034, + "92": 222041, + "38": 51253, + "42": 50555, + "50": 29809, + "54": 3420, + "56": 18957, + "58": 2901, + "62": 3057, + "64": 28125, + "66": 2110, + "70": 3020, + "72": 21044, + "74": 2106, + "78": 4284, + "80": 18469, + "82": 2844, + "86": 2442, + "88": 26349, + "90": 2595, + "94": 4803, + "96": 22347, + "98": 2526, + "45": 11246, + "73": 1749, + "77": 1200, + "79": 2226, + "81": 1211, + "83": 1092, + "93": 1203, + "95": 1969, + "97": 1067, + "99": 333, + "55": 1276, + "57": 1288, + "59": 591, + "61": 732, + "63": 1529, + "65": 1075, + "67": 795, + "71": 1346, + "69": 1083, + "75": 798, + "85": 1152, + "87": 1206, + "89": 1411, + "91": 672, + "53": 804 + }, + "top_20": { + "4": 32574271, + "20": 21155511, + "12": 18937356, + "28": 16053301, + "16": 12662506, + "8": 10194883, + "24": 9383102, + "32": 9043069, + "10": 3268085, + "18": 2956264, + "26": 2464244, + "14": 2308328, + "6": 2216672, + "34": 2047098, + "22": 2020825, + "30": 1808841, + "44": 1473726, + "36": 1430246, + "17": 1159945, + "9": 1108378 + }, + "bottom_20": { + "89": 1411, + "71": 1346, + "57": 1288, + "55": 1276, + "81": 1211, + "87": 1206, + "93": 1203, + "77": 1200, + "85": 1152, + "83": 1092, + "69": 1083, + "65": 1075, + "97": 1067, + "53": 804, + "75": 798, + "67": 795, + "61": 732, + "91": 672, + "59": 591, + "99": 333 + } + }, + "bar": { + "total_tokens": 165165523, + "unique_tokens": 513, + "token_counts": { + "0": 167245, + "4": 1584239, + "5": 1815229, + "6": 1938737, + "7": 1894412, + "8": 2039302, + "9": 2055179, + "10": 2127003, + "11": 2017531, + "12": 2077660, + "13": 2077279, + "14": 2130472, + "15": 2066264, + "16": 2112689, + "17": 2101670, + "18": 2125751, + "19": 2013280, + "20": 1985896, + "21": 1952546, + "22": 1967164, + "23": 1910880, + "24": 1928889, + "25": 1922850, + "26": 1925621, + "27": 1840108, + "28": 1842884, + "29": 1819115, + "30": 1816488, + "31": 1769088, + "32": 1778705, + "33": 1758335, + "34": 1756560, + "35": 1672601, + "36": 1634620, + "37": 1606197, + "38": 1598258, + "39": 1550562, + "40": 1542778, + "41": 1525716, + "42": 1514875, + "43": 1462196, + "44": 1448465, + "45": 1421335, + "46": 1419295, + "47": 1384559, + "48": 1375042, + "49": 1359384, + "50": 1350370, + "51": 1310789, + "52": 1289717, + "53": 1275347, + "54": 1259000, + "55": 1228319, + "56": 1221387, + "57": 1203722, + "58": 1193094, + "59": 1162590, + "60": 1143424, + "61": 1127876, + "62": 1118524, + "63": 1093179, + "64": 1090135, + "65": 1074894, + "66": 1054527, + "67": 1022532, + "68": 1003133, + "69": 982043, + "70": 962875, + "71": 936460, + "72": 932475, + "73": 915952, + "74": 898570, + "75": 885162, + "76": 868578, + "77": 854409, + "78": 844041, + "79": 824095, + "80": 817547, + "81": 800124, + "82": 790089, + "83": 760977, + "84": 747102, + "85": 726306, + "86": 717237, + "87": 695672, + "88": 688280, + "89": 672211, + "90": 665152, + "91": 648369, + "92": 638540, + "93": 625938, + "94": 620793, + "95": 608328, + "96": 594086, + "97": 580961, + "98": 573204, + "99": 557837, + "100": 549468, + "101": 539457, + "102": 533705, + "103": 521055, + "104": 511239, + "106": 492838, + "105": 499746, + "107": 478297, + "108": 472814, + "109": 461696, + "110": 455041, + "111": 441779, + "112": 436004, + "113": 427736, + "114": 424689, + "115": 414176, + "116": 407941, + "117": 398758, + "118": 390110, + "119": 381493, + "120": 379074, + "121": 368285, + "122": 364555, + "123": 354800, + "124": 348819, + "125": 341038, + "126": 335178, + "127": 327360, + "128": 321341, + "129": 315446, + "130": 311731, + "131": 303579, + "132": 296764, + "133": 291270, + "134": 289651, + "135": 286034, + "136": 281108, + "137": 276620, + "138": 270006, + "139": 264370, + "140": 260931, + "141": 251570, + "142": 250052, + "143": 244298, + "144": 241269, + "145": 240796, + "146": 238612, + "147": 233575, + "148": 226664, + "149": 223332, + "150": 221966, + "151": 215626, + "152": 211399, + "153": 211423, + "154": 210084, + "155": 205253, + "156": 200939, + "157": 196635, + "158": 196916, + "159": 195045, + "160": 193578, + "161": 186158, + "162": 183646, + "163": 179485, + "164": 174131, + "165": 171288, + "166": 167780, + "167": 165902, + "168": 162178, + "169": 161338, + "170": 157680, + "171": 156358, + "172": 154193, + "173": 152870, + "174": 151729, + "175": 149160, + "176": 146768, + "177": 145443, + "178": 143183, + "179": 142011, + "180": 139910, + "181": 137801, + "182": 136801, + "183": 134343, + "184": 130433, + "185": 128327, + "186": 126652, + "187": 124581, + "188": 124793, + "189": 125188, + "190": 123865, + "191": 121113, + "192": 117962, + "193": 116795, + "194": 115081, + "195": 112130, + "196": 111655, + "197": 111152, + "198": 109583, + "199": 106581, + "200": 106582, + "201": 104917, + "202": 104034, + "203": 103027, + "204": 102371, + "205": 100529, + "206": 98799, + "207": 97053, + "208": 98139, + "209": 95963, + "210": 95651, + "211": 94429, + "212": 94108, + "213": 92742, + "214": 90897, + "215": 90981, + "216": 88979, + "217": 89186, + "218": 89135, + "219": 87596, + "220": 86250, + "221": 84795, + "222": 84720, + "223": 83632, + "224": 82192, + "225": 80597, + "226": 80012, + "227": 79300, + "228": 78462, + "229": 77013, + "230": 73110, + "231": 72846, + "232": 74555, + "233": 74089, + "234": 74558, + "235": 74348, + "236": 72473, + "237": 71519, + "238": 69831, + "239": 70213, + "240": 69999, + "241": 67352, + "242": 67685, + "243": 66763, + "244": 65477, + "245": 64992, + "246": 64458, + "247": 64109, + "248": 64433, + "249": 63778, + "250": 63157, + "251": 63334, + "252": 60698, + "253": 60649, + "254": 61311, + "255": 61117, + "256": 59763, + "257": 60449, + "258": 60022, + "259": 57188, + "260": 58000, + "261": 57125, + "262": 56372, + "263": 55567, + "264": 55837, + "265": 55488, + "266": 54004, + "267": 54539, + "268": 54181, + "269": 55261, + "270": 52953, + "271": 52714, + "272": 53690, + "273": 52655, + "274": 49305, + "275": 49029, + "276": 48186, + "277": 47753, + "278": 47099, + "279": 47070, + "280": 47117, + "281": 46699, + "282": 46858, + "283": 45831, + "284": 45742, + "285": 45806, + "286": 43658, + "287": 43115, + "288": 43864, + "289": 42764, + "290": 42430, + "291": 41780, + "292": 41644, + "293": 41289, + "294": 40664, + "295": 40497, + "296": 40184, + "297": 40065, + "298": 39531, + "299": 38589, + "300": 38211, + "301": 38911, + "302": 37867, + "303": 37095, + "304": 36763, + "305": 36739, + "306": 36455, + "307": 36622, + "308": 36107, + "309": 36337, + "310": 36608, + "311": 35849, + "312": 35367, + "313": 35456, + "314": 35165, + "315": 34413, + "316": 34099, + "317": 35022, + "318": 34165, + "319": 34241, + "320": 33780, + "321": 33477, + "322": 34077, + "323": 33387, + "324": 33893, + "325": 33101, + "326": 32375, + "327": 31695, + "328": 31804, + "329": 31659, + "330": 31203, + "331": 30532, + "332": 30349, + "333": 30338, + "334": 30513, + "335": 30194, + "336": 30099, + "337": 29582, + "338": 28991, + "339": 29504, + "340": 28948, + "341": 29396, + "342": 29377, + "343": 28521, + "344": 27886, + "345": 28307, + "346": 27847, + "347": 26466, + "348": 26114, + "349": 26742, + "350": 26566, + "351": 26116, + "352": 26279, + "353": 26549, + "354": 25689, + "355": 24894, + "356": 25164, + "357": 25996, + "358": 25890, + "359": 25348, + "360": 25555, + "361": 25677, + "362": 25637, + "363": 24254, + "364": 24541, + "365": 23724, + "366": 23846, + "367": 23588, + "368": 23657, + "369": 23482, + "370": 23016, + "371": 22783, + "372": 22685, + "373": 22666, + "374": 22314, + "375": 21464, + "376": 21359, + "377": 21765, + "378": 21882, + "379": 21531, + "380": 21366, + "381": 21082, + "382": 21203, + "383": 21040, + "384": 22029, + "385": 21037, + "392": 20076, + "393": 19600, + "394": 20006, + "395": 19964, + "396": 19725, + "397": 20119, + "398": 19542, + "399": 19443, + "400": 19025, + "401": 19122, + "402": 18426, + "403": 18811, + "404": 18760, + "405": 18935, + "406": 18898, + "407": 19302, + "408": 18483, + "409": 18692, + "410": 18302, + "411": 17806, + "412": 17180, + "413": 17326, + "414": 17487, + "415": 17596, + "416": 17290, + "417": 17808, + "418": 17344, + "419": 16566, + "420": 16868, + "421": 16861, + "425": 14903, + "426": 15527, + "427": 14842, + "428": 15202, + "429": 14871, + "430": 14625, + "431": 14458, + "444": 13666, + "445": 13864, + "446": 13528, + "447": 13052, + "448": 13011, + "449": 13019, + "450": 12675, + "451": 12767, + "452": 12212, + "469": 12271, + "479": 10919, + "480": 11445, + "481": 11087, + "482": 10676, + "483": 10523, + "486": 10707, + "487": 10321, + "488": 10615, + "490": 10456, + "491": 10445, + "492": 10748, + "496": 10723, + "497": 10264, + "498": 9959, + "500": 9869, + "501": 9974, + "505": 9859, + "506": 9651, + "510": 9162, + "511": 9129, + "512": 9599, + "513": 9589, + "514": 9841, + "515": 9795, + "386": 20293, + "387": 20152, + "388": 19989, + "389": 19925, + "390": 19734, + "391": 20363, + "422": 16247, + "423": 15644, + "424": 15357, + "432": 14792, + "434": 14375, + "435": 14087, + "436": 13747, + "437": 13818, + "438": 13705, + "439": 13747, + "440": 13306, + "441": 13316, + "442": 13338, + "443": 13208, + "453": 12632, + "454": 12860, + "455": 12957, + "456": 13181, + "457": 13030, + "458": 12453, + "459": 12047, + "460": 11876, + "461": 11828, + "462": 11461, + "463": 11340, + "464": 11525, + "465": 12360, + "466": 12712, + "467": 12222, + "468": 12003, + "470": 12126, + "471": 11633, + "472": 11071, + "473": 11096, + "474": 11205, + "475": 11043, + "476": 10979, + "477": 11070, + "478": 11309, + "484": 10646, + "485": 10739, + "489": 10351, + "493": 10945, + "494": 11016, + "495": 10297, + "499": 9813, + "502": 10061, + "503": 9818, + "504": 10076, + "507": 9211, + "508": 9407, + "509": 9374, + "433": 14743 + }, + "top_20": { + "14": 2130472, + "10": 2127003, + "18": 2125751, + "16": 2112689, + "17": 2101670, + "12": 2077660, + "13": 2077279, + "15": 2066264, + "9": 2055179, + "8": 2039302, + "11": 2017531, + "19": 2013280, + "20": 1985896, + "22": 1967164, + "21": 1952546, + "6": 1938737, + "24": 1928889, + "26": 1925621, + "25": 1922850, + "23": 1910880 + }, + "bottom_20": { + "495": 10297, + "497": 10264, + "504": 10076, + "502": 10061, + "501": 9974, + "498": 9959, + "500": 9869, + "505": 9859, + "514": 9841, + "503": 9818, + "499": 9813, + "515": 9795, + "506": 9651, + "512": 9599, + "513": 9589, + "508": 9407, + "509": 9374, + "507": 9211, + "510": 9162, + "511": 9129 + } + }, + "velocity": { + "total_tokens": 165165523, + "unique_tokens": 33, + "token_counts": { + "0": 167245, + "23": 81316692, + "27": 24139232, + "19": 12080765, + "16": 10648704, + "17": 464538, + "18": 1114141, + "20": 788683, + "21": 1224845, + "22": 1770942, + "24": 811770, + "25": 1320532, + "31": 10823520, + "35": 7024913, + "28": 1979649, + "30": 584551, + "32": 1130883, + "12": 2746364, + "13": 342243, + "14": 275035, + "26": 1357064, + "29": 707198, + "33": 589593, + "34": 233771, + "7": 458081, + "5": 14527, + "6": 55253, + "8": 114433, + "9": 97242, + "10": 118225, + "11": 189100, + "15": 348278, + "4": 127511 + }, + "top_20": { + "23": 81316692, + "27": 24139232, + "19": 12080765, + "31": 10823520, + "16": 10648704, + "35": 7024913, + "12": 2746364, + "28": 1979649, + "22": 1770942, + "26": 1357064, + "25": 1320532, + "21": 1224845, + "32": 1130883, + "18": 1114141, + "24": 811770, + "20": 788683, + "29": 707198, + "33": 589593, + "30": 584551, + "17": 464538 + }, + "bottom_20": { + "18": 1114141, + "24": 811770, + "20": 788683, + "29": 707198, + "33": 589593, + "30": 584551, + "17": 464538, + "7": 458081, + "15": 348278, + "13": 342243, + "14": 275035, + "34": 233771, + "11": 189100, + "0": 167245, + "4": 127511, + "10": 118225, + "8": 114433, + "9": 97242, + "6": 55253, + "5": 14527 + } + }, + "duration": { + "total_tokens": 165165523, + "unique_tokens": 65, + "token_counts": { + "0": 167245, + "4": 12488095, + "5": 29539137, + "6": 5468321, + "7": 53760996, + "9": 3217846, + "11": 34836051, + "13": 178465, + "14": 1237364, + "15": 2890926, + "19": 6828131, + "21": 58958, + "23": 368145, + "27": 2218271, + "31": 150849, + "35": 2528635, + "37": 95593, + "8": 1974671, + "16": 123613, + "22": 200891, + "24": 24825, + "30": 69089, + "51": 313819, + "39": 157450, + "29": 17636, + "36": 45731, + "43": 470949, + "18": 2575465, + "26": 825830, + "33": 1055556, + "42": 165884, + "50": 94399, + "25": 29815, + "38": 53095, + "12": 90256, + "47": 59608, + "46": 19939, + "61": 2857, + "65": 26656, + "66": 1852, + "49": 13944, + "10": 158375, + "20": 25053, + "67": 187607, + "59": 17534, + "17": 98534, + "34": 74286, + "55": 34994, + "53": 17193, + "44": 7677, + "41": 22064, + "40": 8736, + "45": 19258, + "28": 11415, + "32": 13414, + "57": 5302, + "63": 6889, + "48": 4437, + "52": 2792, + "58": 723, + "54": 1120, + "62": 308, + "56": 250, + "64": 408, + "60": 296 + }, + "top_20": { + "7": 53760996, + "11": 34836051, + "5": 29539137, + "4": 12488095, + "19": 6828131, + "6": 5468321, + "9": 3217846, + "15": 2890926, + "18": 2575465, + "35": 2528635, + "27": 2218271, + "8": 1974671, + "14": 1237364, + "33": 1055556, + "26": 825830, + "43": 470949, + "23": 368145, + "51": 313819, + "22": 200891, + "67": 187607 + }, + "bottom_20": { + "29": 17636, + "59": 17534, + "53": 17193, + "49": 13944, + "32": 13414, + "28": 11415, + "40": 8736, + "44": 7677, + "63": 6889, + "57": 5302, + "48": 4437, + "61": 2857, + "52": 2792, + "66": 1852, + "54": 1120, + "58": 723, + "64": 408, + "62": 308, + "60": 296, + "56": 250 + } + }, + "program": { + "total_tokens": 165165523, + "unique_tokens": 130, + "token_counts": { + "0": 167245, + "28": 3301108, + "4": 59258093, + "5": 769141, + "62": 3233651, + "60": 5225132, + "75": 5461270, + "72": 2064103, + "26": 292461, + "37": 1396020, + "44": 6045327, + "132": 15564529, + "38": 927510, + "46": 3275859, + "29": 1745633, + "56": 5504350, + "64": 3238555, + "6": 206060, + "25": 435268, + "8": 544276, + "13": 700995, + "15": 1156550, + "16": 1907690, + "61": 4263879, + "49": 642853, + "51": 1159836, + "77": 5424393, + "99": 11962, + "45": 2209485, + "47": 1001286, + "52": 4570089, + "58": 119589, + "89": 89515, + "69": 3287914, + "70": 2257240, + "71": 1453269, + "74": 1606358, + "93": 37570, + "10": 625836, + "30": 360414, + "57": 1020133, + "23": 803689, + "31": 1380327, + "33": 407852, + "109": 139525, + "32": 161635, + "41": 32395, + "118": 158954, + "126": 5253, + "22": 94842, + "34": 471228, + "50": 886071, + "76": 925496, + "78": 601005, + "36": 736109, + "65": 163119, + "7": 158309, + "68": 578403, + "12": 95747, + "18": 186648, + "17": 772783, + "39": 103682, + "73": 203411, + "84": 281778, + "35": 17405, + "40": 54679, + "91": 84676, + "53": 149810, + "48": 140946, + "120": 20625, + "123": 2449, + "95": 74092, + "96": 15149, + "42": 154486, + "66": 100824, + "79": 93456, + "21": 51388, + "9": 94397, + "54": 133412, + "92": 66872, + "63": 65206, + "43": 167023, + "87": 9526, + "55": 62992, + "97": 28656, + "116": 69288, + "121": 57895, + "85": 266753, + "86": 36630, + "14": 66854, + "83": 71230, + "24": 102710, + "102": 17706, + "20": 99019, + "106": 16725, + "19": 58439, + "67": 25017, + "114": 54261, + "113": 43169, + "11": 51594, + "103": 46231, + "111": 117330, + "88": 27214, + "108": 25586, + "104": 48201, + "112": 26273, + "27": 39711, + "94": 38185, + "98": 17559, + "125": 4107, + "129": 379, + "122": 11770, + "100": 70196, + "107": 17827, + "115": 6629, + "82": 38784, + "110": 20594, + "59": 17149, + "80": 6675, + "81": 12382, + "101": 7964, + "124": 650, + "127": 703, + "128": 1004, + "131": 4577, + "119": 8118, + "105": 5155, + "90": 4509, + "130": 1374, + "117": 2620 + }, + "top_20": { + "4": 59258093, + "132": 15564529, + "44": 6045327, + "56": 5504350, + "75": 5461270, + "77": 5424393, + "60": 5225132, + "52": 4570089, + "61": 4263879, + "28": 3301108, + "69": 3287914, + "46": 3275859, + "64": 3238555, + "62": 3233651, + "70": 2257240, + "45": 2209485, + "72": 2064103, + "16": 1907690, + "29": 1745633, + "74": 1606358 + }, + "bottom_20": { + "81": 12382, + "99": 11962, + "122": 11770, + "87": 9526, + "119": 8118, + "101": 7964, + "80": 6675, + "115": 6629, + "126": 5253, + "105": 5155, + "131": 4577, + "90": 4509, + "125": 4107, + "117": 2620, + "123": 2449, + "130": 1374, + "128": 1004, + "127": 703, + "124": 650, + "129": 379 + } + }, + "tempo": { + "total_tokens": 165165523, + "unique_tokens": 33, + "token_counts": { + "0": 167245, + "9": 5790437, + "19": 6942952, + "13": 8972232, + "23": 3797354, + "16": 45114808, + "8": 4962026, + "11": 8326936, + "14": 6860103, + "25": 4957658, + "22": 6648928, + "35": 1923315, + "10": 8859942, + "15": 5131246, + "4": 1457184, + "17": 7387116, + "5": 609770, + "6": 2047425, + "7": 4177507, + "34": 1214421, + "20": 5542314, + "26": 2222934, + "21": 2939811, + "12": 5732199, + "18": 5464663, + "28": 2637904, + "29": 920327, + "30": 471799, + "27": 776090, + "24": 1696832, + "31": 941730, + "32": 356430, + "33": 113885 + }, + "top_20": { + "16": 45114808, + "13": 8972232, + "10": 8859942, + "11": 8326936, + "17": 7387116, + "19": 6942952, + "14": 6860103, + "22": 6648928, + "9": 5790437, + "12": 5732199, + "20": 5542314, + "18": 5464663, + "15": 5131246, + "8": 4962026, + "25": 4957658, + "7": 4177507, + "23": 3797354, + "21": 2939811, + "28": 2637904, + "26": 2222934 + }, + "bottom_20": { + "8": 4962026, + "25": 4957658, + "7": 4177507, + "23": 3797354, + "21": 2939811, + "28": 2637904, + "26": 2222934, + "6": 2047425, + "35": 1923315, + "24": 1696832, + "4": 1457184, + "34": 1214421, + "31": 941730, + "29": 920327, + "27": 776090, + "5": 609770, + "30": 471799, + "32": 356430, + "0": 167245, + "33": 113885 + } + }, + "timesig": { + "total_tokens": 165165523, + "unique_tokens": 10, + "token_counts": { + "0": 167245, + "10": 9427532, + "12": 122569379, + "11": 747128, + "9": 16198618, + "6": 8281749, + "4": 1254214, + "8": 1649967, + "5": 3323190, + "7": 1546501 + }, + "top_20": { + "12": 122569379, + "9": 16198618, + "10": 9427532, + "6": 8281749, + "5": 3323190, + "8": 1649967, + "7": 1546501, + "4": 1254214, + "11": 747128, + "0": 167245 + }, + "bottom_20": { + "12": 122569379, + "9": 16198618, + "10": 9427532, + "6": 8281749, + "5": 3323190, + "8": 1649967, + "7": 1546501, + "4": 1254214, + "11": 747128, + "0": 167245 + } + } + } +} \ No newline at end of file diff --git a/octuple_token_analysis_report_part.txt b/octuple_token_analysis_report_part.txt new file mode 100644 index 0000000..6d3a84d --- /dev/null +++ b/octuple_token_analysis_report_part.txt @@ -0,0 +1,467 @@ +================================================================================ +OCTUPLE分词结果统计分析报告 +================================================================================ + +总token数: 1,321,324,184 +分析的列数: 8 + +-------------------------------------------------------------------------------- +列 0: pitch +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 152 + Token值范围: [1, 154] + 平均出现次数: 1086615.28 + + Top 20 最常见的token: + 1. Token 50: 6,756,218 次 ( 4.09%) + 2. Token 45: 6,319,724 次 ( 3.83%) + 3. Token 52: 6,124,177 次 ( 3.71%) + 4. Token 43: 5,553,458 次 ( 3.36%) + 5. Token 47: 5,378,723 次 ( 3.26%) + 6. Token 48: 5,238,197 次 ( 3.17%) + 7. Token 55: 5,193,681 次 ( 3.14%) + 8. Token 57: 5,063,741 次 ( 3.07%) + 9. Token 38: 4,698,214 次 ( 2.84%) + 10. Token 40: 4,524,339 次 ( 2.74%) + 11. Token 54: 4,114,700 次 ( 2.49%) + 12. Token 53: 3,608,951 次 ( 2.19%) + 13. Token 59: 3,596,239 次 ( 2.18%) + 14. Token 33: 3,557,109 次 ( 2.15%) + 15. Token 42: 3,443,881 次 ( 2.09%) + 16. Token 41: 3,389,405 次 ( 2.05%) + 17. Token 46: 3,339,383 次 ( 2.02%) + 18. Token 49: 3,305,354 次 ( 2.00%) + 19. Token 36: 3,291,100 次 ( 1.99%) + 20. Token 31: 3,182,789 次 ( 1.93%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 154: 2,893 次 ( 0.00%) + 2. Token 100: 3,087 次 ( 0.00%) + 3. Token 92: 3,928 次 ( 0.00%) + 4. Token 98: 3,947 次 ( 0.00%) + 5. Token 153: 5,693 次 ( 0.00%) + 6. Token 144: 6,163 次 ( 0.00%) + 7. Token 152: 6,992 次 ( 0.00%) + 8. Token 124: 7,025 次 ( 0.00%) + 9. Token 99: 7,149 次 ( 0.00%) + 10. Token 97: 8,181 次 ( 0.00%) + 11. Token 95: 8,772 次 ( 0.01%) + 12. Token 94: 10,733 次 ( 0.01%) + 13. Token 96: 12,211 次 ( 0.01%) + 14. Token 90: 14,891 次 ( 0.01%) + 15. Token 132: 15,082 次 ( 0.01%) + 16. Token 150: 15,272 次 ( 0.01%) + 17. Token 145: 16,302 次 ( 0.01%) + 18. Token 137: 16,361 次 ( 0.01%) + 19. Token 91: 17,313 次 ( 0.01%) + 20. Token 4: 22,345 次 ( 0.01%) + + 分布统计: + 最小出现次数: 2,893 + 最大出现次数: 6,756,218 + 中位数出现次数: 254,923 + 标准差: 1,566,812.91 + +-------------------------------------------------------------------------------- +列 1: position +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 97 + Token值范围: [0, 99] + 平均出现次数: 1702737.35 + + Top 20 最常见的token: + 1. Token 4: 32,574,271 次 (19.72%) + 2. Token 20: 21,155,511 次 (12.81%) + 3. Token 12: 18,937,356 次 (11.47%) + 4. Token 28: 16,053,301 次 ( 9.72%) + 5. Token 16: 12,662,506 次 ( 7.67%) + 6. Token 8: 10,194,883 次 ( 6.17%) + 7. Token 24: 9,383,102 次 ( 5.68%) + 8. Token 32: 9,043,069 次 ( 5.48%) + 9. Token 10: 3,268,085 次 ( 1.98%) + 10. Token 18: 2,956,264 次 ( 1.79%) + 11. Token 26: 2,464,244 次 ( 1.49%) + 12. Token 14: 2,308,328 次 ( 1.40%) + 13. Token 6: 2,216,672 次 ( 1.34%) + 14. Token 34: 2,047,098 次 ( 1.24%) + 15. Token 22: 2,020,825 次 ( 1.22%) + 16. Token 30: 1,808,841 次 ( 1.10%) + 17. Token 44: 1,473,726 次 ( 0.89%) + 18. Token 36: 1,430,246 次 ( 0.87%) + 19. Token 17: 1,159,945 次 ( 0.70%) + 20. Token 9: 1,108,378 次 ( 0.67%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 99: 333 次 ( 0.00%) + 2. Token 59: 591 次 ( 0.00%) + 3. Token 91: 672 次 ( 0.00%) + 4. Token 61: 732 次 ( 0.00%) + 5. Token 67: 795 次 ( 0.00%) + 6. Token 75: 798 次 ( 0.00%) + 7. Token 53: 804 次 ( 0.00%) + 8. Token 97: 1,067 次 ( 0.00%) + 9. Token 65: 1,075 次 ( 0.00%) + 10. Token 69: 1,083 次 ( 0.00%) + 11. Token 83: 1,092 次 ( 0.00%) + 12. Token 85: 1,152 次 ( 0.00%) + 13. Token 77: 1,200 次 ( 0.00%) + 14. Token 93: 1,203 次 ( 0.00%) + 15. Token 87: 1,206 次 ( 0.00%) + 16. Token 81: 1,211 次 ( 0.00%) + 17. Token 55: 1,276 次 ( 0.00%) + 18. Token 57: 1,288 次 ( 0.00%) + 19. Token 71: 1,346 次 ( 0.00%) + 20. Token 89: 1,411 次 ( 0.00%) + + 分布统计: + 最小出现次数: 333 + 最大出现次数: 32,574,271 + 中位数出现次数: 28,125 + 标准差: 4,909,414.79 + +-------------------------------------------------------------------------------- +列 2: bar +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 513 + Token值范围: [0, 515] + 平均出现次数: 321960.08 + + Top 20 最常见的token: + 1. Token 14: 2,130,472 次 ( 1.29%) + 2. Token 10: 2,127,003 次 ( 1.29%) + 3. Token 18: 2,125,751 次 ( 1.29%) + 4. Token 16: 2,112,689 次 ( 1.28%) + 5. Token 17: 2,101,670 次 ( 1.27%) + 6. Token 12: 2,077,660 次 ( 1.26%) + 7. Token 13: 2,077,279 次 ( 1.26%) + 8. Token 15: 2,066,264 次 ( 1.25%) + 9. Token 9: 2,055,179 次 ( 1.24%) + 10. Token 8: 2,039,302 次 ( 1.23%) + 11. Token 11: 2,017,531 次 ( 1.22%) + 12. Token 19: 2,013,280 次 ( 1.22%) + 13. Token 20: 1,985,896 次 ( 1.20%) + 14. Token 22: 1,967,164 次 ( 1.19%) + 15. Token 21: 1,952,546 次 ( 1.18%) + 16. Token 6: 1,938,737 次 ( 1.17%) + 17. Token 24: 1,928,889 次 ( 1.17%) + 18. Token 26: 1,925,621 次 ( 1.17%) + 19. Token 25: 1,922,850 次 ( 1.16%) + 20. Token 23: 1,910,880 次 ( 1.16%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 511: 9,129 次 ( 0.01%) + 2. Token 510: 9,162 次 ( 0.01%) + 3. Token 507: 9,211 次 ( 0.01%) + 4. Token 509: 9,374 次 ( 0.01%) + 5. Token 508: 9,407 次 ( 0.01%) + 6. Token 513: 9,589 次 ( 0.01%) + 7. Token 512: 9,599 次 ( 0.01%) + 8. Token 506: 9,651 次 ( 0.01%) + 9. Token 515: 9,795 次 ( 0.01%) + 10. Token 499: 9,813 次 ( 0.01%) + 11. Token 503: 9,818 次 ( 0.01%) + 12. Token 514: 9,841 次 ( 0.01%) + 13. Token 505: 9,859 次 ( 0.01%) + 14. Token 500: 9,869 次 ( 0.01%) + 15. Token 498: 9,959 次 ( 0.01%) + 16. Token 501: 9,974 次 ( 0.01%) + 17. Token 502: 10,061 次 ( 0.01%) + 18. Token 504: 10,076 次 ( 0.01%) + 19. Token 497: 10,264 次 ( 0.01%) + 20. Token 495: 10,297 次 ( 0.01%) + + 分布统计: + 最小出现次数: 9,129 + 最大出现次数: 2,130,472 + 中位数出现次数: 58,000 + 标准差: 542,952.62 + +-------------------------------------------------------------------------------- +列 3: velocity +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 33 + Token值范围: [0, 35] + 平均出现次数: 5005015.85 + + Top 20 最常见的token: + 1. Token 23: 81,316,692 次 (49.23%) + 2. Token 27: 24,139,232 次 (14.62%) + 3. Token 19: 12,080,765 次 ( 7.31%) + 4. Token 31: 10,823,520 次 ( 6.55%) + 5. Token 16: 10,648,704 次 ( 6.45%) + 6. Token 35: 7,024,913 次 ( 4.25%) + 7. Token 12: 2,746,364 次 ( 1.66%) + 8. Token 28: 1,979,649 次 ( 1.20%) + 9. Token 22: 1,770,942 次 ( 1.07%) + 10. Token 26: 1,357,064 次 ( 0.82%) + 11. Token 25: 1,320,532 次 ( 0.80%) + 12. Token 21: 1,224,845 次 ( 0.74%) + 13. Token 32: 1,130,883 次 ( 0.68%) + 14. Token 18: 1,114,141 次 ( 0.67%) + 15. Token 24: 811,770 次 ( 0.49%) + 16. Token 20: 788,683 次 ( 0.48%) + 17. Token 29: 707,198 次 ( 0.43%) + 18. Token 33: 589,593 次 ( 0.36%) + 19. Token 30: 584,551 次 ( 0.35%) + 20. Token 17: 464,538 次 ( 0.28%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 5: 14,527 次 ( 0.01%) + 2. Token 6: 55,253 次 ( 0.03%) + 3. Token 9: 97,242 次 ( 0.06%) + 4. Token 8: 114,433 次 ( 0.07%) + 5. Token 10: 118,225 次 ( 0.07%) + 6. Token 4: 127,511 次 ( 0.08%) + 7. Token 0: 167,245 次 ( 0.10%) + 8. Token 11: 189,100 次 ( 0.11%) + 9. Token 34: 233,771 次 ( 0.14%) + 10. Token 14: 275,035 次 ( 0.17%) + 11. Token 13: 342,243 次 ( 0.21%) + 12. Token 15: 348,278 次 ( 0.21%) + 13. Token 7: 458,081 次 ( 0.28%) + 14. Token 17: 464,538 次 ( 0.28%) + 15. Token 30: 584,551 次 ( 0.35%) + 16. Token 33: 589,593 次 ( 0.36%) + 17. Token 29: 707,198 次 ( 0.43%) + 18. Token 20: 788,683 次 ( 0.48%) + 19. Token 24: 811,770 次 ( 0.49%) + 20. Token 18: 1,114,141 次 ( 0.67%) + + 分布统计: + 最小出现次数: 14,527 + 最大出现次数: 81,316,692 + 中位数出现次数: 707,198 + 标准差: 14,375,775.97 + +-------------------------------------------------------------------------------- +列 4: duration +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 65 + Token值范围: [0, 67] + 平均出现次数: 2541008.05 + + Top 20 最常见的token: + 1. Token 7: 53,760,996 次 (32.55%) + 2. Token 11: 34,836,051 次 (21.09%) + 3. Token 5: 29,539,137 次 (17.88%) + 4. Token 4: 12,488,095 次 ( 7.56%) + 5. Token 19: 6,828,131 次 ( 4.13%) + 6. Token 6: 5,468,321 次 ( 3.31%) + 7. Token 9: 3,217,846 次 ( 1.95%) + 8. Token 15: 2,890,926 次 ( 1.75%) + 9. Token 18: 2,575,465 次 ( 1.56%) + 10. Token 35: 2,528,635 次 ( 1.53%) + 11. Token 27: 2,218,271 次 ( 1.34%) + 12. Token 8: 1,974,671 次 ( 1.20%) + 13. Token 14: 1,237,364 次 ( 0.75%) + 14. Token 33: 1,055,556 次 ( 0.64%) + 15. Token 26: 825,830 次 ( 0.50%) + 16. Token 43: 470,949 次 ( 0.29%) + 17. Token 23: 368,145 次 ( 0.22%) + 18. Token 51: 313,819 次 ( 0.19%) + 19. Token 22: 200,891 次 ( 0.12%) + 20. Token 67: 187,607 次 ( 0.11%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 56: 250 次 ( 0.00%) + 2. Token 60: 296 次 ( 0.00%) + 3. Token 62: 308 次 ( 0.00%) + 4. Token 64: 408 次 ( 0.00%) + 5. Token 58: 723 次 ( 0.00%) + 6. Token 54: 1,120 次 ( 0.00%) + 7. Token 66: 1,852 次 ( 0.00%) + 8. Token 52: 2,792 次 ( 0.00%) + 9. Token 61: 2,857 次 ( 0.00%) + 10. Token 48: 4,437 次 ( 0.00%) + 11. Token 57: 5,302 次 ( 0.00%) + 12. Token 63: 6,889 次 ( 0.00%) + 13. Token 44: 7,677 次 ( 0.00%) + 14. Token 40: 8,736 次 ( 0.01%) + 15. Token 28: 11,415 次 ( 0.01%) + 16. Token 32: 13,414 次 ( 0.01%) + 17. Token 49: 13,944 次 ( 0.01%) + 18. Token 53: 17,193 次 ( 0.01%) + 19. Token 59: 17,534 次 ( 0.01%) + 20. Token 29: 17,636 次 ( 0.01%) + + 分布统计: + 最小出现次数: 250 + 最大出现次数: 53,760,996 + 中位数出现次数: 69,089 + 标准差: 8,623,586.00 + +-------------------------------------------------------------------------------- +列 5: program +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 130 + Token值范围: [0, 132] + 平均出现次数: 1270504.02 + + Top 20 最常见的token: + 1. Token 4: 59,258,093 次 (35.88%) + 2. Token 132: 15,564,529 次 ( 9.42%) + 3. Token 44: 6,045,327 次 ( 3.66%) + 4. Token 56: 5,504,350 次 ( 3.33%) + 5. Token 75: 5,461,270 次 ( 3.31%) + 6. Token 77: 5,424,393 次 ( 3.28%) + 7. Token 60: 5,225,132 次 ( 3.16%) + 8. Token 52: 4,570,089 次 ( 2.77%) + 9. Token 61: 4,263,879 次 ( 2.58%) + 10. Token 28: 3,301,108 次 ( 2.00%) + 11. Token 69: 3,287,914 次 ( 1.99%) + 12. Token 46: 3,275,859 次 ( 1.98%) + 13. Token 64: 3,238,555 次 ( 1.96%) + 14. Token 62: 3,233,651 次 ( 1.96%) + 15. Token 70: 2,257,240 次 ( 1.37%) + 16. Token 45: 2,209,485 次 ( 1.34%) + 17. Token 72: 2,064,103 次 ( 1.25%) + 18. Token 16: 1,907,690 次 ( 1.16%) + 19. Token 29: 1,745,633 次 ( 1.06%) + 20. Token 74: 1,606,358 次 ( 0.97%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 129: 379 次 ( 0.00%) + 2. Token 124: 650 次 ( 0.00%) + 3. Token 127: 703 次 ( 0.00%) + 4. Token 128: 1,004 次 ( 0.00%) + 5. Token 130: 1,374 次 ( 0.00%) + 6. Token 123: 2,449 次 ( 0.00%) + 7. Token 117: 2,620 次 ( 0.00%) + 8. Token 125: 4,107 次 ( 0.00%) + 9. Token 90: 4,509 次 ( 0.00%) + 10. Token 131: 4,577 次 ( 0.00%) + 11. Token 105: 5,155 次 ( 0.00%) + 12. Token 126: 5,253 次 ( 0.00%) + 13. Token 115: 6,629 次 ( 0.00%) + 14. Token 80: 6,675 次 ( 0.00%) + 15. Token 101: 7,964 次 ( 0.00%) + 16. Token 119: 8,118 次 ( 0.00%) + 17. Token 87: 9,526 次 ( 0.01%) + 18. Token 122: 11,770 次 ( 0.01%) + 19. Token 99: 11,962 次 ( 0.01%) + 20. Token 81: 12,382 次 ( 0.01%) + + 分布统计: + 最小出现次数: 379 + 最大出现次数: 59,258,093 + 中位数出现次数: 101,767 + 标准差: 5,429,735.91 + +-------------------------------------------------------------------------------- +列 6: tempo +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 33 + Token值范围: [0, 35] + 平均出现次数: 5005015.85 + + Top 20 最常见的token: + 1. Token 16: 45,114,808 次 (27.31%) + 2. Token 13: 8,972,232 次 ( 5.43%) + 3. Token 10: 8,859,942 次 ( 5.36%) + 4. Token 11: 8,326,936 次 ( 5.04%) + 5. Token 17: 7,387,116 次 ( 4.47%) + 6. Token 19: 6,942,952 次 ( 4.20%) + 7. Token 14: 6,860,103 次 ( 4.15%) + 8. Token 22: 6,648,928 次 ( 4.03%) + 9. Token 9: 5,790,437 次 ( 3.51%) + 10. Token 12: 5,732,199 次 ( 3.47%) + 11. Token 20: 5,542,314 次 ( 3.36%) + 12. Token 18: 5,464,663 次 ( 3.31%) + 13. Token 15: 5,131,246 次 ( 3.11%) + 14. Token 8: 4,962,026 次 ( 3.00%) + 15. Token 25: 4,957,658 次 ( 3.00%) + 16. Token 7: 4,177,507 次 ( 2.53%) + 17. Token 23: 3,797,354 次 ( 2.30%) + 18. Token 21: 2,939,811 次 ( 1.78%) + 19. Token 28: 2,637,904 次 ( 1.60%) + 20. Token 26: 2,222,934 次 ( 1.35%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 33: 113,885 次 ( 0.07%) + 2. Token 0: 167,245 次 ( 0.10%) + 3. Token 32: 356,430 次 ( 0.22%) + 4. Token 30: 471,799 次 ( 0.29%) + 5. Token 5: 609,770 次 ( 0.37%) + 6. Token 27: 776,090 次 ( 0.47%) + 7. Token 29: 920,327 次 ( 0.56%) + 8. Token 31: 941,730 次 ( 0.57%) + 9. Token 34: 1,214,421 次 ( 0.74%) + 10. Token 4: 1,457,184 次 ( 0.88%) + 11. Token 24: 1,696,832 次 ( 1.03%) + 12. Token 35: 1,923,315 次 ( 1.16%) + 13. Token 6: 2,047,425 次 ( 1.24%) + 14. Token 26: 2,222,934 次 ( 1.35%) + 15. Token 28: 2,637,904 次 ( 1.60%) + 16. Token 21: 2,939,811 次 ( 1.78%) + 17. Token 23: 3,797,354 次 ( 2.30%) + 18. Token 7: 4,177,507 次 ( 2.53%) + 19. Token 25: 4,957,658 次 ( 3.00%) + 20. Token 8: 4,962,026 次 ( 3.00%) + + 分布统计: + 最小出现次数: 113,885 + 最大出现次数: 45,114,808 + 中位数出现次数: 3,797,354 + 标准差: 7,594,751.74 + +-------------------------------------------------------------------------------- +列 7: timesig +-------------------------------------------------------------------------------- + 总token数: 165,165,523 + 唯一token数: 10 + Token值范围: [0, 12] + 平均出现次数: 16516552.30 + + Top 20 最常见的token: + 1. Token 12: 122,569,379 次 (74.21%) + 2. Token 9: 16,198,618 次 ( 9.81%) + 3. Token 10: 9,427,532 次 ( 5.71%) + 4. Token 6: 8,281,749 次 ( 5.01%) + 5. Token 5: 3,323,190 次 ( 2.01%) + 6. Token 8: 1,649,967 次 ( 1.00%) + 7. Token 7: 1,546,501 次 ( 0.94%) + 8. Token 4: 1,254,214 次 ( 0.76%) + 9. Token 11: 747,128 次 ( 0.45%) + 10. Token 0: 167,245 次 ( 0.10%) + + Top 20 最不常见的token (出现次数>0): + 1. Token 0: 167,245 次 ( 0.10%) + 2. Token 11: 747,128 次 ( 0.45%) + 3. Token 4: 1,254,214 次 ( 0.76%) + 4. Token 7: 1,546,501 次 ( 0.94%) + 5. Token 8: 1,649,967 次 ( 1.00%) + 6. Token 5: 3,323,190 次 ( 2.01%) + 7. Token 6: 8,281,749 次 ( 5.01%) + 8. Token 10: 9,427,532 次 ( 5.71%) + 9. Token 9: 16,198,618 次 ( 9.81%) + 10. Token 12: 122,569,379 次 (74.21%) + + 分布统计: + 最小出现次数: 167,245 + 最大出现次数: 122,569,379 + 中位数出现次数: 2,486,578 + 标准差: 35,683,981.69 + +================================================================================ +跨列分析 +================================================================================ + + pitch : 165,165,523 tokens (12.50%) + position : 165,165,523 tokens (12.50%) + bar : 165,165,523 tokens (12.50%) + velocity : 165,165,523 tokens (12.50%) + duration : 165,165,523 tokens (12.50%) + program : 165,165,523 tokens (12.50%) + tempo : 165,165,523 tokens (12.50%) + timesig : 165,165,523 tokens (12.50%) + +================================================================================ +报告生成完成 +================================================================================ \ No newline at end of file