1127 update to latest
This commit is contained in:
@ -1,8 +1,18 @@
|
||||
from math import ceil
|
||||
from typing import Optional, Union, Literal
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.modules.module import Module
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.modules.activation import _is_make_fx_tracing, _check_arg_device, _arg_requires_grad
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, hidden_size, dropout):
|
||||
@ -49,6 +59,64 @@ class extendedMLP(nn.Module):
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class extendedMLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
|
||||
super().__init__()
|
||||
self.input_size = in_size
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
if num_layers == 1:
|
||||
# Only one layer
|
||||
self.layers.append(nn.Linear(in_size, out_size))
|
||||
return
|
||||
elif num_layers > 1:
|
||||
# First layer
|
||||
self.layers.append(nn.Linear(in_size, hidden_size))
|
||||
self.layers.append(nn.Dropout(dropout))
|
||||
self.layers.append(nn.ReLU())
|
||||
# Intermediate layers
|
||||
if num_layers > 2:
|
||||
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
|
||||
self.layers.append(nn.Linear(hidden_size, hidden_size))
|
||||
self.layers.append(nn.Dropout(dropout))
|
||||
self.layers.append(nn.ReLU())
|
||||
# Last layer
|
||||
self.layers.append(nn.Linear(hidden_size, out_size))
|
||||
else:
|
||||
raise ValueError("num_layers should be a positive integer")
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, in_size, out_size, num_layers=2, hidden_size=2048, dropout=0.1):
|
||||
super().__init__()
|
||||
self.input_size = in_size
|
||||
|
||||
if num_layers == 1:
|
||||
# 单层情况,直接线性映射
|
||||
self.ffn = nn.Linear(in_size, out_size)
|
||||
elif num_layers == 2:
|
||||
# 两层时使用 SwiGLU
|
||||
self.w1 = nn.Linear(in_size, 2 * hidden_size) # 前半主分支,后半门控分支
|
||||
self.w2 = nn.Linear(hidden_size, out_size)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
raise ValueError("SwiGLU FFN 仅支持 num_layers=1 或 2")
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, "ffn"):
|
||||
return self.ffn(x)
|
||||
else:
|
||||
x_proj = self.w1(x)
|
||||
x_main, x_gate = x_proj.chunk(2, dim=-1) # 一分为二
|
||||
x = F.silu(x_main) * x_gate # SwiGLU: silu(a) * b
|
||||
x = self.dropout(x)
|
||||
x = self.w2(x)
|
||||
return x
|
||||
|
||||
class multiMLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
|
||||
super().__init__()
|
||||
@ -209,6 +277,28 @@ class TransformerLayer(nn.Module):
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||
return output_dict
|
||||
|
||||
class TransformerLayerV2(nn.Module):
|
||||
def __init__(self, dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
self.self_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||
self.cross_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||
self.residual_FF = ResidualLayerNormModule(SwiGLUFFN(in_size=dim, out_size=dim, num_layers=2, hidden_size=4*dim, dropout=dropout))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, input_dict):
|
||||
'''
|
||||
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||||
'''
|
||||
# self attention
|
||||
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
|
||||
|
||||
input_dict['input_seq'] = attn_output
|
||||
# cross attention
|
||||
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
|
||||
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||
attn_output = self.dropout(attn_output)
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||
return output_dict
|
||||
class FeatureEnricher(nn.Module):
|
||||
def __init__(self, dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
@ -226,3 +316,482 @@ class FeatureEnricher(nn.Module):
|
||||
attn_output = self.dropout(attn_output)
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
|
||||
return output_dict
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information from different representation subspaces.
|
||||
|
||||
.. note::
|
||||
See `this tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
|
||||
for an in depth discussion of the performant building blocks PyTorch offers for building your own
|
||||
transformer layers.
|
||||
|
||||
Method described in the paper:
|
||||
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Multi-Head Attention is defined as:
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O
|
||||
|
||||
where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
||||
|
||||
``nn.MultiheadAttention`` will use the optimized implementations of
|
||||
``scaled_dot_product_attention()`` when possible.
|
||||
|
||||
In addition to support for the new ``scaled_dot_product_attention()``
|
||||
function, for speeding up Inference, MHA will use
|
||||
fastpath inference with support for Nested Tensors, iff:
|
||||
|
||||
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
|
||||
- inputs are batched (3D) with ``batch_first==True``
|
||||
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
||||
- training is disabled (using ``.eval()``)
|
||||
- ``add_bias_kv`` is ``False``
|
||||
- ``add_zero_attn`` is ``False``
|
||||
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
||||
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
||||
nor ``attn_mask`` is passed
|
||||
- autocast is disabled
|
||||
|
||||
If the optimized inference fastpath implementation is in use, a
|
||||
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
||||
``query``/``key``/``value`` to represent padding more efficiently than using a
|
||||
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
||||
will be returned, and an additional speedup proportional to the fraction of the input
|
||||
that is padding can be expected.
|
||||
|
||||
Args:
|
||||
embed_dim: Total dimension of the model.
|
||||
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
||||
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
||||
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
||||
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
||||
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
||||
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
||||
Default: ``False``.
|
||||
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
||||
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
||||
batch_first: If ``True``, then the input and output tensors are provided
|
||||
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
||||
https://arxiv.org/abs/2205.14135
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
if embed_dim <= 0 or num_heads <= 0:
|
||||
raise ValueError(
|
||||
f"embed_dim and num_heads must be greater than 0,"
|
||||
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
|
||||
)
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
if self._qkv_same_embed_dim:
|
||||
xavier_uniform_(self.in_proj_weight)
|
||||
else:
|
||||
xavier_uniform_(self.q_proj_weight)
|
||||
xavier_uniform_(self.k_proj_weight)
|
||||
xavier_uniform_(self.v_proj_weight)
|
||||
|
||||
if self.in_proj_bias is not None:
|
||||
constant_(self.in_proj_bias, 0.0)
|
||||
constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
def __setstate__(self, state):
|
||||
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||||
if "_qkv_same_embed_dim" not in state:
|
||||
state["_qkv_same_embed_dim"] = True
|
||||
|
||||
super().__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[Tensor, Optional[Tensor]]:
|
||||
r"""Compute attention outputs using query, key, and value embeddings.
|
||||
|
||||
Supports optional parameters for padding, masks and attention weights.
|
||||
|
||||
Args:
|
||||
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
||||
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
||||
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
||||
Queries are compared against key-value pairs to produce the output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
||||
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
||||
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
||||
See "Attention Is All You Need" for more details.
|
||||
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
||||
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
||||
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
||||
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
||||
Binary and float masks are supported.
|
||||
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
||||
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
||||
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
||||
Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
|
||||
and achieve the best performance for MHA.
|
||||
Default: ``True``.
|
||||
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
||||
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
||||
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
||||
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
||||
Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
||||
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
||||
the attention weight.
|
||||
If both attn_mask and key_padding_mask are supplied, their types should match.
|
||||
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
||||
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
||||
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
||||
is_causal: If specified, applies a causal mask as attention mask.
|
||||
Default: ``False``.
|
||||
Warning:
|
||||
``is_causal`` provides a hint that ``attn_mask`` is the
|
||||
causal mask. Providing incorrect hints can result in
|
||||
incorrect execution, including forward and backward
|
||||
compatibility.
|
||||
|
||||
Outputs:
|
||||
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
||||
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
||||
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
||||
embedding dimension ``embed_dim``.
|
||||
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
||||
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
||||
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
||||
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
||||
|
||||
.. note::
|
||||
`batch_first` argument is ignored for unbatched inputs.
|
||||
""" # noqa: B950
|
||||
why_not_fast_path = ""
|
||||
if (
|
||||
(attn_mask is not None and torch.is_floating_point(attn_mask))
|
||||
or (key_padding_mask is not None)
|
||||
and torch.is_floating_point(key_padding_mask)
|
||||
):
|
||||
why_not_fast_path = "floating-point masks are not supported for fast path."
|
||||
|
||||
is_batched = query.dim() == 3
|
||||
|
||||
key_padding_mask = F._canonical_mask(
|
||||
mask=key_padding_mask,
|
||||
mask_name="key_padding_mask",
|
||||
other_type=F._none_or_dtype(attn_mask),
|
||||
other_name="attn_mask",
|
||||
target_type=query.dtype,
|
||||
)
|
||||
|
||||
attn_mask = F._canonical_mask(
|
||||
mask=attn_mask,
|
||||
mask_name="attn_mask",
|
||||
other_type=None,
|
||||
other_name="",
|
||||
target_type=query.dtype,
|
||||
check_other=False,
|
||||
)
|
||||
|
||||
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
||||
|
||||
if not is_fastpath_enabled:
|
||||
why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
||||
elif not is_batched:
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif self.in_proj_weight is None:
|
||||
why_not_fast_path = "in_proj_weight was None"
|
||||
elif query.dtype != self.in_proj_weight.dtype:
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif (self.num_heads % 2) != 0:
|
||||
why_not_fast_path = "self.num_heads is not even"
|
||||
elif not self.batch_first:
|
||||
why_not_fast_path = "batch_first was not True"
|
||||
elif self.bias_k is not None:
|
||||
why_not_fast_path = "self.bias_k was not None"
|
||||
elif self.bias_v is not None:
|
||||
why_not_fast_path = "self.bias_v was not None"
|
||||
elif self.add_zero_attn:
|
||||
why_not_fast_path = "add_zero_attn was enabled"
|
||||
elif not self._qkv_same_embed_dim:
|
||||
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
||||
elif query.is_nested and (
|
||||
key_padding_mask is not None or attn_mask is not None
|
||||
):
|
||||
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
|
||||
is not supported with NestedTensor input"
|
||||
elif torch.is_autocast_enabled():
|
||||
why_not_fast_path = "autocast is enabled"
|
||||
|
||||
if not why_not_fast_path:
|
||||
tensor_args = (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
)
|
||||
# We have to use list comprehensions below because TorchScript does not support
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif _is_make_fx_tracing():
|
||||
why_not_fast_path = "we are running make_fx tracing"
|
||||
elif not all(_check_arg_device(x) for x in tensor_args):
|
||||
why_not_fast_path = (
|
||||
"some Tensor argument's device is neither one of "
|
||||
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
|
||||
)
|
||||
elif torch.is_grad_enabled() and any(
|
||||
_arg_requires_grad(x) for x in tensor_args
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
if not why_not_fast_path:
|
||||
merged_mask, mask_type = self.merge_masks(
|
||||
attn_mask, key_padding_mask, query
|
||||
)
|
||||
|
||||
if self.in_proj_bias is not None and self.in_proj_weight is not None:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
merged_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
mask_type,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
assert not any_nested, (
|
||||
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
||||
+ f"The fast path was not hit because {why_not_fast_path}"
|
||||
)
|
||||
|
||||
if self.batch_first and is_batched:
|
||||
# make sure that the transpose op does not affect the "is" property
|
||||
if key is value:
|
||||
if query is key:
|
||||
query = key = value = query.transpose(1, 0)
|
||||
else:
|
||||
query, key = (x.transpose(1, 0) for x in (query, key))
|
||||
value = key
|
||||
else:
|
||||
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj_weight,
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if self.batch_first and is_batched:
|
||||
return attn_output.transpose(1, 0), attn_output_weights
|
||||
else:
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
def merge_masks(
|
||||
self,
|
||||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor],
|
||||
query: Tensor,
|
||||
) -> tuple[Optional[Tensor], Optional[int]]:
|
||||
r"""Determine mask type and combine masks if necessary.
|
||||
|
||||
If only one mask is provided, that mask
|
||||
and the corresponding mask type will be returned. If both masks are provided, they will be both
|
||||
expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
|
||||
and mask type 2 will be returned
|
||||
Args:
|
||||
attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
|
||||
key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
|
||||
query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
|
||||
Returns:
|
||||
merged_mask: merged mask
|
||||
mask_type: merged mask type (0, 1, or 2)
|
||||
"""
|
||||
mask_type: Optional[int] = None
|
||||
merged_mask: Optional[Tensor] = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
mask_type = 1
|
||||
merged_mask = key_padding_mask
|
||||
|
||||
if attn_mask is not None:
|
||||
# In this branch query can't be a nested tensor, so it has a shape
|
||||
batch_size, seq_len, _ = query.shape
|
||||
mask_type = 2
|
||||
|
||||
# Always expands attn_mask to 4D
|
||||
if attn_mask.dim() == 3:
|
||||
attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
|
||||
else: # attn_mask.dim() == 2:
|
||||
attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
|
||||
batch_size, self.num_heads, -1, -1
|
||||
)
|
||||
merged_mask = attn_mask_expanded
|
||||
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask_expanded = key_padding_mask.view(
|
||||
batch_size, 1, 1, seq_len
|
||||
).expand(-1, self.num_heads, -1, -1)
|
||||
merged_mask = attn_mask_expanded + key_padding_mask_expanded
|
||||
|
||||
# no attn_mask and no key_padding_mask, returns None, None
|
||||
return merged_mask, mask_type
|
||||
|
||||
|
||||
|
||||
|
||||
@ -347,13 +347,24 @@ class SelfAttention(SubDecoderClass):
|
||||
causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size)
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
|
||||
# self.transformer_decoder = Decoder(
|
||||
# dim = dim,
|
||||
# depth = sub_decoder_depth,
|
||||
# heads = heads,
|
||||
# attn_dropout = dropout,
|
||||
# ff_dropout = dropout,
|
||||
# attn_flash = True)
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
depth = sub_decoder_depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True)
|
||||
attn_flash = True,
|
||||
use_rmsnorm=True,
|
||||
ff_swish = True, # set this to True
|
||||
ff_glu = True, # set to true to use for all feedforwards
|
||||
)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
@ -713,7 +724,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
denoising_steps:int = 8,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1091,7 +1102,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
class DiffusionDecoder(SubDecoderClass):
|
||||
class DiffusionDecoderV2(SubDecoderClass):
|
||||
def __init__(
|
||||
self,
|
||||
prediction_order:list,
|
||||
@ -1102,7 +1113,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
denoising_steps:int = 8,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1129,7 +1140,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
self.input_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
if sub_decoder_enricher_use:
|
||||
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||||
@ -1138,15 +1149,22 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
self.register_buffer('causal_ca_mask', causal_ca_mask)
|
||||
|
||||
# get depth of the sub-decoder
|
||||
if sub_decoder_depth > 1:
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
else:
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
|
||||
if sub_decoder_enricher_use:
|
||||
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
|
||||
self.aux_ar_decoder = SelfAttention(prediction_order=prediction_order,
|
||||
vocab=vocab,
|
||||
sub_decoder_depth=1,
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dropout=dropout,
|
||||
sub_decoder_enricher_use=False)
|
||||
|
||||
# simplified version of the forward process in diffusion model
|
||||
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
|
||||
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
|
||||
@ -1273,9 +1291,11 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None, aux_ar=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
copy_input_dict = input_dict.copy()
|
||||
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
|
||||
@ -1307,6 +1327,10 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
if aux_ar: # inference with auxiliary auto-regressive decoder
|
||||
aux_ar_logits, sampled_token_dict = self.aux_ar_decoder(copy_input_dict, sampling_method='auto-regressive', threshold=threshold, temperature=temperature, condition_step=condition_step)
|
||||
# print("aux_ar_logits", aux_ar_logits)
|
||||
return aux_ar_logits, sampled_token_dict
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
@ -1420,4 +1444,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
# get aux ar decoder logits
|
||||
aux_ar_logits = self.aux_ar_decoder(copy_input_dict, target) # B x T
|
||||
return (logits_dict, aux_ar_logits), (masked_indices, p_mask)
|
||||
# return logits_dict, (masked_indices, p_mask)
|
||||
@ -3,6 +3,7 @@ import random
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
from typing import Union, List, Tuple, Dict
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -157,6 +158,8 @@ class IterTuneCompiler(IterableDataset):
|
||||
segment = self.augmentor(segment)
|
||||
# use input_ids replace tune_name
|
||||
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
|
||||
# print(segment.shape, mask.shape, tune_name.shape)
|
||||
# segment = segment[torch.randperm(segment.size(0))]
|
||||
yield segment, mask, tune_name, encoded_caption
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
defaults:
|
||||
# - nn_params: nb8_embSum_NMT
|
||||
# - nn_params: remi8
|
||||
- nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
|
||||
# - nn_params: oct8_embSum_diff_t2m_300M_pretrainingv3
|
||||
# - nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
|
||||
- nn_params: oct8_embSum_har_t2m_600M_pretrainingv3
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
# - nn_params: nb8_embSum_subPararell
|
||||
@ -15,7 +17,7 @@ defaults:
|
||||
# - nn_params: remi8_main12_head_16_dim512
|
||||
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||
|
||||
dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||
@ -23,28 +25,28 @@ captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
use_ddp: True # True, False | distributed data parallel
|
||||
use_fp16: True # True, False | mixed precision training
|
||||
use_diff: True # True,use diffusion in subdecoder
|
||||
use_diff: False # True,use diffusion in subdecoder
|
||||
diff_steps: 8 # number of diffusion steps
|
||||
use_dispLoss: True
|
||||
use_dispLoss: False
|
||||
lambda_weight: 0.5
|
||||
tau: 0.5
|
||||
|
||||
train_params:
|
||||
device: cuda
|
||||
batch_size: 10
|
||||
batch_size: 9
|
||||
grad_clip: 1.0
|
||||
num_iter: 300000 # total number of iterations
|
||||
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
||||
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
|
||||
iterations_per_training_cycle: 10 # number of iterations for logging training loss
|
||||
iterations_per_validation_cycle: 3000 # number of iterations for validation process
|
||||
input_length: 3072 # input sequence length3072
|
||||
input_length: 2048 # input sequence length3072
|
||||
# you can use focal loss, it it's not used, set focal_gamma to 0
|
||||
focal_alpha: 1
|
||||
focal_gamma: 0
|
||||
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||
scheduler : cosinelr
|
||||
initial_lr: 0.00001
|
||||
initial_lr: 0.0003
|
||||
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||
warmup_steps: 2000 #number of warmup steps
|
||||
|
||||
@ -5,7 +5,7 @@ model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoder
|
||||
model_dropout: 0.2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoderV2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 768
|
||||
num_layer: 16
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoderV2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1080
|
||||
num_layer: 13
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoderV2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1080
|
||||
num_layer: 13
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: SelfAttention
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1080
|
||||
num_layer: 13
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 3 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: SelfAttention
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1272
|
||||
num_layer: 20
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
454
Amadeus/toy_train.py
Normal file
454
Amadeus/toy_train.py
Normal file
@ -0,0 +1,454 @@
|
||||
from math import ceil
|
||||
from typing import Optional, Union, Literal
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers import Qwen2Config
|
||||
from transformers.activations import get_activation
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
x = x.float()
|
||||
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||||
x = self.weight * x
|
||||
return x.to(dtype)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: Optional[int] = None,
|
||||
dropout: float = 0.,
|
||||
hidden_act: Literal['silu', 'gelu'] = 'silu'):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = get_activation(hidden_act)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gate, hidden = self.up_proj(x).chunk(2, dim=-1)
|
||||
gate = self.act_fn(gate)
|
||||
|
||||
return self.down_proj(self.dropout(gate * hidden))
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
head_dim: int = 64,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_dropout: float = 0.,
|
||||
bias: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim)
|
||||
self.head_dim = head_dim
|
||||
self.scaling = head_dim ** -0.5
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||||
bias=bias)
|
||||
self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||||
bias=bias)
|
||||
self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||||
bias=bias)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_attention_heads * self.head_dim, hidden_size,
|
||||
bias=bias)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation]
|
||||
|
||||
# To be compatible with huggingface
|
||||
self.config = Qwen2Config()
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
key_states: Optional[Tensor] = None,
|
||||
value_states: Optional[Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
|
||||
q = self.q_proj(hidden_states)
|
||||
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2)
|
||||
|
||||
if(key_states is None or value_states is None):
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2)
|
||||
value_states = v.view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attn_output, attn_weights = self.attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DecoderLayerWithCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
|
||||
attn_head_dim: int = 64,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_dropout: float = 0.,
|
||||
attn_bias: bool = False,
|
||||
|
||||
mlp_intermediate_size: Optional[int] = None,
|
||||
mlp_dropout: float = 0.,
|
||||
mlp_act: Literal['gelu', 'silu'] = 'silu',
|
||||
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.self_attn = MultiHeadAttention(
|
||||
hidden_size,
|
||||
attn_head_dim,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
attn_bias,
|
||||
rms_norm_eps,
|
||||
attn_implementation)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
hidden_size,
|
||||
attn_head_dim,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
attn_bias,
|
||||
rms_norm_eps,
|
||||
attn_implementation)
|
||||
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
mlp_intermediate_size,
|
||||
mlp_dropout,
|
||||
mlp_act)
|
||||
|
||||
self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_dict,
|
||||
):
|
||||
'''
|
||||
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||||
'''
|
||||
hidden_states = input_dict['input_seq']
|
||||
cross_attn_states = input_dict['memory']
|
||||
cross_attn_mask = input_dict['memory_mask']
|
||||
attention_mask = input_dict['memory_mask']
|
||||
output_attentions = False
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layernorm(hidden_states)
|
||||
hidden_states, dec_attn_weight = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross Attention
|
||||
if(cross_attn_states is not None):
|
||||
residual = hidden_states
|
||||
hidden_states = self.cross_attn_layernorm(hidden_states)
|
||||
cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim)
|
||||
key_states = self.cross_attn.k_proj(cross_attn_states)
|
||||
value_states = self.cross_attn.v_proj(cross_attn_states)
|
||||
|
||||
key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2)
|
||||
value_states = value_states.view(cross_attn_shape).transpose(1, 2)
|
||||
hidden_states, cross_attn_weight = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=cross_attn_mask,
|
||||
key_states=key_states,
|
||||
value_states=value_states
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Feed Forward
|
||||
residual = hidden_states
|
||||
hidden_states = self.ffn_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask}
|
||||
if(output_attentions):
|
||||
if(cross_attn_states is not None):
|
||||
# return (hidden_states, (dec_attn_weight, cross_attn_weight))
|
||||
output_dict['dec_attn_weight'] = dec_attn_weight
|
||||
output_dict['cross_attn_weight'] = cross_attn_weight
|
||||
return output_dict
|
||||
else:
|
||||
# return (hidden_states, dec_attn_weight)
|
||||
output_dict['dec_attn_weight'] = dec_attn_weight
|
||||
return output_dict
|
||||
else:
|
||||
return output_dict
|
||||
|
||||
def sinusoidal_positional_embedding(
|
||||
x: Tensor,
|
||||
n: float = 10000.0) -> Tensor:
|
||||
device, dtype = x.device, x.dtype
|
||||
T = x.shape[-2]
|
||||
D = x.shape[-1]
|
||||
|
||||
positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1)
|
||||
embeddings = torch.zeros(T, D, device=device, dtype=dtype)
|
||||
|
||||
denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding
|
||||
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model))
|
||||
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model))
|
||||
|
||||
return x + embeddings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
|
||||
# 创建一个简单的数据集来测试收敛性
|
||||
class SimpleSeq2SeqDataset(Dataset):
|
||||
def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256):
|
||||
self.num_samples = num_samples
|
||||
self.seq_len = seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 模拟编码器输出 (memory)
|
||||
memory = torch.randn(self.seq_len, self.hidden_size)
|
||||
|
||||
# 模拟解码器输入和目标
|
||||
decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,))
|
||||
target = torch.randint(0, self.vocab_size, (self.seq_len,))
|
||||
|
||||
# 创建简单的注意力掩码
|
||||
memory_mask = torch.ones(self.seq_len, self.seq_len)
|
||||
|
||||
return {
|
||||
'memory': memory,
|
||||
'decoder_input': decoder_input,
|
||||
'target': target,
|
||||
'memory_mask': memory_mask
|
||||
}
|
||||
|
||||
# 创建简单的模型
|
||||
class SimpleDecoderModel(nn.Module):
|
||||
def __init__(self, vocab_size, hidden_size=256, num_layers=3):
|
||||
super(SimpleDecoderModel, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# 嵌入层
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
|
||||
# 位置编码
|
||||
self.pos_encoding = sinusoidal_positional_embedding
|
||||
|
||||
# 多个解码器层
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderLayerWithCrossAttention(
|
||||
hidden_size=hidden_size,
|
||||
attn_head_dim=64,
|
||||
num_attention_heads=2,
|
||||
attention_dropout=0.1,
|
||||
attn_bias=False,
|
||||
mlp_intermediate_size=512,
|
||||
mlp_dropout=0.1,
|
||||
mlp_act='silu',
|
||||
rms_norm_eps=1e-6,
|
||||
attn_implementation="sdpa"
|
||||
) for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# 输出层
|
||||
self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6)
|
||||
self.output_proj = nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
def forward(self, decoder_input, memory, memory_mask):
|
||||
# 嵌入 + 位置编码
|
||||
x = self.embedding(decoder_input)
|
||||
x = self.pos_encoding(x)
|
||||
|
||||
# 通过多个解码器层
|
||||
input_dict = {
|
||||
'input_seq': x,
|
||||
'memory': memory,
|
||||
'memory_mask': memory_mask
|
||||
}
|
||||
|
||||
for layer in self.layers:
|
||||
output_dict = layer(input_dict)
|
||||
input_dict['input_seq'] = output_dict['input_seq']
|
||||
|
||||
# 最终输出
|
||||
hidden_states = self.output_norm(output_dict['input_seq'])
|
||||
logits = self.output_proj(hidden_states)
|
||||
|
||||
return logits
|
||||
|
||||
# 训练函数
|
||||
def train_decoder_model():
|
||||
# 超参数
|
||||
vocab_size = 100
|
||||
hidden_size = 256
|
||||
seq_len = 10
|
||||
batch_size = 32
|
||||
num_epochs = 10
|
||||
learning_rate = 0.001
|
||||
|
||||
# 设备设置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 创建数据集和数据加载器
|
||||
dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len,
|
||||
vocab_size=vocab_size, hidden_size=hidden_size)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# 初始化模型
|
||||
model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size)
|
||||
model.to(device)
|
||||
|
||||
# 损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# 训练记录
|
||||
train_losses = []
|
||||
|
||||
print("开始训练...")
|
||||
model.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
epoch_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in dataloader:
|
||||
# 移动到设备
|
||||
memory = batch['memory'].to(device)
|
||||
decoder_input = batch['decoder_input'].long().to(device)
|
||||
target = batch['target'].long().to(device)
|
||||
memory_mask = batch['memory_mask'].to(device)
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
logits = model(decoder_input, memory, memory_mask)
|
||||
|
||||
# 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len)
|
||||
loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1))
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = epoch_loss / num_batches
|
||||
train_losses.append(avg_loss)
|
||||
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
||||
|
||||
# 早期停止检查:如果损失明显下降,说明模型在收敛
|
||||
if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值
|
||||
print("✅ 模型正常收敛!")
|
||||
return True
|
||||
|
||||
# 检查最终损失
|
||||
final_loss = train_losses[-1]
|
||||
if final_loss < 5.0:
|
||||
print("✅ 训练完成,模型正常收敛!")
|
||||
return True
|
||||
else:
|
||||
print("❌ 损失下降不明显,可能需要调整模型或超参数")
|
||||
return False
|
||||
|
||||
# 运行训练
|
||||
if __name__ == "__main__":
|
||||
# 先测试单个decoder层的前向传播
|
||||
print("测试DecoderLayerWithCrossAttention前向传播...")
|
||||
|
||||
# 创建测试数据
|
||||
batch_size, seq_len, hidden_size = 2, 64, 256
|
||||
decoder_input = torch.randn(seq_len, hidden_size)
|
||||
memory = torch.randn(seq_len, hidden_size)
|
||||
memory_mask = torch.ones(seq_len, seq_len)
|
||||
|
||||
# 测试单层
|
||||
decoder_layer = DecoderLayerWithCrossAttention(
|
||||
hidden_size=hidden_size,
|
||||
attn_head_dim=64,
|
||||
num_attention_heads=4
|
||||
)
|
||||
|
||||
input_dict = {
|
||||
'input_seq': decoder_input,
|
||||
'memory': memory,
|
||||
'memory_mask': memory_mask
|
||||
}
|
||||
|
||||
try:
|
||||
output = decoder_layer(input_dict)
|
||||
print("✅ DecoderLayer前向传播测试通过")
|
||||
print(f"输入形状: {decoder_input.shape}")
|
||||
print(f"输出形状: {output['input_seq'].shape}")
|
||||
|
||||
# 运行完整训练
|
||||
success = train_decoder_model()
|
||||
if success:
|
||||
print("\n🎉 测试完成!DecoderLayerWithCrossAttention能够正常训练和收敛")
|
||||
else:
|
||||
print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 前向传播测试失败: {e}")
|
||||
@ -229,18 +229,38 @@ class DiffusionLoss4CompoundToken():
|
||||
|
||||
return loss
|
||||
|
||||
def get_aux_ar_nll_loss(self, logits, target, mask):
|
||||
probs = logits.softmax(dim=-1)
|
||||
if probs.ndim == 3:
|
||||
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
# clamp min value to 1e-7 to avoid log(0)
|
||||
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
|
||||
return loss
|
||||
|
||||
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
|
||||
train_loss_list = []
|
||||
log_loss_dict_normal = {}
|
||||
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
disp_loss = None
|
||||
aux_ar_logits = None
|
||||
# print(len(logits_dict))
|
||||
if len(logits_dict) == 2: # has aux ar loss
|
||||
logits_dict, aux_ar_logits = logits_dict
|
||||
if input_dict is not None:
|
||||
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
|
||||
feat = hidden_vec.mean(dim=1) #bs,dim
|
||||
disp_loss = dispersive_loss(feat, tau=tau) # scalar
|
||||
for idx, key in enumerate(self.feature_list):
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
|
||||
if aux_ar_logits is not None:
|
||||
aux_ar_loss = self.get_aux_ar_nll_loss(aux_ar_logits[key], shifted_tgt[..., idx], mask)
|
||||
training_loss = 0.5 * training_loss + 0.5 * aux_ar_loss
|
||||
train_loss_list.append(training_loss)
|
||||
if valid:
|
||||
if key == 'type' or key == 'timesig':
|
||||
|
||||
@ -74,8 +74,8 @@ class LanguageModelTrainer:
|
||||
sampling_threshold: float, # Threshold for sampling decisions
|
||||
sampling_temperature: float, # Temperature for controlling sampling randomness
|
||||
config, # Configuration parameters (contains general, training, and inference settings)
|
||||
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
# model_checkpoint="wandb/run-20251114_151512-k21rnynj/files/checkpoints/iter104999_loss0.2490.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
):
|
||||
# Save model, optimizer, and other configurations
|
||||
self.model = model
|
||||
@ -892,6 +892,10 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
|
||||
segment, mask, caption,encoded_caption = batch
|
||||
input_seq, target = segment[:, :-1], segment[:, 1:]
|
||||
total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
|
||||
try:
|
||||
aux_ar_logits, logits_dict = logits_dict
|
||||
except:
|
||||
logits_dict = logits_dict
|
||||
probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()}
|
||||
num_nonmask_tokens = torch.sum(mask)
|
||||
input_seq = input_seq.to(self.device)
|
||||
|
||||
@ -729,11 +729,6 @@ class XtransformerNewPretrainingDecoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
# frozen text encoder
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
|
||||
@ -20,10 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
from itertools import groupby, chain
|
||||
from random import shuffle, seed
|
||||
|
||||
|
||||
lock = RLock()
|
||||
|
||||
def shuffled(seq):
|
||||
shuffle(seq)
|
||||
return seq
|
||||
|
||||
def permute_inside_and_across_tracks(seq):
|
||||
seq_sorted = sorted(seq, key=lambda x: x[5])
|
||||
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
|
||||
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
|
||||
|
||||
|
||||
def convert_event_dicts(dict_list):
|
||||
"""
|
||||
将 event 词表列表按顺序转换为结构化输出
|
||||
@ -59,7 +71,7 @@ def convert_event_dicts(dict_list):
|
||||
|
||||
return result
|
||||
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool):
|
||||
try:
|
||||
score = Score(midi_path, ttype="tick")
|
||||
|
||||
@ -71,6 +83,8 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
# 分词
|
||||
tok_seq = tokenizer(score)
|
||||
token_ids = tok_seq.ids
|
||||
if whether_shuffle:
|
||||
token_ids = permute_inside_and_across_tracks(token_ids)
|
||||
# add sos token at the beginning
|
||||
vocab = tokenizer.vocab
|
||||
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
|
||||
@ -86,7 +100,7 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
print(f" × 处理文件时出错:{midi_path} -> {e}")
|
||||
|
||||
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False):
|
||||
# === 1. 初始化分词器并保存词表 ===
|
||||
print("初始化分词器 Octuple...")
|
||||
config = TokenizerConfig(
|
||||
@ -108,15 +122,20 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# === 3. 收集 MIDI 文件 ===
|
||||
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
midi_paths = list(midi_paths)
|
||||
# midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
# glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
# midi_paths = list(midi_paths)
|
||||
midi_paths = []
|
||||
for root, _, files in os.walk(midi_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith(('.mid', '.midi')):
|
||||
midi_paths.append(os.path.join(root, file))
|
||||
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
|
||||
|
||||
# === 4. 并行处理 ===
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
res = future.result()
|
||||
@ -128,8 +147,18 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
|
||||
|
||||
parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)")
|
||||
parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录")
|
||||
parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序")
|
||||
|
||||
dataset_name = midi_directory.split("/")[-1]
|
||||
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
|
||||
|
||||
|
||||
output_dir = tuneidx_prefix
|
||||
preprocess_midi_directory(midi_directory, output_dir)
|
||||
args = parser.parse_args()
|
||||
preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle)
|
||||
39
data_representation/permute.py
Normal file
39
data_representation/permute.py
Normal file
@ -0,0 +1,39 @@
|
||||
from itertools import groupby, chain
|
||||
from random import shuffle, seed
|
||||
|
||||
def shuffled(seq):
|
||||
shuffle(seq)
|
||||
return seq
|
||||
|
||||
def permute_inside_and_across_tracks(seq):
|
||||
seq_sorted = sorted(seq, key=lambda x: x[5])
|
||||
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
|
||||
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
|
||||
|
||||
# (PitchDrum, Position, Bar, Velocity, Duration, Program, Tempo, TimeSignature)
|
||||
seq = [
|
||||
# Program 0
|
||||
(60, 0, 0, 90, 96, 0, 120, 16),
|
||||
(64, 48,0, 88, 96, 0, 120, 16),
|
||||
(67, 96,0, 92, 96, 0, 120, 16),
|
||||
|
||||
# Program 32
|
||||
(40, 0, 0, 80, 192, 32, 120, 16),
|
||||
(43, 0, 1, 78, 192, 32, 120, 16),
|
||||
|
||||
# Program 40
|
||||
(72, 24,0, 85, 72, 40, 120, 16),
|
||||
(74, 72,0, 83, 72, 40, 120, 16),
|
||||
(76, 24,1, 86, 72, 40, 120, 16),
|
||||
]
|
||||
|
||||
# seed(42)
|
||||
|
||||
inside_track_permuted_and_track_permuted = permute_inside_and_across_tracks(seq)
|
||||
|
||||
print("原始 seq:")
|
||||
for e in seq:
|
||||
print(e)
|
||||
print("\n打乱后的 seq:")
|
||||
for e in inside_track_permuted_and_track_permuted:
|
||||
print(e)
|
||||
634
data_representation/resample.py
Normal file
634
data_representation/resample.py
Normal file
@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于token分布距离的重采样脚本
|
||||
读取octuple_token_analysis_report.json,计算每个数据与整体分布的距离,
|
||||
按照距离加权采样,距离越远的越容易被采样
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from scipy.stats import entropy, wasserstein_distance
|
||||
from scipy.spatial.distance import jensenshannon
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import multiprocessing as mp
|
||||
|
||||
# Octuple的列名定义
|
||||
COLUMN_NAMES = [
|
||||
"pitch", # 0: Pitch/PitchDrum
|
||||
"position", # 1: Position
|
||||
"bar", # 2: Bar
|
||||
"velocity", # 3: Velocity
|
||||
"duration", # 4: Duration
|
||||
"program", # 5: Program
|
||||
"tempo", # 6: Tempo
|
||||
"timesig" # 7: TimeSignature
|
||||
]
|
||||
|
||||
|
||||
def load_distribution_from_json(json_path):
|
||||
"""
|
||||
从JSON文件中加载整体token分布
|
||||
|
||||
Args:
|
||||
json_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
dict: {column_name: {token: probability}}
|
||||
"""
|
||||
print(f"读取分布文件: {json_path}")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
report = json.load(f)
|
||||
|
||||
distributions = {}
|
||||
columns = report.get('columns', {})
|
||||
|
||||
for col_name in COLUMN_NAMES:
|
||||
if col_name not in columns:
|
||||
print(f"警告: 列 {col_name} 不在报告中")
|
||||
distributions[col_name] = {}
|
||||
continue
|
||||
|
||||
col_data = columns[col_name]
|
||||
token_counts = col_data.get('token_counts', {})
|
||||
total_tokens = col_data.get('total_tokens', 1)
|
||||
|
||||
# 转换为概率分布
|
||||
distribution = {}
|
||||
for token_str, count in token_counts.items():
|
||||
token = int(token_str)
|
||||
distribution[token] = count / total_tokens
|
||||
|
||||
distributions[col_name] = distribution
|
||||
print(f" 列 {col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}")
|
||||
|
||||
return distributions
|
||||
|
||||
|
||||
def compute_data_distribution(data, col_idx):
|
||||
"""
|
||||
计算单个数据在指定列的token分布
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns)
|
||||
col_idx: 列索引
|
||||
|
||||
Returns:
|
||||
dict: {token: probability}
|
||||
"""
|
||||
if data.size == 0:
|
||||
return {}
|
||||
|
||||
tokens = data[:, col_idx]
|
||||
unique, counts = np.unique(tokens, return_counts=True)
|
||||
total = len(tokens)
|
||||
|
||||
distribution = {}
|
||||
for token, count in zip(unique, counts):
|
||||
distribution[int(token)] = count / total
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def compute_emd_distance(dist1, dist2):
|
||||
"""
|
||||
使用推土机距离(Earth Mover's Distance / Wasserstein距离)计算两个分布之间的距离
|
||||
|
||||
Args:
|
||||
dist1: 分布1,dict {token: probability},已归一化
|
||||
dist2: 分布2,dict {token: probability},已归一化
|
||||
|
||||
Returns:
|
||||
float: EMD距离
|
||||
"""
|
||||
# 获取所有token的并集,并排序
|
||||
all_tokens = sorted(set(dist1.keys()) | set(dist2.keys()))
|
||||
|
||||
if not all_tokens:
|
||||
return 0.0
|
||||
|
||||
# 构建概率向量和token值向量
|
||||
p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens])
|
||||
q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens])
|
||||
token_values = np.array(all_tokens, dtype=float)
|
||||
|
||||
# 归一化(处理数值误差)
|
||||
p_sum = p_weights.sum()
|
||||
q_sum = q_weights.sum()
|
||||
|
||||
if p_sum < 1e-10 or q_sum < 1e-10:
|
||||
return 0.0
|
||||
|
||||
p_weights = p_weights / p_sum
|
||||
q_weights = q_weights / q_sum
|
||||
|
||||
# 使用Wasserstein距离(1-Wasserstein距离,即推土机距离)
|
||||
# wasserstein_distance需要两个分布的样本值位置和权重
|
||||
# 对于离散分布,我们使用token值作为位置
|
||||
emd = wasserstein_distance(token_values, token_values, p_weights, q_weights)
|
||||
|
||||
return emd
|
||||
|
||||
|
||||
def compute_distribution_distance(dist1, dist2, method='emd'):
|
||||
"""
|
||||
计算两个分布之间的距离
|
||||
|
||||
Args:
|
||||
dist1: 分布1,dict {token: probability}
|
||||
dist2: 分布2,dict {token: probability}
|
||||
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
|
||||
|
||||
Returns:
|
||||
float: 分布距离
|
||||
"""
|
||||
if method == 'emd':
|
||||
return compute_emd_distance(dist1, dist2)
|
||||
|
||||
# 获取所有token的并集
|
||||
all_tokens = set(dist1.keys()) | set(dist2.keys())
|
||||
|
||||
if not all_tokens:
|
||||
return 0.0
|
||||
|
||||
# 构建概率向量
|
||||
p = np.array([dist1.get(token, 0.0) for token in all_tokens])
|
||||
q = np.array([dist2.get(token, 0.0) for token in all_tokens])
|
||||
|
||||
# 归一化(处理数值误差)
|
||||
p = p / (p.sum() + 1e-10)
|
||||
q = q / (q.sum() + 1e-10)
|
||||
|
||||
if method == 'js':
|
||||
# Jensen-Shannon散度(对称,范围[0, 1])
|
||||
return jensenshannon(p, q)
|
||||
elif method == 'kl':
|
||||
# KL散度(非对称,需要处理零值)
|
||||
# 添加小的平滑项避免log(0)
|
||||
p = p + 1e-10
|
||||
q = q + 1e-10
|
||||
p = p / p.sum()
|
||||
q = q / q.sum()
|
||||
return entropy(p, q)
|
||||
else:
|
||||
raise ValueError(f"未知的距离方法: {method}")
|
||||
|
||||
|
||||
def extract_subdistribution(global_dist, data_tokens):
|
||||
"""
|
||||
从全局分布中提取只包含数据中出现的token的子分布,并归一化
|
||||
|
||||
Args:
|
||||
global_dist: 全局分布,dict {token: probability}
|
||||
data_tokens: 数据中出现的token集合,set或list
|
||||
|
||||
Returns:
|
||||
dict: 子分布,dict {token: probability},已归一化
|
||||
"""
|
||||
if not data_tokens or not global_dist:
|
||||
return {}
|
||||
|
||||
# 提取子分布
|
||||
sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens}
|
||||
|
||||
# 归一化
|
||||
total = sum(sub_dist.values())
|
||||
if total < 1e-10:
|
||||
return {}
|
||||
|
||||
normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()}
|
||||
|
||||
return normalized_sub_dist
|
||||
|
||||
|
||||
def compute_data_distance(data, global_distributions, method='emd'):
|
||||
"""
|
||||
计算单个数据与整体分布的距离
|
||||
对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布,
|
||||
都进行归一化然后计算推土机距离
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载)
|
||||
global_distributions: 整体分布,dict {column_name: {token: probability}}
|
||||
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
|
||||
|
||||
Returns:
|
||||
float: 平均距离(跨所有列)
|
||||
"""
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
try:
|
||||
data = np.load(data)['arr_0']
|
||||
except Exception as e:
|
||||
# 不打印错误,让调用者处理
|
||||
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
|
||||
|
||||
distances = []
|
||||
|
||||
for col_idx, col_name in enumerate(COLUMN_NAMES):
|
||||
# 计算该数据在该列的分布
|
||||
data_dist = compute_data_distribution(data, col_idx)
|
||||
|
||||
# 获取整体分布
|
||||
global_dist = global_distributions.get(col_name, {})
|
||||
|
||||
if not data_dist or not global_dist:
|
||||
continue
|
||||
|
||||
# 从全局分布中提取只包含数据中出现的token的子分布
|
||||
data_tokens = set(data_dist.keys())
|
||||
sub_global_dist = extract_subdistribution(global_dist, data_tokens)
|
||||
|
||||
if not sub_global_dist:
|
||||
continue
|
||||
|
||||
# 归一化数据分布
|
||||
data_dist_sum = sum(data_dist.values())
|
||||
if data_dist_sum < 1e-10:
|
||||
continue
|
||||
normalized_data_dist = {token: prob / data_dist_sum
|
||||
for token, prob in data_dist.items()}
|
||||
|
||||
# 计算距离(两个分布都已归一化)
|
||||
dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method)
|
||||
distances.append(dist)
|
||||
|
||||
# 返回平均距离
|
||||
return np.mean(distances) if distances else 0.0
|
||||
|
||||
|
||||
def _load_single_file(npz_file):
|
||||
"""
|
||||
加载单个npz文件的辅助函数(用于多线程)
|
||||
|
||||
Args:
|
||||
npz_file: npz文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (data, file_path) 或 None(如果加载失败)
|
||||
"""
|
||||
try:
|
||||
data = np.load(npz_file)['arr_0']
|
||||
if data.ndim == 2:
|
||||
return (data, npz_file)
|
||||
elif data.ndim == 1:
|
||||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_data_file_paths(data_dir):
|
||||
"""
|
||||
获取所有数据文件路径(不加载数据)
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径
|
||||
|
||||
Returns:
|
||||
list: 文件路径列表
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
if data_dir.exists():
|
||||
npz_files = sorted(list(data_dir.rglob("*.npz")))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return []
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件")
|
||||
return npz_files
|
||||
|
||||
|
||||
def load_data_with_paths(data_dir, num_threads=None, lazy=False):
|
||||
"""
|
||||
加载所有数据并返回数据路径列表(多线程版本)
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
lazy: 如果为True,只返回文件路径,不加载数据
|
||||
|
||||
Returns:
|
||||
tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
if data_dir.exists():
|
||||
npz_files = sorted(list(data_dir.rglob("*.npz")))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return [], []
|
||||
|
||||
if lazy:
|
||||
print(f"找到 {len(npz_files)} 个.npz文件(延迟加载模式)")
|
||||
return None, npz_files
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件,开始加载...")
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(npz_files))
|
||||
|
||||
all_data = []
|
||||
file_paths = []
|
||||
|
||||
# 使用多线程加载文件
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = {executor.submit(_load_single_file, npz_file): npz_file
|
||||
for npz_file in npz_files}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"):
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
data, file_path = result
|
||||
all_data.append(data)
|
||||
file_paths.append(file_path)
|
||||
|
||||
# 保持原始顺序
|
||||
if file_paths:
|
||||
sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0]))
|
||||
file_paths, all_data = zip(*sorted_pairs)
|
||||
file_paths = list(file_paths)
|
||||
all_data = list(all_data)
|
||||
|
||||
return all_data, file_paths
|
||||
|
||||
|
||||
def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True):
|
||||
"""
|
||||
根据距离进行加权重采样
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
distances: 距离列表
|
||||
sample_ratio: 采样比例
|
||||
method: 距离计算方法(用于确定权重方向)
|
||||
lazy: 如果为True,返回文件路径而不是数据
|
||||
|
||||
Returns:
|
||||
tuple: (sampled_data_or_paths, sampled_paths, sampled_indices)
|
||||
"""
|
||||
n_samples = int(len(file_paths) * sample_ratio)
|
||||
print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)")
|
||||
|
||||
# 将距离转换为权重
|
||||
# 距离越远,权重越大
|
||||
distances = np.array(distances)
|
||||
|
||||
# 处理零距离或异常值
|
||||
min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10
|
||||
distances = np.maximum(distances, min_dist * 0.1)
|
||||
|
||||
# 归一化距离到[0, 1],然后转换为权重
|
||||
# 使用指数函数增强距离差异
|
||||
normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10)
|
||||
weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样
|
||||
|
||||
# 归一化权重
|
||||
weights = weights / weights.sum()
|
||||
|
||||
# 加权随机采样
|
||||
indices = np.arange(len(file_paths))
|
||||
sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights)
|
||||
|
||||
sampled_paths = [file_paths[i] for i in sampled_indices]
|
||||
|
||||
# 如果lazy=True,返回路径;否则加载数据
|
||||
if lazy:
|
||||
sampled_data = sampled_paths # 返回路径,延迟加载
|
||||
else:
|
||||
# 加载采样后的数据
|
||||
sampled_data = []
|
||||
for path in tqdm(sampled_paths, desc="加载采样数据"):
|
||||
try:
|
||||
data = np.load(path)['arr_0']
|
||||
sampled_data.append(data)
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {path} 时出错: {e}")
|
||||
sampled_data.append(None)
|
||||
|
||||
print(f"采样完成:")
|
||||
print(f" 采样数据数量: {len(sampled_paths)}")
|
||||
print(f" 平均距离: {distances[sampled_indices].mean():.6f}")
|
||||
print(f" 最小距离: {distances[sampled_indices].min():.6f}")
|
||||
print(f" 最大距离: {distances[sampled_indices].max():.6f}")
|
||||
|
||||
return sampled_data, sampled_paths, sampled_indices
|
||||
|
||||
|
||||
def _save_single_file(args_tuple):
|
||||
"""
|
||||
保存单个文件的辅助函数(用于多线程)
|
||||
支持延迟加载:如果data是路径,则从文件加载
|
||||
|
||||
Args:
|
||||
args_tuple: (data, original_path, output_dir)
|
||||
|
||||
Returns:
|
||||
tuple: (success, original_path) 或 (False, original_path, error_msg)
|
||||
"""
|
||||
data, original_path, output_dir = args_tuple
|
||||
try:
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
data = np.load(data)['arr_0']
|
||||
|
||||
# 保持相对路径结构
|
||||
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
|
||||
output_path = output_dir / relative_path
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
np.savez_compressed(output_path, data)
|
||||
return (True, original_path)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
|
||||
return (False, original_path, error_msg)
|
||||
|
||||
|
||||
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
|
||||
"""
|
||||
保存采样后的数据(多线程版本)
|
||||
|
||||
Args:
|
||||
sampled_data: 采样后的数据列表
|
||||
sampled_paths: 采样后的文件路径列表
|
||||
output_dir: 输出目录
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n保存采样数据到: {output_dir}")
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(sampled_data))
|
||||
|
||||
# 准备参数
|
||||
save_args = [(data, original_path, output_dir)
|
||||
for data, original_path in zip(sampled_data, sampled_paths)]
|
||||
|
||||
# 使用多线程保存文件
|
||||
success_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(_save_single_file, args)
|
||||
for args in save_args]
|
||||
|
||||
# 收集结果
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
|
||||
try:
|
||||
result = future.result(timeout=300) # 设置超时避免卡死
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
success = result[0]
|
||||
if success:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"错误: 获取保存结果时出错: {e}")
|
||||
|
||||
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="基于token分布距离的重采样")
|
||||
parser.add_argument("--json_path", type=str,
|
||||
default="octuple_token_analysis_report.json",
|
||||
help="token分析报告JSON文件路径")
|
||||
parser.add_argument("--data_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
|
||||
help="数据目录路径")
|
||||
parser.add_argument("--output_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled",
|
||||
help="输出目录路径")
|
||||
parser.add_argument("--sample_ratio", type=float, default=0.3,
|
||||
help="采样比例 (默认: 0.3)")
|
||||
parser.add_argument("--distance_method", type=str, default="emd",
|
||||
choices=["emd", "js", "kl"],
|
||||
help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子")
|
||||
parser.add_argument("--num_threads", type=int, default=1,
|
||||
help="线程数,None表示使用CPU核心数 (默认: None)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置随机种子
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# 1. 加载整体分布
|
||||
global_distributions = load_distribution_from_json(args.json_path)
|
||||
|
||||
# 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据)
|
||||
_, file_paths = load_data_with_paths(args.data_dir, lazy=True)
|
||||
|
||||
if not file_paths:
|
||||
print("错误: 未找到任何数据文件")
|
||||
return
|
||||
|
||||
print(f"\n共找到 {len(file_paths)} 个数据文件")
|
||||
|
||||
# 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载)
|
||||
print("\n计算每个数据与整体分布的距离(延迟加载模式)...")
|
||||
|
||||
def _compute_distance_wrapper(args_tuple):
|
||||
"""计算距离的包装函数(用于多线程,支持延迟加载)"""
|
||||
idx, file_path, global_dists, method = args_tuple
|
||||
try:
|
||||
distance = compute_data_distance(file_path, global_dists, method=method)
|
||||
return (idx, distance, None)
|
||||
except Exception as e:
|
||||
return (idx, 0.0, str(e))
|
||||
|
||||
if args.num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(file_paths))
|
||||
else:
|
||||
num_threads = args.num_threads
|
||||
|
||||
# 准备参数(使用文件路径而不是数据,包含索引)
|
||||
distance_args = [(i, file_path, global_distributions, args.distance_method)
|
||||
for i, file_path in enumerate(file_paths)]
|
||||
|
||||
# 使用多线程计算距离(按需加载数据)
|
||||
# 初始化结果列表,保持顺序
|
||||
distances = [0.0] * len(file_paths)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(_compute_distance_wrapper, args)
|
||||
for args in distance_args]
|
||||
|
||||
# 收集结果,使用 tqdm 显示进度
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"):
|
||||
try:
|
||||
idx, distance, error = future.result(timeout=300) # 设置超时避免卡死
|
||||
distances[idx] = distance
|
||||
if error:
|
||||
print(f"警告: 计算距离时出错 (索引 {idx}): {error}")
|
||||
except Exception as e:
|
||||
print(f"错误: 获取结果时出错: {e}")
|
||||
# 如果无法获取结果,保持默认值 0.0
|
||||
|
||||
distances = np.array(distances)
|
||||
print(f"\n距离统计:")
|
||||
print(f" 平均距离: {distances.mean():.6f}")
|
||||
print(f" 最小距离: {distances.min():.6f}")
|
||||
print(f" 最大距离: {distances.max():.6f}")
|
||||
print(f" 标准差: {distances.std():.6f}")
|
||||
|
||||
# 4. 根据距离进行加权采样(延迟加载模式)
|
||||
sampled_data, sampled_paths, sampled_indices = weighted_resample(
|
||||
file_paths, distances,
|
||||
sample_ratio=args.sample_ratio,
|
||||
method=args.distance_method,
|
||||
lazy=True # 使用延迟加载,避免重复加载数据
|
||||
)
|
||||
|
||||
# 5. 保存采样结果(多线程,延迟加载)
|
||||
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
|
||||
|
||||
# 6. 保存采样索引(可选,用于后续分析)
|
||||
indices_file = Path(args.output_dir) / "sampled_indices.npy"
|
||||
np.save(indices_file, sampled_indices)
|
||||
print(f"\n采样索引已保存到: {indices_file}")
|
||||
|
||||
# 保存采样信息
|
||||
info = {
|
||||
"total_samples": len(file_paths),
|
||||
"sampled_samples": len(sampled_data),
|
||||
"sample_ratio": args.sample_ratio,
|
||||
"distance_method": args.distance_method,
|
||||
"distance_stats": {
|
||||
"mean": float(distances.mean()),
|
||||
"min": float(distances.min()),
|
||||
"max": float(distances.max()),
|
||||
"std": float(distances.std())
|
||||
},
|
||||
"sampled_distance_stats": {
|
||||
"mean": float(distances[sampled_indices].mean()),
|
||||
"min": float(distances[sampled_indices].min()),
|
||||
"max": float(distances[sampled_indices].max()),
|
||||
"std": float(distances[sampled_indices].std())
|
||||
}
|
||||
}
|
||||
|
||||
info_file = Path(args.output_dir) / "resample_info.json"
|
||||
with open(info_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(info, f, indent=2, ensure_ascii=False)
|
||||
print(f"采样信息已保存到: {info_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
472
data_representation/resampleV2.py
Normal file
472
data_representation/resampleV2.py
Normal file
@ -0,0 +1,472 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于position和duration token的重采样脚本V2
|
||||
对于每首歌:
|
||||
1. 如果包含的position和duration不在总数据集前3个,则必定采样
|
||||
2. 对于包含的,以某个固定的百分比采样
|
||||
3. 两个条件满足一个即可
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import multiprocessing as mp
|
||||
|
||||
# Octuple的列名定义
|
||||
COLUMN_NAMES = [
|
||||
"pitch", # 0: Pitch/PitchDrum
|
||||
"position", # 1: Position
|
||||
"bar", # 2: Bar
|
||||
"velocity", # 3: Velocity
|
||||
"duration", # 4: Duration
|
||||
"program", # 5: Program
|
||||
"tempo", # 6: Tempo
|
||||
"timesig" # 7: TimeSignature
|
||||
]
|
||||
|
||||
|
||||
def load_top_tokens_from_json(json_path, column_name, top_k=3):
|
||||
"""
|
||||
从JSON文件中加载指定列的前top_k个最常见的token
|
||||
|
||||
Args:
|
||||
json_path: JSON文件路径
|
||||
column_name: 列名(如'position'或'duration')
|
||||
top_k: 返回前k个最常见的token
|
||||
|
||||
Returns:
|
||||
set: 前top_k个最常见的token集合
|
||||
"""
|
||||
print(f"读取分布文件: {json_path}")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
report = json.load(f)
|
||||
|
||||
columns = report.get('columns', {})
|
||||
|
||||
if column_name not in columns:
|
||||
print(f"警告: 列 {column_name} 不在报告中")
|
||||
return set()
|
||||
|
||||
col_data = columns[column_name]
|
||||
token_counts = col_data.get('token_counts', {})
|
||||
|
||||
# 按出现次数排序,获取前top_k个
|
||||
sorted_tokens = sorted(token_counts.items(), key=lambda x: int(x[1]), reverse=True)
|
||||
top_tokens = {int(token_str) for token_str, _ in sorted_tokens[:top_k]}
|
||||
|
||||
print(f" 列 {column_name} 的前{top_k}个最常见token: {sorted(top_tokens)}")
|
||||
|
||||
return top_tokens
|
||||
|
||||
|
||||
def get_data_tokens(data, col_idx):
|
||||
"""
|
||||
获取单个数据在指定列的所有唯一token
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns) 或文件路径
|
||||
col_idx: 列索引
|
||||
|
||||
Returns:
|
||||
set: 唯一token集合
|
||||
"""
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
try:
|
||||
data = np.load(data)['arr_0']
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
|
||||
|
||||
if data.size == 0:
|
||||
return set()
|
||||
|
||||
tokens = data[:, col_idx]
|
||||
unique_tokens = set(int(token) for token in np.unique(tokens))
|
||||
|
||||
return unique_tokens
|
||||
|
||||
|
||||
def should_sample_song(data, top_position_tokens, top_duration_tokens,
|
||||
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, rng=None):
|
||||
"""
|
||||
判断一首歌是否应该被采样
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns) 或文件路径
|
||||
top_position_tokens: position列的前3个最常见token集合
|
||||
top_duration_tokens: duration列的前3个最常见token集合
|
||||
contain_sample_ratio: 对于包含前3个token的歌曲,采样比例
|
||||
not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率)
|
||||
rng: 随机数生成器,如果为None则使用全局的np.random
|
||||
|
||||
Returns:
|
||||
tuple: (是否应该采样, 是否在前3个) - 在前3个指position和duration都在前3个
|
||||
"""
|
||||
# 获取position和duration列的唯一token
|
||||
position_idx = COLUMN_NAMES.index("position")
|
||||
duration_idx = COLUMN_NAMES.index("duration")
|
||||
|
||||
position_tokens = get_data_tokens(data, position_idx)
|
||||
duration_tokens = get_data_tokens(data, duration_idx)
|
||||
|
||||
# 判断是否在前3个
|
||||
position_in_top3 = bool(position_tokens & top_position_tokens)
|
||||
duration_in_top3 = bool(duration_tokens & top_duration_tokens)
|
||||
in_top3 = position_in_top3 and duration_in_top3
|
||||
|
||||
if rng is None:
|
||||
rng = np.random
|
||||
|
||||
# 条件1: 如果position或duration不包含前3个token,以更高概率采样
|
||||
if not position_in_top3 or not duration_in_top3:
|
||||
should_sample = rng.random() < not_contain_sample_ratio
|
||||
return should_sample, False
|
||||
|
||||
# 条件2: 如果包含前3个token,则以固定百分比采样
|
||||
should_sample = rng.random() < contain_sample_ratio
|
||||
return should_sample, True
|
||||
|
||||
|
||||
def _load_single_file(npz_file):
|
||||
"""
|
||||
加载单个npz文件的辅助函数(用于多线程)
|
||||
|
||||
Args:
|
||||
npz_file: npz文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (data, file_path) 或 None(如果加载失败)
|
||||
"""
|
||||
try:
|
||||
data = np.load(npz_file)['arr_0']
|
||||
if data.ndim == 2:
|
||||
return (data, npz_file)
|
||||
elif data.ndim == 1:
|
||||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_data_file_paths(data_dir):
|
||||
"""
|
||||
获取所有数据文件路径(不加载数据)
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径
|
||||
|
||||
Returns:
|
||||
list: 文件路径列表
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
if data_dir.exists():
|
||||
npz_files = sorted(list(data_dir.rglob("*.npz")))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return []
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件")
|
||||
return npz_files
|
||||
|
||||
|
||||
def resample_songs(file_paths, top_position_tokens, top_duration_tokens,
|
||||
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9,
|
||||
num_threads=None, seed=42):
|
||||
"""
|
||||
根据新逻辑进行重采样
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
top_position_tokens: position列的前3个最常见token集合
|
||||
top_duration_tokens: duration列的前3个最常见token集合
|
||||
contain_sample_ratio: 对于包含前3个token的歌曲,采样比例
|
||||
not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率)
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
seed: 随机种子
|
||||
|
||||
Returns:
|
||||
tuple: (sampled_paths, sampled_indices, stats)
|
||||
"""
|
||||
import threading
|
||||
|
||||
# 为每个线程创建独立的随机数生成器
|
||||
thread_local = threading.local()
|
||||
|
||||
def get_thread_rng():
|
||||
"""获取当前线程的随机数生成器"""
|
||||
if not hasattr(thread_local, 'rng'):
|
||||
# 使用线程ID和种子创建独立的随机数生成器
|
||||
thread_id = threading.current_thread().ident
|
||||
thread_local.rng = np.random.RandomState(seed + hash(thread_id) % 1000000)
|
||||
return thread_local.rng
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(file_paths))
|
||||
|
||||
print(f"\n开始重采样,使用 {num_threads} 个线程...")
|
||||
print(f" 包含前3个token的采样比例: {contain_sample_ratio*100:.1f}%")
|
||||
print(f" 不包含前3个token的采样比例: {not_contain_sample_ratio*100:.1f}%")
|
||||
|
||||
def _should_sample_wrapper(args_tuple):
|
||||
"""判断是否采样的包装函数(用于多线程)"""
|
||||
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio = args_tuple
|
||||
try:
|
||||
# 使用线程本地的随机数生成器
|
||||
thread_rng = get_thread_rng()
|
||||
should_sample, in_top3 = should_sample_song(
|
||||
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio, thread_rng
|
||||
)
|
||||
return (file_path, should_sample, in_top3, None)
|
||||
except Exception as e:
|
||||
return (file_path, False, False, str(e))
|
||||
|
||||
# 准备参数
|
||||
sample_args = [(file_path, top_position_tokens, top_duration_tokens,
|
||||
contain_sample_ratio, not_contain_sample_ratio)
|
||||
for file_path in file_paths]
|
||||
|
||||
# 使用多线程判断每首歌是否应该采样
|
||||
sampled_paths = []
|
||||
sampled_indices = []
|
||||
stats = {
|
||||
'not_in_top3_count': 0, # 不在前3个的歌曲数量
|
||||
'not_in_top3_sampled': 0, # 不在前3个且被采样的歌曲数量
|
||||
'in_top3_count': 0, # 在前3个的歌曲数量
|
||||
'in_top3_sampled': 0 # 在前3个且被采样的歌曲数量
|
||||
}
|
||||
|
||||
# 限制并发任务数量,避免一次性提交过多任务
|
||||
batch_size = min(1000, len(file_paths))
|
||||
results = {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 分批提交任务
|
||||
for batch_start in range(0, len(sample_args), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(sample_args))
|
||||
batch_args = sample_args[batch_start:batch_end]
|
||||
|
||||
futures = {executor.submit(_should_sample_wrapper, args): args[0]
|
||||
for args in batch_args}
|
||||
|
||||
# 收集结果
|
||||
for future in tqdm(as_completed(futures), total=len(futures),
|
||||
desc=f"判断采样 [{batch_start+1}-{batch_end}/{len(file_paths)}]",
|
||||
leave=False):
|
||||
try:
|
||||
file_path, should_sample, in_top3, error = future.result(timeout=60)
|
||||
results[file_path] = (should_sample, in_top3, error)
|
||||
if error:
|
||||
print(f"警告: 处理 {file_path} 时出错: {error}")
|
||||
except Exception as e:
|
||||
print(f"错误: 获取结果时出错: {e}")
|
||||
|
||||
# 按原始顺序处理结果,并统计
|
||||
for idx, file_path in enumerate(file_paths):
|
||||
if file_path not in results:
|
||||
continue
|
||||
|
||||
should_sample, in_top3, error = results[file_path]
|
||||
if error:
|
||||
continue
|
||||
|
||||
# 统计信息
|
||||
if in_top3:
|
||||
stats['in_top3_count'] += 1
|
||||
if should_sample:
|
||||
stats['in_top3_sampled'] += 1
|
||||
else:
|
||||
stats['not_in_top3_count'] += 1
|
||||
if should_sample:
|
||||
stats['not_in_top3_sampled'] += 1
|
||||
|
||||
if should_sample:
|
||||
sampled_paths.append(file_path)
|
||||
sampled_indices.append(idx)
|
||||
|
||||
print(f"\n采样完成:")
|
||||
print(f" 总歌曲数: {len(file_paths)}")
|
||||
print(f" 采样歌曲数: {len(sampled_paths)}")
|
||||
print(f" 采样比例: {len(sampled_paths)/len(file_paths)*100:.2f}%")
|
||||
print(f" 不在前3个的歌曲数: {stats['not_in_top3_count']}")
|
||||
print(f" 不在前3个且被采样的歌曲数: {stats['not_in_top3_sampled']}")
|
||||
if stats['not_in_top3_count'] > 0:
|
||||
print(f" 不在前3个的歌曲采样比例: {stats['not_in_top3_sampled']/stats['not_in_top3_count']*100:.2f}%")
|
||||
print(f" 在前3个的歌曲数: {stats['in_top3_count']}")
|
||||
print(f" 在前3个且被采样的歌曲数: {stats['in_top3_sampled']}")
|
||||
if stats['in_top3_count'] > 0:
|
||||
print(f" 在前3个的歌曲采样比例: {stats['in_top3_sampled']/stats['in_top3_count']*100:.2f}%")
|
||||
|
||||
return sampled_paths, sampled_indices, stats
|
||||
|
||||
|
||||
def _save_single_file(args_tuple):
|
||||
"""
|
||||
保存单个文件的辅助函数(用于多线程)
|
||||
支持延迟加载:如果data是路径,则从文件加载
|
||||
|
||||
Args:
|
||||
args_tuple: (data, original_path, output_dir)
|
||||
|
||||
Returns:
|
||||
tuple: (success, original_path) 或 (False, original_path, error_msg)
|
||||
"""
|
||||
data, original_path, output_dir = args_tuple
|
||||
try:
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
data = np.load(data)['arr_0']
|
||||
|
||||
# 保持相对路径结构
|
||||
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
|
||||
output_path = output_dir / relative_path
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
np.savez_compressed(output_path, data)
|
||||
return (True, original_path)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
|
||||
return (False, original_path, error_msg)
|
||||
|
||||
|
||||
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
|
||||
"""
|
||||
保存采样后的数据(多线程版本)
|
||||
|
||||
Args:
|
||||
sampled_data: 采样后的数据列表
|
||||
sampled_paths: 采样后的文件路径列表
|
||||
output_dir: 输出目录
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n保存采样数据到: {output_dir}")
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(sampled_data))
|
||||
|
||||
# 准备参数
|
||||
save_args = [(data, original_path, output_dir)
|
||||
for data, original_path in zip(sampled_data, sampled_paths)]
|
||||
|
||||
# 使用多线程保存文件
|
||||
success_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(_save_single_file, args)
|
||||
for args in save_args]
|
||||
|
||||
# 收集结果
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
|
||||
try:
|
||||
result = future.result(timeout=300) # 设置超时避免卡死
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
success = result[0]
|
||||
if success:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"错误: 获取保存结果时出错: {e}")
|
||||
|
||||
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="基于position和duration token的重采样V2")
|
||||
parser.add_argument("--json_path", type=str,
|
||||
default="octuple_token_analysis_report.json",
|
||||
help="token分析报告JSON文件路径")
|
||||
parser.add_argument("--data_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
|
||||
help="数据目录路径")
|
||||
parser.add_argument("--output_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2",
|
||||
help="输出目录路径")
|
||||
parser.add_argument("--contain_sample_ratio", type=float, default=0.1,
|
||||
help="对于包含前3个token的歌曲,采样比例 (默认: 0.1)")
|
||||
parser.add_argument("--not_contain_sample_ratio", type=float, default=0.9,
|
||||
help="对于不包含前3个token的歌曲,采样比例 (默认: 0.9)")
|
||||
parser.add_argument("--top_k", type=int, default=3,
|
||||
help="使用前k个最常见的token (默认: 3)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子")
|
||||
parser.add_argument("--num_threads", type=int, default=None,
|
||||
help="线程数,None表示使用CPU核心数 (默认: None)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. 加载position和duration的前top_k个最常见token
|
||||
top_position_tokens = load_top_tokens_from_json(
|
||||
args.json_path, "position", top_k=args.top_k
|
||||
)
|
||||
top_duration_tokens = load_top_tokens_from_json(
|
||||
args.json_path, "duration", top_k=args.top_k
|
||||
)
|
||||
|
||||
if not top_position_tokens or not top_duration_tokens:
|
||||
print("错误: 无法加载前top_k个token")
|
||||
return
|
||||
|
||||
# 2. 获取所有数据文件路径
|
||||
file_paths = get_data_file_paths(args.data_dir)
|
||||
|
||||
if not file_paths:
|
||||
print("错误: 未找到任何数据文件")
|
||||
return
|
||||
|
||||
print(f"\n共找到 {len(file_paths)} 个数据文件")
|
||||
|
||||
# 3. 根据新逻辑进行重采样
|
||||
sampled_paths, sampled_indices, stats = resample_songs(
|
||||
file_paths,
|
||||
top_position_tokens,
|
||||
top_duration_tokens,
|
||||
contain_sample_ratio=args.contain_sample_ratio,
|
||||
not_contain_sample_ratio=args.not_contain_sample_ratio,
|
||||
num_threads=args.num_threads,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
# 4. 保存采样结果(延迟加载)
|
||||
sampled_data = sampled_paths # 使用路径,延迟加载
|
||||
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
|
||||
|
||||
# 5. 保存采样索引(可选,用于后续分析)
|
||||
indices_file = Path(args.output_dir) / "sampled_indices.npy"
|
||||
np.save(indices_file, np.array(sampled_indices))
|
||||
print(f"\n采样索引已保存到: {indices_file}")
|
||||
|
||||
# 6. 保存采样信息
|
||||
info = {
|
||||
"total_samples": len(file_paths),
|
||||
"sampled_samples": len(sampled_paths),
|
||||
"contain_sample_ratio": args.contain_sample_ratio,
|
||||
"not_contain_sample_ratio": args.not_contain_sample_ratio,
|
||||
"top_k": args.top_k,
|
||||
"top_position_tokens": sorted(list(top_position_tokens)),
|
||||
"top_duration_tokens": sorted(list(top_duration_tokens)),
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
info_file = Path(args.output_dir) / "resample_info.json"
|
||||
with open(info_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(info, f, indent=2, ensure_ascii=False)
|
||||
print(f"采样信息已保存到: {info_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,14 +1,282 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统计octuple分词结果中每一列每个token的出现次数,并生成分析报告
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import defaultdict, Counter
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
# 读取 npz 文件
|
||||
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
|
||||
# Octuple的列名定义
|
||||
COLUMN_NAMES = [
|
||||
"pitch", # 0: Pitch/PitchDrum
|
||||
"position", # 1: Position
|
||||
"bar", # 2: Bar
|
||||
"velocity", # 3: Velocity
|
||||
"duration", # 4: Duration
|
||||
"program", # 5: Program
|
||||
"tempo", # 6: Tempo
|
||||
"timesig" # 7: TimeSignature
|
||||
]
|
||||
|
||||
# 查看保存的键
|
||||
print(data.files)
|
||||
# 输出:['filename', 'sequence']
|
||||
|
||||
# 访问数据
|
||||
sequence = data["arr_0"]
|
||||
def load_octuple_data(data_dir):
|
||||
"""
|
||||
加载所有octuple分词后的.npz文件
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录
|
||||
|
||||
Returns:
|
||||
list: 所有加载的numpy数组列表
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
# 如果目录存在,查找所有.npz文件
|
||||
if data_dir.exists():
|
||||
npz_files = list(data_dir.rglob("*.npz"))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return []
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件,开始加载...")
|
||||
|
||||
all_data = []
|
||||
for npz_file in tqdm(npz_files, desc="加载数据"):
|
||||
try:
|
||||
data = np.load(npz_file)['arr_0']
|
||||
# 确保数据是二维数组 (num_tokens, num_columns)
|
||||
if data.ndim == 2:
|
||||
all_data.append(data)
|
||||
elif data.ndim == 1:
|
||||
# 如果是一维,可能需要reshape,但octuple应该是二维的
|
||||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||||
continue
|
||||
|
||||
return all_data
|
||||
|
||||
|
||||
def count_tokens_by_column(all_data):
|
||||
"""
|
||||
统计每一列每个token的出现次数
|
||||
|
||||
Args:
|
||||
all_data: 所有数据的列表,每个元素是一个numpy数组 (num_tokens, num_columns)
|
||||
|
||||
Returns:
|
||||
dict: {column_index: Counter({token_value: count})}
|
||||
"""
|
||||
column_counters = defaultdict(Counter)
|
||||
|
||||
print("统计token出现次数...")
|
||||
for data in tqdm(all_data, desc="处理数据"):
|
||||
if data.size == 0:
|
||||
continue
|
||||
|
||||
num_columns = data.shape[1] if data.ndim == 2 else 1
|
||||
|
||||
for col_idx in range(num_columns):
|
||||
if data.ndim == 2:
|
||||
tokens = data[:, col_idx]
|
||||
else:
|
||||
tokens = data
|
||||
|
||||
# 统计该列中每个token的出现次数
|
||||
unique, counts = np.unique(tokens, return_counts=True)
|
||||
for token, count in zip(unique, counts):
|
||||
column_counters[col_idx][int(token)] += int(count)
|
||||
|
||||
return dict(column_counters)
|
||||
|
||||
|
||||
def generate_report(column_counters, output_file=None):
|
||||
"""
|
||||
生成分析报告
|
||||
|
||||
Args:
|
||||
column_counters: 统计结果字典
|
||||
output_file: 输出文件路径(可选)
|
||||
"""
|
||||
report_lines = []
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("OCTUPLE分词结果统计分析报告")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
# 总体统计
|
||||
total_tokens = sum(sum(counter.values()) for counter in column_counters.values())
|
||||
report_lines.append(f"总token数: {total_tokens:,}")
|
||||
report_lines.append(f"分析的列数: {len(column_counters)}")
|
||||
report_lines.append("")
|
||||
|
||||
# 每一列的详细统计
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
|
||||
report_lines.append("-" * 80)
|
||||
report_lines.append(f"列 {col_idx}: {col_name}")
|
||||
report_lines.append("-" * 80)
|
||||
|
||||
total_in_column = sum(counter.values())
|
||||
unique_tokens = len(counter)
|
||||
min_token = min(counter.keys()) if counter else 0
|
||||
max_token = max(counter.keys()) if counter else 0
|
||||
|
||||
report_lines.append(f" 总token数: {total_in_column:,}")
|
||||
report_lines.append(f" 唯一token数: {unique_tokens:,}")
|
||||
report_lines.append(f" Token值范围: [{min_token}, {max_token}]")
|
||||
report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A")
|
||||
report_lines.append("")
|
||||
|
||||
# Top 20 最常见的token
|
||||
report_lines.append(f" Top 20 最常见的token:")
|
||||
top_tokens = counter.most_common(20)
|
||||
for rank, (token, count) in enumerate(top_tokens, 1):
|
||||
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
|
||||
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
|
||||
report_lines.append("")
|
||||
|
||||
# Top 20 最不常见的token(出现次数>0的)
|
||||
report_lines.append(f" Top 20 最不常见的token (出现次数>0):")
|
||||
bottom_tokens = counter.most_common()[-20:]
|
||||
bottom_tokens.reverse()
|
||||
for rank, (token, count) in enumerate(bottom_tokens, 1):
|
||||
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
|
||||
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
|
||||
report_lines.append("")
|
||||
|
||||
# 分布统计
|
||||
counts_list = list(counter.values())
|
||||
if counts_list:
|
||||
report_lines.append(f" 分布统计:")
|
||||
report_lines.append(f" 最小出现次数: {min(counts_list):,}")
|
||||
report_lines.append(f" 最大出现次数: {max(counts_list):,}")
|
||||
report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}")
|
||||
report_lines.append(f" 标准差: {np.std(counts_list):,.2f}")
|
||||
report_lines.append("")
|
||||
|
||||
# 跨列分析
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("跨列分析")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
total_in_column = sum(counter.values())
|
||||
percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0
|
||||
report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)")
|
||||
|
||||
report_lines.append("")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("报告生成完成")
|
||||
report_lines.append("=" * 80)
|
||||
|
||||
# 输出报告
|
||||
report_text = "\n".join(report_lines)
|
||||
print("\n" + report_text)
|
||||
|
||||
# 保存到文件
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report_text)
|
||||
print(f"\n报告已保存到: {output_path}")
|
||||
|
||||
# 同时保存JSON格式的详细数据
|
||||
if output_file:
|
||||
json_output = output_path.with_suffix('.json')
|
||||
json_data = {
|
||||
'summary': {
|
||||
'total_tokens': total_tokens,
|
||||
'num_columns': len(column_counters)
|
||||
},
|
||||
'columns': {}
|
||||
}
|
||||
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
json_data['columns'][col_name] = {
|
||||
'total_tokens': sum(counter.values()),
|
||||
'unique_tokens': len(counter),
|
||||
'token_counts': dict(counter),
|
||||
'top_20': dict(counter.most_common(20)),
|
||||
'bottom_20': dict(counter.most_common()[-20:])
|
||||
}
|
||||
|
||||
with open(json_output, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"详细数据已保存到: {json_output}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 默认数据目录 - 可以根据需要修改
|
||||
default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi"
|
||||
|
||||
# 可以指定具体的数据目录,例如:
|
||||
data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2"
|
||||
# 或者使用默认目录扫描所有oct8目录
|
||||
|
||||
# import sys
|
||||
# if len(sys.argv) > 1:
|
||||
# data_dir = sys.argv[1]
|
||||
# else:
|
||||
# # 自动查找所有oct8目录
|
||||
# base_dir = Path(default_data_dir)
|
||||
# oct8_dirs = list(base_dir.rglob("oct8"))
|
||||
# if oct8_dirs:
|
||||
# print(f"找到以下oct8目录:")
|
||||
# for i, d in enumerate(oct8_dirs, 1):
|
||||
# print(f" {i}. {d}")
|
||||
# if len(oct8_dirs) == 1:
|
||||
# data_dir = str(oct8_dirs[0])
|
||||
# print(f"\n使用目录: {data_dir}")
|
||||
# else:
|
||||
# # 使用第一个找到的目录,或者合并所有目录
|
||||
# print(f"\n使用第一个目录: {oct8_dirs[0]}")
|
||||
# print("如需分析其他目录,请指定路径作为参数")
|
||||
# data_dir = str(oct8_dirs[0])
|
||||
# else:
|
||||
# data_dir = default_data_dir
|
||||
# print(f"未找到oct8目录,使用默认目录: {data_dir}")
|
||||
|
||||
# 加载数据
|
||||
all_data = load_octuple_data(data_dir)
|
||||
|
||||
if not all_data:
|
||||
print("错误: 未加载到任何数据")
|
||||
return
|
||||
|
||||
# 检查数据格式
|
||||
if all_data:
|
||||
sample = all_data[0]
|
||||
print(f"\n数据格式检查:")
|
||||
print(f" 样本形状: {sample.shape}")
|
||||
print(f" 样本数据类型: {sample.dtype}")
|
||||
print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}")
|
||||
print()
|
||||
|
||||
# 统计token出现次数
|
||||
column_counters = count_tokens_by_column(all_data)
|
||||
|
||||
# 生成报告
|
||||
output_file = "octuple_token_analysis_report_part.txt"
|
||||
generate_report(column_counters, output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
print("token 序列长度:", len(sequence))
|
||||
print("前 20 个 token:", sequence[:20])
|
||||
139
dllm/.gitignore
vendored
Normal file
139
dllm/.gitignore
vendored
Normal file
@ -0,0 +1,139 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.idea
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
applications/DeepSpeed-Chat/data
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Others
|
||||
/.vscode/
|
||||
/tmp/
|
||||
/data/
|
||||
/wandb/
|
||||
/logs/
|
||||
/models*/
|
||||
4
dllm/.gitmodules
vendored
Normal file
4
dllm/.gitmodules
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
[submodule "lm-evaluation-harness"]
|
||||
path = lm-evaluation-harness
|
||||
url = https://github.com/ZHZisZZ/lm-evaluation-harness
|
||||
branch = dllm
|
||||
21
dllm/LICENSE
Normal file
21
dllm/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Zhanhui Zhou
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
283
dllm/README.md
Normal file
283
dllm/README.md
Normal file
@ -0,0 +1,283 @@
|
||||
<h1 align="center">dLLM</h1>
|
||||
|
||||
<p align="center">
|
||||
Simple Diffusion Language Modeling
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="assets/logo.gif"
|
||||
alt="dLLM logo">
|
||||
</p>
|
||||
|
||||
|
||||
## Overview
|
||||
**dLLM** is a library that unifies the training and evaluation of **diffusion language models**, bringing transparency and reproducibility to the entire development pipeline:
|
||||
|
||||
<!-- and [RND1](https://www.radicalnumerics.ai/assets/rnd1_report.pdf) -->
|
||||
|
||||
- dLLM provides scalable training pipelines (inspired by [`transformers`](https://github.com/huggingface/transformers/blob/main/src/transformers) [Trainer](https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py)), with support for [LoRA](https://github.com/huggingface/peft), [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) and [FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and beyond.
|
||||
|
||||
- dLLM provides unified evaluation pipelines (inspired by [`lm-evaluation-harness`](https://github.com/EleutherAI/lm-evaluation-harness)) that abstracts away inference details and making customization simple.
|
||||
|
||||
- Built on these components, dLLM provide the minimal **pretraining / finetuning / evaluation** recipes for open-weight models (e.g., [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)), and implementations of training algorithms (e.g., [Edit Flows](https://arxiv.org/abs/2506.09018)).
|
||||
|
||||
<!-- > [!NOTE]
|
||||
> This repository is primarily for educational purposes and does not aim for 100% exact reproduction of official models (which is impossible). We hope it serves as a helpful reference for the community — contributions and improvements are always welcome! -->
|
||||
|
||||
|
||||
## News
|
||||
**[2025/11]** We released a collection of BERTs finetuned for instruction-following: [`ModernBERT-{large,base}-chat-v0`](https://huggingface.co/collections/dllm-collection/bert-chat). This proof-of-concept shows that BERT’s internal knowledge can be leveraged for generative tasks via masked instruction tuning. See [ BERT Chat Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg) for detailed recipes, experimental results and lessons learned; See [`examples/bert`](/examples/bert) for training / inference / evaluation instructions.
|
||||
|
||||
|
||||
## Table of Contents
|
||||
- [Features](#features)
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Inference](#inference)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Citation](#citation)
|
||||
|
||||
|
||||
## Features
|
||||
<!-- - [`examples/rnd`](/examples/rnd): (WIP) Finetuning open-weight RND1 [RND1-Base](https://www.radicalnumerics.ai/assets/rnd1_report.pdf). -->
|
||||
- [`examples/llada`](/examples/llada): Pretraining, finetuning and evaluating LLaDA [LLaDA](https://arxiv.org/abs/2502.09992) / [LLaDA-MoE](https://arxiv.org/abs/2509.24389).
|
||||
- [`examples/dream`](/examples/dream): Pretraining, finetuning and evaluating Dream [Dream](https://arxiv.org/abs/2508.15487).
|
||||
- [`examples/bert`](/examples/bert): Finetuning any [BERT](https://arxiv.org/abs/1810.04805) to be lightweight Chatbots.
|
||||
<details>
|
||||
<summary>🎬 Click to show BERT Chat Demo</summary>
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/bert/assets/chat.gif" alt="chat" width="80%">
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>
|
||||
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
|
||||
</em>
|
||||
</p>
|
||||
</details>
|
||||
- [`examples/editflow`](/examples/editflow): Educational reference for training [EditFlow](https://arxiv.org/abs/2506.09018) models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT Chat) with *edit operations*—insertion, deletion, and substitution—and how to pretrain or finetune EditFlow models from scratch on public data.
|
||||
|
||||
<details>
|
||||
<summary>🎬 Click to show EditFlow Demo</summary>
|
||||
|
||||
<p align="center">
|
||||
<img src="/examples/editflow/assets/all.gif" alt="EditFlow demo" width="100%">
|
||||
</p>
|
||||
<p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p>
|
||||
|
||||
</details>
|
||||
- More upcoming.
|
||||
|
||||
|
||||
## Setup
|
||||
### Installation
|
||||
```bash
|
||||
# create and activate conda environment
|
||||
conda create -n dllm python=3.10 -y
|
||||
conda activate dllm
|
||||
|
||||
# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
|
||||
conda install cuda=12.4 -c nvidia
|
||||
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
|
||||
--index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# install dllm package
|
||||
pip install -e .
|
||||
```
|
||||
### (optional) Evaluation setup
|
||||
|
||||
```bash
|
||||
# initialize `lm-evaluation-harness` submodule
|
||||
git submodule update --init --recursive
|
||||
|
||||
# install submodule in editable mode with IFEval & Math dependencies
|
||||
pip install -e "lm-evaluation-harness[ifeval,math]"
|
||||
```
|
||||
|
||||
### (optional) Slurm setup
|
||||
For [Slurm](https://slurm.schedmd.com/) users, update [`scripts/train.slurm.sh`](/scripts/train.slurm.sh) for your cluster:
|
||||
```diff
|
||||
- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
|
||||
- #SBATCH --quotatype=spot # Note: adjust this for your cluster
|
||||
+ #SBATCH --partition=YOUR_PARTITION
|
||||
+ #SBATCH --quotatype=YOUR_QUOTATYPE
|
||||
```
|
||||
Next, create a directory for your job logs:
|
||||
```shell
|
||||
mkdir logs
|
||||
```
|
||||
This folder will store the log files generated by your sbatch jobs.
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# modules for training / sampling
|
||||
dllm
|
||||
├── core # Core reusable modules shared across `dllm/pipelines`
|
||||
│ ├── generation
|
||||
│ ├── schedulers
|
||||
│ └── trainers
|
||||
├── data
|
||||
├── pipelines # Application-specific training & inference pipelines
|
||||
| ├── bert
|
||||
│ ├── dream
|
||||
│ ├── editflow
|
||||
│ └── llada
|
||||
│ ├── models # Model architecture and configs
|
||||
│ ├── generator.py # Generation utilities
|
||||
│ ├── trainer.py # Core training logic
|
||||
│ └── eval.py # Evaluation entry point
|
||||
├── tools
|
||||
└── utils
|
||||
|
||||
# entry points for training / sampling
|
||||
examples
|
||||
├── bert
|
||||
├── dream
|
||||
├── editflow
|
||||
└── llada
|
||||
├── chat.py # Interactive inference example
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
├── sft.py # Supervised finetuning example
|
||||
└── eval.sh # Evalution script
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
A typical training entry script looks like (for example, [`examples/llada/sft.py`](/examples/llada/sft.py)) looks like this:
|
||||
```python
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
dataset = "..."
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
args=training_args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
You can launch training job locally with `accelerate`, or submit it to a [Slurm](https://slurm.schedmd.com/) cluster using `sbatch`.
|
||||
```shell
|
||||
# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA)
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/zero2.yaml \
|
||||
examples/llada/sft.py \
|
||||
--num_train_epochs 4 \
|
||||
--load_in_4bit True --lora True
|
||||
```
|
||||
```shell
|
||||
# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs)
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--num_train_epochs 4
|
||||
|
||||
# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs)
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/llada/sft.py" \
|
||||
--num_train_epochs 4
|
||||
```
|
||||
See [Features](#features) for specific training recipes.
|
||||
|
||||
|
||||
> Here are some useful tips for training:
|
||||
> 1. Use a subset of data:
|
||||
> `--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]"`
|
||||
> 2. Concatenate datasets:
|
||||
> `--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk"`
|
||||
> 3. Train with LoRA and 4bit quantization:
|
||||
> `--load_in_4bit True --lora True`
|
||||
> 4. Train with different distributed training methods:
|
||||
> `--accelerate_config "ddp,zero-{1,2,3},fsdp"`
|
||||
|
||||
## Inference
|
||||
|
||||
We provide unified [generators](/dllm/core/generation/generator.py) that abstracts away inference details.
|
||||
A typical inference entry script (for example, [`examples/llada/generate.py`](/examples/llada/generate.py)) looks like this:
|
||||
```python
|
||||
import dllm
|
||||
from dllm import llada
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
# for other models, change your generator and keep others unchanged
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
outputs = generator.generate(inputs, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
```
|
||||
|
||||
You can also try interactive chat script (for example, [`examples/llada/chat.py`](/examples/llada/chat.py)) for visualized multi-turn dialogue:
|
||||
|
||||
<p align="center">
|
||||
<img src="/assets/chat.gif" alt="chat" width="80%">
|
||||
</p>
|
||||
<!-- <p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p> -->
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [`MMLU_Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro), run:
|
||||
```shell
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/llada/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "llada" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
|
||||
```
|
||||
|
||||
We also provide scripts to automatically evaluate [LLaDA](https://arxiv.org/abs/2502.09992), [Dream](https://arxiv.org/abs/2508.15487), and [BERT-Chat](https://huggingface.co/collections/dllm-collection/bert-chat) on all benchmarks.
|
||||
For example, you can launch [`examples/llada/eval.sh`](/examples/llada/eval.sh) directly using the following commands:
|
||||
```shell
|
||||
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True
|
||||
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
```
|
||||
@misc{dllm,
|
||||
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
|
||||
title = {dLLM: Simple Diffusion Language Modeling},
|
||||
year = {2025},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
|
||||
}
|
||||
```
|
||||
BIN
dllm/assets/JetBrainsMono-VariableFont_wght.ttf
Normal file
BIN
dllm/assets/JetBrainsMono-VariableFont_wght.ttf
Normal file
Binary file not shown.
BIN
dllm/assets/chat.gif
Normal file
BIN
dllm/assets/chat.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.5 MiB |
BIN
dllm/assets/logo.gif
Normal file
BIN
dllm/assets/logo.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 956 KiB |
BIN
dllm/assets/logo.png
Normal file
BIN
dllm/assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.0 KiB |
119
dllm/assets/logo.py
Normal file
119
dllm/assets/logo.py
Normal file
@ -0,0 +1,119 @@
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os
|
||||
|
||||
# ---------- Configuration (smaller size) ----------
|
||||
W, H = 480, 210 # lower resolution
|
||||
TOTAL_DURATION = 3.0
|
||||
FPS = 15 # lower fps
|
||||
TEXT = "dLLM"
|
||||
# TEXT_COLOR = (235, 235, 235)
|
||||
TEXT_COLOR = (0, 0, 0)
|
||||
OUTPUT = "logo.gif"
|
||||
LAST_FRAME_PNG = "logo.png"
|
||||
|
||||
DIFFUSION_PORTION = 0.3 # fewer diffusion frames
|
||||
SEED = 8
|
||||
|
||||
|
||||
# ---------- Auto font size ----------
|
||||
def load_font_auto_size(text, w, h, target_width_ratio=0.95, target_height_ratio=0.95):
|
||||
lo, hi = 10, 2000
|
||||
best_font, best_size = None, lo
|
||||
while lo <= hi:
|
||||
mid = (lo + hi) // 2
|
||||
try:
|
||||
font = ImageFont.truetype(
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", size=mid
|
||||
)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
dummy = Image.new("L", (w, h), 0)
|
||||
d = ImageDraw.Draw(dummy)
|
||||
bbox = d.textbbox((0, 0), text, font=font)
|
||||
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
|
||||
width_ok = tw <= w * target_width_ratio
|
||||
height_ok = th <= h * target_height_ratio
|
||||
|
||||
if width_ok and height_ok:
|
||||
best_font, best_size = font, mid
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid - 1
|
||||
|
||||
return best_font if best_font is not None else font
|
||||
|
||||
|
||||
# ---------- Text rendering ----------
|
||||
def render_text_mask(w, h, text, font):
|
||||
img = Image.new("L", (w, h), 0)
|
||||
d = ImageDraw.Draw(img)
|
||||
bbox = d.textbbox((0, 0), text, font=font)
|
||||
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
x = (w - tw) // 2 - bbox[0]
|
||||
y = (h - th) // 2 - bbox[1]
|
||||
d.text((x, y), text, font=font, fill=255)
|
||||
return np.asarray(img, np.float32) / 255.0
|
||||
|
||||
|
||||
# ---------- Initialization ----------
|
||||
font = load_font_auto_size(TEXT, W, H)
|
||||
mask = render_text_mask(W, H, TEXT, font)
|
||||
|
||||
num_frames = int(TOTAL_DURATION * FPS)
|
||||
diffusion_frames = max(1, int(num_frames * DIFFUSION_PORTION))
|
||||
hold_ms = int((TOTAL_DURATION - diffusion_frames / FPS) * 1000)
|
||||
|
||||
rng = np.random.default_rng(SEED)
|
||||
frames = []
|
||||
|
||||
# ---------- Diffusion stage ----------
|
||||
for i in range(diffusion_frames):
|
||||
t = i / max(1, diffusion_frames - 1)
|
||||
progress = t**0.9
|
||||
noise_sigma = (1.0 - progress) ** 2.2
|
||||
|
||||
noise = rng.standard_normal((H, W, 1)).astype(np.float32)
|
||||
noise_img = 1.0 - noise_sigma * 0.5 * np.abs(noise)
|
||||
np.clip(noise_img, 0.0, 1.0, out=noise_img)
|
||||
|
||||
alpha = progress**2.0
|
||||
alpha_map = (mask * alpha).astype(np.float32)[..., None]
|
||||
|
||||
text_rgb = np.zeros((H, W, 3), dtype=np.float32)
|
||||
for c in range(3):
|
||||
text_rgb[..., c] = (mask > 0).astype(np.float32) * (TEXT_COLOR[c] / 255.0)
|
||||
|
||||
frame = (1.0 - alpha_map) * noise_img + alpha_map * text_rgb
|
||||
frame = (np.clip(frame, 0.0, 1.0) * 255).astype(np.uint8)
|
||||
frames.append(Image.fromarray(frame, mode="RGB"))
|
||||
|
||||
# ---------- Last frame ----------
|
||||
final_frame = frames[-1]
|
||||
|
||||
# ---------- Save last frame as PNG ----------
|
||||
final_frame.save(LAST_FRAME_PNG)
|
||||
print(f"🖼️ Last frame saved as: {LAST_FRAME_PNG}")
|
||||
|
||||
# ---------- Quantization (reduce size) ----------
|
||||
pal_frames = [f.convert("P", palette=Image.ADAPTIVE, colors=64) for f in frames]
|
||||
pal_final = final_frame.convert("P", palette=Image.ADAPTIVE, colors=64)
|
||||
|
||||
# ---------- Save GIF ----------
|
||||
normal_ms = int(1000 / FPS)
|
||||
durations = [normal_ms] * len(pal_frames) + [hold_ms]
|
||||
|
||||
pal_frames[0].save(
|
||||
OUTPUT,
|
||||
save_all=True,
|
||||
append_images=pal_frames[1:] + [pal_final],
|
||||
duration=durations,
|
||||
loop=0,
|
||||
optimize=True,
|
||||
)
|
||||
|
||||
print(f"✅ GIF saved: {OUTPUT}")
|
||||
print(
|
||||
f"Frames (diffusion only): {len(pal_frames)} at {FPS} FPS, final hold {hold_ms} ms, resolution {W}x{H}"
|
||||
)
|
||||
1
dllm/dllm/__init__.py
Normal file
1
dllm/dllm/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import core, data, pipelines, utils
|
||||
1
dllm/dllm/core/__init__.py
Normal file
1
dllm/dllm/core/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from dllm.core import trainers, schedulers, generation
|
||||
1
dllm/dllm/core/generation/__init__.py
Normal file
1
dllm/dllm/core/generation/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import generator, visualizer
|
||||
49
dllm/dllm/core/generation/generator.py
Normal file
49
dllm/dllm/core/generation/generator.py
Normal file
@ -0,0 +1,49 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedModel
|
||||
|
||||
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorOutput:
|
||||
sequences: torch.Tensor
|
||||
histories: list[torch.Tensor] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig:
|
||||
return_dict_in_generate: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGenerator(ABC):
|
||||
model: PreTrainedModel
|
||||
tokenizer: PreTrainedTokenizer
|
||||
scheduler: BaseAlphaScheduler | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.scheduler is None:
|
||||
self.scheduler = LinearAlphaScheduler()
|
||||
|
||||
@abstractmethod
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[torch.Tensor, list],
|
||||
config: GeneratorConfig | None = None,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@torch.no_grad()
|
||||
def infill(
|
||||
self,
|
||||
inputs: list[torch.Tensor, list],
|
||||
config: GeneratorConfig | None = None,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput:
|
||||
raise NotImplementedError
|
||||
427
dllm/dllm/core/generation/visualizer.py
Normal file
427
dllm/dllm/core/generation/visualizer.py
Normal file
@ -0,0 +1,427 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseVisualizer(ABC):
|
||||
tokenizer: PreTrainedTokenizer
|
||||
|
||||
@abstractmethod
|
||||
def visualize(self, history: list[torch.Tensor, list], **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoVisualizer(BaseVisualizer):
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
history: list[torch.Tensor, list],
|
||||
output_path: str = "visualization.gif",
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class TerminalVisualizer(BaseVisualizer):
|
||||
|
||||
# Configuration (adjust as needed)
|
||||
HEADER_SIZE = 3 # Fixed number of lines for the header (0 if show_header is False)
|
||||
PROGRESS_SIZE = 3 # Fixed number of lines for the progress bar
|
||||
PANEL_PADDING_TOP = 1 # Top padding of the Panel (padding=(top, side))
|
||||
PANEL_PADDING_BOTTOM = 1 # Bottom padding of the Panel
|
||||
PANEL_PADDING_SIDE = 1 # Number of characters used for left and right padding
|
||||
PANEL_BORDER = 2 # Number of columns taken by the Panel border (usually 2)
|
||||
MIN_TOTAL_HEIGHT = 10 # Minimum terminal height (in lines)
|
||||
MAX_TOTAL_HEIGHT = 60 # Maximum terminal height to prevent overflowing the terminal
|
||||
DEFAULT_TERM_WIDTH = 120 # Default terminal width (in columns)
|
||||
ansi_escape = re.compile(r"\x1b\[[0-9;]*m") # Regex to match ANSI escape codes
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
|
||||
fps: int = 16,
|
||||
rich: bool = True,
|
||||
title: str = "dllm",
|
||||
max_chars: int = None,
|
||||
every_n_steps: int = 1,
|
||||
show_header: bool = True,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Visualize a masked-diffusion decoding trajectory stored in `history`.
|
||||
If items have batch dimension [B, T], visualize each sequence separately.
|
||||
"""
|
||||
try:
|
||||
# detect batch size
|
||||
first_step = history[0]
|
||||
if first_step.dim() > 1 and first_step.shape[0] > 1:
|
||||
B = first_step.shape[0]
|
||||
for b_idx in range(B):
|
||||
# build per-sequence history
|
||||
seq_history = [step[b_idx].unsqueeze(0) for step in history]
|
||||
self.visualize_one_history(
|
||||
seq_history,
|
||||
fps,
|
||||
rich,
|
||||
title=f"{title} (Batch {b_idx})",
|
||||
max_chars=max_chars,
|
||||
every_n_steps=every_n_steps,
|
||||
show_header=show_header,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
else:
|
||||
# no batch, just visualize normally
|
||||
self.visualize_one_history(
|
||||
history,
|
||||
fps,
|
||||
rich,
|
||||
title,
|
||||
max_chars,
|
||||
every_n_steps,
|
||||
show_header,
|
||||
skip_special_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"(Visualization skipped due to error: {e})")
|
||||
|
||||
def visualize_one_history(
|
||||
self,
|
||||
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
|
||||
fps: int = 16,
|
||||
rich: bool = True,
|
||||
title: str = "dllm",
|
||||
max_chars: int = None,
|
||||
every_n_steps: int = 1, # re-render frequency (perf knob)
|
||||
show_header: bool = True,
|
||||
skip_special_tokens: bool = False, # NEW ARGUMENT
|
||||
) -> None:
|
||||
"""
|
||||
Visualize a masked-diffusion decoding trajectory stored in `history`.
|
||||
|
||||
Args:
|
||||
history: Sequence of token tensors for each step. Each item is [T] or [B,T].
|
||||
fps: Frames per second for the live UI (Rich) or sleep cadence for tqdm fallback.
|
||||
title: Header title.
|
||||
max_chars: Cap on rendered characters to keep terminal snappy.
|
||||
every_n_steps: Only redraw text every N steps (progress still updates every step).
|
||||
show_header: Show the magenta header bar (Rich path).
|
||||
skip_special_tokens: Whether to skip special/pad/eos tokens when rendering (default: False).
|
||||
Notes:
|
||||
- Masked positions are detected via `self.tokenizer.mask_token_id`.
|
||||
- Special tokens are determined via `self.tokenizer.all_special_ids`.
|
||||
- All layout, styling, and progress are encapsulated here.
|
||||
"""
|
||||
# --------- imports & env checks ----------
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.text import Text
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
TextColumn,
|
||||
TimeRemainingColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
)
|
||||
from rich.layout import Layout
|
||||
|
||||
_RICH_IMPORTED = True
|
||||
except Exception:
|
||||
_RICH_IMPORTED = False
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
_TQDM_IMPORTED = True
|
||||
except Exception:
|
||||
_TQDM_IMPORTED = False
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"TerminalVisualizer.tokenizer must be set to a valid tokenizer."
|
||||
)
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
specials: set[int] = set(getattr(tokenizer, "all_special_ids", []) or [])
|
||||
self._specials = specials # store for helpers
|
||||
self._mask_token_id: Optional[int] = getattr(tokenizer, "mask_token_id", None)
|
||||
self._pad_token_id: Optional[int] = getattr(tokenizer, "pad_token_id", None)
|
||||
self._eos_token_id: Optional[int] = getattr(tokenizer, "eos_token_id", None)
|
||||
|
||||
# --------- helpers inside class scope ----------
|
||||
# (keep everything inside this class as requested)
|
||||
|
||||
# throttle settings
|
||||
sleep_s = 0.0 if fps <= 0 else 1.0 / float(max(1, fps))
|
||||
total_steps = len(history)
|
||||
every_n_steps = max(1, int(every_n_steps))
|
||||
|
||||
# decode final text up-front (used after render)
|
||||
final_text = self._detok(history[-1], skip_special_tokens=skip_special_tokens)
|
||||
final_text = self._truncate(final_text, max_chars)
|
||||
|
||||
# ------------------ new: estimate height from final_text ------------------
|
||||
import textwrap
|
||||
import shutil
|
||||
|
||||
def strip_ansi(s: str) -> str:
|
||||
return self.ansi_escape.sub("", s) if s else ""
|
||||
|
||||
def estimate_height_from_text(text: str, console_width: int) -> int:
|
||||
"""
|
||||
Estimate how many terminal rows the panel with `text` will need given console_width.
|
||||
Uses class constants for paddings/borders and header/progress sizes.
|
||||
"""
|
||||
plain = strip_ansi(text or "")
|
||||
# inner width = console width minus left/right panel paddings & border
|
||||
inner_width = max(
|
||||
10, console_width - 2 * self.PANEL_PADDING_SIDE - self.PANEL_BORDER
|
||||
)
|
||||
lines = 0
|
||||
# preserve existing newlines: wrap each paragraph separately
|
||||
for para in plain.splitlines() or [""]:
|
||||
if para.strip() == "":
|
||||
lines += 1
|
||||
continue
|
||||
wrapped = textwrap.wrap(
|
||||
para,
|
||||
width=inner_width,
|
||||
replace_whitespace=False,
|
||||
drop_whitespace=False,
|
||||
)
|
||||
lines += max(1, len(wrapped))
|
||||
text_block_lines = (
|
||||
lines + self.PANEL_PADDING_TOP + self.PANEL_PADDING_BOTTOM
|
||||
)
|
||||
extra = 2 # for panel title / subtitle / small margin
|
||||
header_h = self.HEADER_SIZE if show_header else 0
|
||||
total = header_h + text_block_lines + self.PROGRESS_SIZE + extra
|
||||
# clamp
|
||||
total = max(self.MIN_TOTAL_HEIGHT, min(total, self.MAX_TOTAL_HEIGHT))
|
||||
return int(total)
|
||||
|
||||
# try to detect terminal width; fallback to 100
|
||||
try:
|
||||
term_width = shutil.get_terminal_size().columns
|
||||
if not isinstance(term_width, int) or term_width <= 0:
|
||||
term_width = self.DEFAULT_TERM_WIDTH
|
||||
except Exception:
|
||||
term_width = self.DEFAULT_TERM_WIDTH
|
||||
|
||||
est_height = estimate_height_from_text(final_text, console_width=term_width)
|
||||
# ------------------ end new ----------------------------------------------
|
||||
|
||||
# choose rich or tqdm
|
||||
use_rich = bool(rich and _RICH_IMPORTED)
|
||||
|
||||
if not use_rich or not _RICH_IMPORTED:
|
||||
# ---------- tqdm fallback ----------
|
||||
if not _TQDM_IMPORTED:
|
||||
for i, toks in enumerate(history, start=1):
|
||||
if sleep_s > 0:
|
||||
time.sleep(sleep_s)
|
||||
print("\n✨ Generation complete!\n")
|
||||
print(final_text)
|
||||
return
|
||||
|
||||
pbar = tqdm(total=total_steps, desc="Diffusion", leave=True)
|
||||
for i, toks in enumerate(history, start=1):
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"masks": self._count_masks(toks),
|
||||
"pct": f"{int(100 * i / max(total_steps, 1))}%",
|
||||
}
|
||||
)
|
||||
if sleep_s > 0:
|
||||
time.sleep(sleep_s)
|
||||
pbar.close()
|
||||
print("\n✨ Generation complete!\n")
|
||||
if final_text:
|
||||
print(final_text)
|
||||
return
|
||||
|
||||
# ---------- rich live UI ----------
|
||||
# replaced fixed height=100 with the estimated height from history[-1]
|
||||
console = Console(
|
||||
force_terminal=True,
|
||||
color_system="truecolor",
|
||||
width=term_width,
|
||||
height=est_height,
|
||||
)
|
||||
layout = Layout()
|
||||
layout.split_column(
|
||||
(
|
||||
Layout(name="header", size=3)
|
||||
if show_header
|
||||
else Layout(name="header", size=0)
|
||||
),
|
||||
Layout(name="text", ratio=1),
|
||||
Layout(name="progress", size=3),
|
||||
)
|
||||
|
||||
progress = Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]Diffusion"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TextColumn("•"),
|
||||
TextColumn("[cyan]Masks: {task.fields[masks]}"),
|
||||
TextColumn("•"),
|
||||
TextColumn("[magenta]{task.fields[pct]:>4s}"),
|
||||
TimeRemainingColumn(),
|
||||
expand=True,
|
||||
)
|
||||
|
||||
init_masks = self._count_masks(history[0]) if history else 0
|
||||
task_id = progress.add_task(
|
||||
"Generating", total=total_steps, masks=init_masks, pct="0%"
|
||||
)
|
||||
|
||||
with Live(layout, console=console, refresh_per_second=max(1, fps)):
|
||||
for step_idx, toks in enumerate(history, start=1):
|
||||
if show_header:
|
||||
header = Text(title, style="bold magenta", justify="center")
|
||||
layout["header"].update(Panel(header, border_style="bright_blue"))
|
||||
|
||||
# progress bar
|
||||
masks_remaining = self._count_masks(toks)
|
||||
pct = f"{int(100 * step_idx / max(total_steps, 1))}%"
|
||||
progress.update(task_id, advance=1, masks=masks_remaining, pct=pct)
|
||||
|
||||
# text panel: decode whole sequence (avoids Ġ/Ċ artifacts)
|
||||
if (
|
||||
every_n_steps <= 1
|
||||
or (step_idx % every_n_steps == 0)
|
||||
or step_idx in (1, total_steps)
|
||||
):
|
||||
text_str = self._detok(
|
||||
toks, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
text_str = self._truncate(text_str, max_chars)
|
||||
text_rich = Text.from_ansi(text_str) if text_str else Text("")
|
||||
layout["text"].update(
|
||||
Panel(
|
||||
(
|
||||
text_rich
|
||||
if text_rich.plain
|
||||
else Text("[dim]— no tokens —[/dim]")
|
||||
),
|
||||
title="[bold]Generated Text",
|
||||
subtitle=f"[dim]Step {step_idx}/{total_steps}[/dim]",
|
||||
border_style="cyan",
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
|
||||
layout["progress"].update(Panel(progress))
|
||||
if sleep_s > 0:
|
||||
time.sleep(sleep_s)
|
||||
|
||||
console.print("\n[bold green]✨ Generation complete![/bold green]\n")
|
||||
# console.print(
|
||||
# Panel(
|
||||
# final_text if final_text else "[dim]— no decodable text —[/dim]",
|
||||
# title="[bold]Final Generated Text",
|
||||
# border_style="green",
|
||||
# padding=(1, 2),
|
||||
# )
|
||||
# )
|
||||
|
||||
# ======================== helpers (kept inside class) ========================
|
||||
|
||||
def _has_tty(self) -> bool:
|
||||
return sys.stdout.isatty() and os.environ.get("TERM", "") not in ("", "dumb")
|
||||
|
||||
def _first_item(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x[0] if x.dim() > 1 else x
|
||||
|
||||
def _count_masks(self, toks: torch.Tensor) -> int:
|
||||
if getattr(self, "_mask_token_id", None) is None:
|
||||
return 0
|
||||
t = self._first_item(toks)
|
||||
return int((t == self._mask_token_id).sum().item())
|
||||
|
||||
def _detok(self, ids_or_tensor, *, skip_special_tokens: bool) -> str:
|
||||
"""
|
||||
Robust detokenize for list[int] / torch.Tensor([T]) / torch.Tensor([B,T]).
|
||||
Decode the whole sequence to avoid byte-level artifacts like Ġ/Ċ.
|
||||
"""
|
||||
tokenizer = self.tokenizer
|
||||
# normalize to python list[int]
|
||||
if isinstance(ids_or_tensor, torch.Tensor):
|
||||
t = self._first_item(ids_or_tensor).long()
|
||||
ids = t.tolist()
|
||||
elif isinstance(ids_or_tensor, (list, tuple)):
|
||||
ids = list(ids_or_tensor)
|
||||
else:
|
||||
# unknown type
|
||||
return ""
|
||||
|
||||
# Optionally drop specials/pad/eos *before* decode if desired
|
||||
if skip_special_tokens:
|
||||
keep = []
|
||||
specials = getattr(self, "_specials", set())
|
||||
pad_id = getattr(self, "_pad_token_id", None)
|
||||
eos_id = getattr(self, "_eos_token_id", None)
|
||||
for tid in ids:
|
||||
if tid in specials:
|
||||
continue
|
||||
if pad_id is not None and tid == pad_id:
|
||||
continue
|
||||
if eos_id is not None and tid == eos_id:
|
||||
continue
|
||||
keep.append(tid)
|
||||
ids = keep
|
||||
|
||||
# Prefer tokenizer.decode (handles Ġ/Ċ, merges properly)
|
||||
text = ""
|
||||
try:
|
||||
if hasattr(tokenizer, "decode"):
|
||||
text = tokenizer.decode(
|
||||
ids,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
else:
|
||||
# fallback: tokens -> string
|
||||
toks = tokenizer.convert_ids_to_tokens(ids)
|
||||
if hasattr(tokenizer, "convert_tokens_to_string"):
|
||||
text = tokenizer.convert_tokens_to_string(toks)
|
||||
else:
|
||||
text = " ".join(map(str, toks))
|
||||
except Exception:
|
||||
# extremely defensive fallback
|
||||
try:
|
||||
text = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
# sanitize control chars for terminal
|
||||
if text:
|
||||
text = text.replace("\r", "")
|
||||
return text
|
||||
|
||||
def _truncate(self, s: str, max_chars: Optional[int]) -> str:
|
||||
if max_chars is None or (isinstance(max_chars, int) and max_chars < 0):
|
||||
return s
|
||||
return s[:max_chars]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
2
dllm/dllm/core/schedulers/__init__.py
Normal file
2
dllm/dllm/core/schedulers/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .alpha import *
|
||||
from .kappa import *
|
||||
132
dllm/dllm/core/schedulers/alpha.py
Normal file
132
dllm/dllm/core/schedulers/alpha.py
Normal file
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import math
|
||||
from typing import ClassVar, Dict, Type, Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
Number = Union[float, torch.Tensor]
|
||||
|
||||
|
||||
# ---------------- Registry-enabled Base ---------------- #
|
||||
@dataclasses.dataclass
|
||||
class BaseAlphaScheduler:
|
||||
__registry__: ClassVar[dict[str, type[BaseAlphaScheduler]]] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
BaseAlphaScheduler.__registry__[cls.__name__] = cls
|
||||
BaseAlphaScheduler.__registry__[cls.__name__.lower()] = cls
|
||||
|
||||
# Make instances callable (sched(i) -> alpha(i))
|
||||
def __call__(self, t: Number) -> Number:
|
||||
return self.alpha(t)
|
||||
|
||||
# ---- common API ----
|
||||
def alpha(self, i: Number) -> Number:
|
||||
i_t = torch.as_tensor(
|
||||
i,
|
||||
dtype=torch.float32,
|
||||
device=i.device if isinstance(i, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
|
||||
raise ValueError(f"i={i} not in [0,1]")
|
||||
out = self._alpha(i_t)
|
||||
return out.item() if isinstance(i, float) else out
|
||||
|
||||
def alpha_derivative(self, i: Number) -> Number:
|
||||
i_t = torch.as_tensor(
|
||||
i,
|
||||
dtype=torch.float32,
|
||||
device=i.device if isinstance(i, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
|
||||
raise ValueError(f"i={i} not in [0,1]")
|
||||
out = self._alpha_derivative(i_t)
|
||||
return out.item() if isinstance(i, float) else out
|
||||
|
||||
def reverse_mask_prob(self, s: Number, t: Number) -> Number:
|
||||
t_t = torch.as_tensor(
|
||||
t,
|
||||
dtype=torch.float32,
|
||||
device=t.device if isinstance(t, torch.Tensor) else None,
|
||||
)
|
||||
s_t = torch.as_tensor(
|
||||
s,
|
||||
dtype=torch.float32,
|
||||
device=s.device if isinstance(s, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= s_t) & (s_t < 1.0) & (0.0 < t_t) & (t_t <= 1.0)):
|
||||
raise ValueError(f"(t={t}, s={s}) out of range")
|
||||
if not torch.all(s_t < t_t):
|
||||
raise ValueError(f"Require s < t elementwise, but got (t={t}, s={s})")
|
||||
out = (1 - self(s_t)) / (1 - self(t_t))
|
||||
return out.item() if isinstance(t, float) and isinstance(s, float) else out
|
||||
|
||||
def weight(self, i: Number) -> Number:
|
||||
# w(t) = - α'(t) / (1 - α(t))
|
||||
return - self.alpha_derivative(i) / (1 - self.alpha(i) + 1e-6)
|
||||
|
||||
# ---- hooks implemented by subclasses ----
|
||||
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------- Implementations ---------------- #
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LinearAlphaScheduler(BaseAlphaScheduler):
|
||||
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return 1 - i
|
||||
|
||||
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return -torch.ones_like(i)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CosineAlphaScheduler(BaseAlphaScheduler):
|
||||
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return 1 - torch.cos((math.pi / 2) * (1 - i))
|
||||
|
||||
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
|
||||
return -(math.pi / 2) * torch.sin((math.pi / 2) * (1 - i))
|
||||
|
||||
|
||||
# ---------------- Factory helpers ---------------- #
|
||||
|
||||
|
||||
def get_alpha_scheduler_class(name: str) -> type[BaseAlphaScheduler]:
|
||||
"""Return the scheduler class by name (case-insensitive)."""
|
||||
cls = BaseAlphaScheduler.__registry__.get(
|
||||
name
|
||||
) or BaseAlphaScheduler.__registry__.get(name.lower())
|
||||
if cls is None:
|
||||
available = sorted(k for k in BaseAlphaScheduler.__registry__ if k[0].isupper())
|
||||
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
|
||||
return cls
|
||||
|
||||
|
||||
def make_alpha_scheduler(name: str, **kwargs: Any) -> BaseAlphaScheduler:
|
||||
"""Instantiate a scheduler by name with optional kwargs."""
|
||||
cls = get_alpha_scheduler_class(name)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
# ---------------- Example usage ---------------- #
|
||||
|
||||
if __name__ == "__main__":
|
||||
lin_sched = make_alpha_scheduler("LinearalphaScheduler")
|
||||
print("Linear α(0.5):", lin_sched.alpha(0.5))
|
||||
print("Linear w(0.5):", lin_sched.weight(0.5))
|
||||
print("Linear α([.25,.5,.75]):", lin_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("==========================================")
|
||||
cos_sched = make_alpha_scheduler("CosinealphaScheduler")
|
||||
print("Cosine α(0.5):", cos_sched.alpha(0.5))
|
||||
print("Cosine w(0.5):", cos_sched.weight(0.5))
|
||||
print("Cosine α([.25,.5,.75]):", cos_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
128
dllm/dllm/core/schedulers/kappa.py
Normal file
128
dllm/dllm/core/schedulers/kappa.py
Normal file
@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import math
|
||||
from typing import ClassVar, Dict, Type, Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
Number = Union[float, torch.Tensor]
|
||||
|
||||
|
||||
# ---------------- Registry-enabled Base ---------------- #
|
||||
@dataclasses.dataclass
|
||||
class BaseKappaScheduler:
|
||||
__registry__: ClassVar[dict[str, type[BaseKappaScheduler]]] = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
BaseKappaScheduler.__registry__[cls.__name__] = cls
|
||||
BaseKappaScheduler.__registry__[cls.__name__.lower()] = cls
|
||||
|
||||
# Make instances callable (sched(t) -> kappa(t))
|
||||
def __call__(self, t: Number) -> Number:
|
||||
return self.kappa(t)
|
||||
|
||||
# ---- common API ----
|
||||
def kappa(self, t: Number) -> Number:
|
||||
t_tensor = torch.as_tensor(
|
||||
t,
|
||||
dtype=torch.float32,
|
||||
device=t.device if isinstance(t, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
|
||||
raise ValueError(f"t={t} not in [0,1]")
|
||||
out = self._kappa(t_tensor)
|
||||
return out.item() if isinstance(t, float) else out
|
||||
|
||||
def kappa_derivative(self, t: Number) -> Number:
|
||||
t_tensor = torch.as_tensor(
|
||||
t,
|
||||
dtype=torch.float32,
|
||||
device=t.device if isinstance(t, torch.Tensor) else None,
|
||||
)
|
||||
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
|
||||
raise ValueError(f"t={t} not in [0,1]")
|
||||
out = self._kappa_derivative(t_tensor)
|
||||
return out.item() if isinstance(t, float) else out
|
||||
|
||||
def weight(self, t: Number) -> Number:
|
||||
# w(t) = κ'(t) / (1 - κ(t))
|
||||
return self.kappa_derivative(t) / (1 - self.kappa(t) + 1e-6)
|
||||
|
||||
# ---- hooks implemented by subclasses ----
|
||||
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------- Implementations ---------------- #
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CubicKappaScheduler(BaseKappaScheduler):
|
||||
a: float = 1.0
|
||||
b: float = 1.0
|
||||
|
||||
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ(t) = (a+1) t^3 - (a+b+1) t^2 + (b+1) t
|
||||
return (self.a + 1) * (t**3) - (self.a + self.b + 1) * (t**2) + (self.b + 1) * t
|
||||
|
||||
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ'(t) = 3(a+1) t^2 - 2(a+b+1) t + (b+1)
|
||||
return 3 * (self.a + 1) * (t**2) - 2 * (self.a + self.b + 1) * t + (self.b + 1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LinearKappaScheduler(CubicKappaScheduler):
|
||||
# Special case: κ(t) = t corresponds to a=-1, b=0
|
||||
a: float = -1.0
|
||||
b: float = 0.0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CosineKappaScheduler(BaseKappaScheduler):
|
||||
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ(t) = 1 - cos((π/2) * t)
|
||||
return 1.0 - torch.cos(0.5 * math.pi * t)
|
||||
|
||||
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
|
||||
# κ'(t) = (π/2) * sin((π/2) * t)
|
||||
return 0.5 * math.pi * torch.sin(0.5 * math.pi * t)
|
||||
|
||||
|
||||
# ---------------- Factory helpers ---------------- #
|
||||
|
||||
|
||||
def get_kappa_scheduler_class(name: str) -> type[BaseKappaScheduler]:
|
||||
"""Return the scheduler class by name (case-insensitive)."""
|
||||
cls = BaseKappaScheduler.__registry__.get(
|
||||
name
|
||||
) or BaseKappaScheduler.__registry__.get(name.lower())
|
||||
if cls is None:
|
||||
available = sorted(k for k in BaseKappaScheduler.__registry__ if k[0].isupper())
|
||||
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
|
||||
return cls
|
||||
|
||||
|
||||
def make_kappa_scheduler(name: str, **kwargs: Any) -> BaseKappaScheduler:
|
||||
"""Instantiate a scheduler by name with optional kwargs."""
|
||||
cls = get_kappa_scheduler_class(name)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
# ---------------- Example usage ---------------- #
|
||||
|
||||
if __name__ == "__main__":
|
||||
lin_sched = make_kappa_scheduler("LinearKappaScheduler")
|
||||
print("Linear κ(0.5):", lin_sched.kappa(0.5))
|
||||
print("Linear w(0.5):", lin_sched.weight(0.5))
|
||||
print("Linear κ([.25,.5,.75]):", lin_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("==========================================")
|
||||
cos_sched = make_kappa_scheduler("CosineKappaScheduler")
|
||||
print("Cosine κ(0.5):", cos_sched.kappa(0.5))
|
||||
print("Cosine w(0.5):", cos_sched.weight(0.5))
|
||||
print("Cosine κ([.25,.5,.75]):", cos_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
|
||||
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
|
||||
1
dllm/dllm/core/trainers/__init__.py
Normal file
1
dllm/dllm/core/trainers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from dllm.core.trainers.mdlm import MDLMTrainer
|
||||
140
dllm/dllm/core/trainers/mdlm.py
Normal file
140
dllm/dllm/core/trainers/mdlm.py
Normal file
@ -0,0 +1,140 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from typing import Any
|
||||
|
||||
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
|
||||
|
||||
|
||||
class MDLMTrainer(transformers.Trainer):
|
||||
"""
|
||||
Masked Diffusion Language Model Trainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
scheduler: BaseAlphaScheduler | None = None,
|
||||
time_epsilon: float = 1e-3,
|
||||
loss_weight_type: str = "scheduler", # "ones"
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.scheduler = scheduler or LinearAlphaScheduler()
|
||||
if not (0.0 < time_epsilon < 1.0):
|
||||
raise ValueError("time_epsilon must be in (0, 1)")
|
||||
self.time_epsilon = time_epsilon
|
||||
self.loss_weight_type = loss_weight_type
|
||||
|
||||
def _preprocess_inputs(self, inputs):
|
||||
pass
|
||||
|
||||
def _postprocess_outputs(self, outputs):
|
||||
pass
|
||||
|
||||
def _compute_loss_weights(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
inputs: dict[str, Any],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Compute loss weights given timestep t and other arguments."""
|
||||
b, l = inputs["input_ids"].shape
|
||||
if self.loss_weight_type == "scheduler":
|
||||
loss_weights = self.scheduler.weight(t).unsqueeze(1).repeat(1, l) # b, 1
|
||||
elif self.loss_weight_type == "ones":
|
||||
loss_weights = torch.ones_like(inputs["input_ids"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss_weights
|
||||
|
||||
@torch.no_grad()
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
if prediction_loss_only:
|
||||
return (loss.detach(), None, None)
|
||||
|
||||
logits = getattr(outputs, "logits", outputs)
|
||||
if isinstance(logits, torch.Tensor):
|
||||
logits = logits.detach().contiguous()
|
||||
|
||||
labels = inputs.get("labels")
|
||||
if isinstance(labels, torch.Tensor):
|
||||
labels = labels.detach().contiguous()
|
||||
|
||||
return (loss.detach(), logits, labels)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: transformers.PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert self.processing_class.padding_side == "right"
|
||||
self._preprocess_inputs(inputs)
|
||||
input_ids, labels, attention_mask = (
|
||||
inputs["input_ids"],
|
||||
inputs["labels"],
|
||||
inputs.get("attention_mask", None),
|
||||
)
|
||||
b, l = input_ids.shape
|
||||
|
||||
# === 1. Sample diffusion timesteps ===
|
||||
# Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0.
|
||||
# The scheduler defines the masking rate α(t); we convert it to a masking probability p_mask = 1 - α(t).
|
||||
t = self.time_epsilon + (1 - self.time_epsilon) * torch.rand(
|
||||
b, device=input_ids.device
|
||||
)
|
||||
p_mask = 1 - self.scheduler(t).unsqueeze(1).expand(b, l)
|
||||
|
||||
# === 2. Apply stochastic masking ===
|
||||
# Tokens are masked independently according to p_mask(t).
|
||||
# Positions with label = -100 are excluded (ignored in loss).
|
||||
masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & (
|
||||
labels != -100
|
||||
)
|
||||
# Replace masked tokens with the special [MASK] token.
|
||||
noised_input_ids = torch.where(
|
||||
masked_indices, self.processing_class.mask_token_id, input_ids
|
||||
)
|
||||
|
||||
# === 3. Forward pass through the model ===
|
||||
# The model predicts clean tokens given noised inputs.
|
||||
outputs = model(input_ids=noised_input_ids, attention_mask=attention_mask)
|
||||
self._postprocess_outputs(outputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# === 4. Handle degenerate cases (no tokens masked) ===
|
||||
# If no positions were masked, return a zero loss to keep gradients valid.
|
||||
# This step is necessary for Deepspeed Zero-{2,3}
|
||||
if not masked_indices.any():
|
||||
return (
|
||||
(logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0
|
||||
)
|
||||
|
||||
# === 5. Compute per-token loss weights ===
|
||||
# Depending on the configuration, weights may depend on timestep t
|
||||
# (e.g., scheduler-based) or be uniform (ones).
|
||||
loss_weights = self._compute_loss_weights(
|
||||
t=t, inputs=inputs, masked_indices=masked_indices
|
||||
)
|
||||
|
||||
# === 6. Compute weighted cross-entropy ===
|
||||
# Only masked tokens contribute to the loss.
|
||||
assert (input_ids[masked_indices] == labels[masked_indices]).all()
|
||||
token_loss = F.cross_entropy(
|
||||
logits[masked_indices], input_ids[masked_indices], reduction="none"
|
||||
)
|
||||
token_loss = token_loss * loss_weights[masked_indices]
|
||||
|
||||
# === 7. Normalize loss per effective token length ===
|
||||
# Normalize each sequence’s contribution by its number of valid tokens,
|
||||
# then average over the batch for stability across variable-length inputs.
|
||||
effective_lengths = torch.sum(labels != -100, dim=1, keepdim=True).expand(b, l)
|
||||
loss = torch.sum(token_loss / effective_lengths[masked_indices]) / b
|
||||
|
||||
# === 8. Return final loss (and optionally model outputs) ===
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
1
dllm/dllm/data/__init__.py
Normal file
1
dllm/dllm/data/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .utils import load_sft_dataset, load_pt_dataset
|
||||
63
dllm/dllm/data/alpaca.py
Normal file
63
dllm/dllm/data/alpaca.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import Optional
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
|
||||
def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str:
|
||||
"""Construct a clean text prompt from Alpaca fields.
|
||||
|
||||
We intentionally *do not* include Anthropic-style role tags (e.g., "Human:", "Assistant:")
|
||||
in the returned prompt, to mirror the return shape of `load_hh_rlhf_dataset` which removes
|
||||
those tags from the prompt it returns.
|
||||
"""
|
||||
instruction = (instruction or "").strip()
|
||||
input_text = (input_text or "").strip()
|
||||
|
||||
if input_text:
|
||||
# Keep instruction and input separated by a blank line for readability.
|
||||
return f"{instruction}\n\n{input_text}"
|
||||
else:
|
||||
return instruction
|
||||
|
||||
|
||||
def load_dataset_alpaca(dataset_name_or_path: str) -> DatasetDict:
|
||||
"""Load the Alpaca dataset (tatsu-lab/alpaca) and expose unified fields.
|
||||
|
||||
Returns a `DatasetDict` where each split contains:
|
||||
- prompt: Combined instruction (+ optional input), with clean formatting
|
||||
- response: The target output (model answer)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_name_or_path : str
|
||||
Usually "tatsu-lab/alpaca" or a local path.
|
||||
"""
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
|
||||
def map_fn(example):
|
||||
prompt = _build_alpaca_prompt(
|
||||
example.get("instruction", ""), example.get("input", "")
|
||||
)
|
||||
response = (example.get("output", "") or "").strip()
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": response},
|
||||
]
|
||||
}
|
||||
|
||||
dataset = dataset.map(
|
||||
map_fn, remove_columns=dataset["train"].column_names, num_proc=4
|
||||
)
|
||||
# make train test split
|
||||
dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dllm.utils import resolve_with_base_env
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
"tatsu-lab/alpaca", "BASE_DATASETS_DIR"
|
||||
)
|
||||
dataset = load_dataset_alpaca(dataset_name_or_path)
|
||||
breakpoint()
|
||||
133
dllm/dllm/data/opc.py
Normal file
133
dllm/dllm/data/opc.py
Normal file
@ -0,0 +1,133 @@
|
||||
from typing import Optional, Text, List, Dict
|
||||
from datasets import (
|
||||
load_dataset,
|
||||
get_dataset_config_names,
|
||||
concatenate_datasets,
|
||||
DatasetDict,
|
||||
Dataset,
|
||||
IterableDatasetDict,
|
||||
)
|
||||
from dllm.data.utils import (
|
||||
_merge_datasetdicts,
|
||||
_merge_iterabledatasetdicts,
|
||||
_ensure_datasetdict,
|
||||
_ensure_iterabledatasetdict,
|
||||
_ensure_datasetdict,
|
||||
)
|
||||
|
||||
|
||||
def load_dataset_opc_sft(
|
||||
dataset_name_or_path: str, name: str | None = None, lang: str | None = None
|
||||
) -> DatasetDict:
|
||||
"""
|
||||
Load OpenCoder OPC SFT dataset(s) and produce a DatasetDict with a train/test split.
|
||||
- If `name` is provided: load that specific config.
|
||||
- If `name` is None: load *all* available configs and concatenate them.
|
||||
"""
|
||||
|
||||
def _map_to_messages(ds: Dataset) -> Dataset:
|
||||
def map_fn(example):
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": example["instruction"]},
|
||||
{"role": "assistant", "content": example["output"]},
|
||||
]
|
||||
}
|
||||
|
||||
# Remove all original columns after mapping
|
||||
remove_cols = ds.column_names
|
||||
return ds.map(map_fn, remove_columns=remove_cols, num_proc=4)
|
||||
|
||||
def _load_one_config(dataset_name_or_path: str, cfg_name: str) -> Dataset:
|
||||
ds = load_dataset(dataset_name_or_path, cfg_name, split="train")
|
||||
return _map_to_messages(ds)
|
||||
|
||||
if name is not None:
|
||||
train_ds = _load_one_config(dataset_name_or_path, name)
|
||||
else:
|
||||
# Enumerate and load all configs, then concatenate
|
||||
cfgs: list[str] = get_dataset_config_names(dataset_name_or_path)
|
||||
if not cfgs:
|
||||
raise ValueError(f"No configs found for dataset: {dataset_name_or_path}")
|
||||
parts = [_load_one_config(dataset_name_or_path, c) for c in cfgs]
|
||||
train_ds = concatenate_datasets(parts)
|
||||
|
||||
# Final split
|
||||
ds_dict = train_ds.train_test_split(test_size=0.1, seed=42)
|
||||
if lang is not None:
|
||||
ds_dict = ds_dict.filter(lambda row: lang in row["messages"][1]["content"])
|
||||
|
||||
return DatasetDict(ds_dict)
|
||||
|
||||
|
||||
def load_dataset_opc_annealing(
|
||||
dataset_name_or_path: str,
|
||||
name: str | None = None,
|
||||
lang: str | None = None,
|
||||
streaming: bool = True,
|
||||
) -> DatasetDict:
|
||||
def _load_one_config(_name):
|
||||
ds = load_dataset(
|
||||
dataset_name_or_path, _name, split="train", streaming=streaming
|
||||
)
|
||||
if lang:
|
||||
if _name in ["synthetic_code_snippet", "algorithmic_corpus"]:
|
||||
ds = ds.filter(lambda row: row["lang"] == lang)
|
||||
elif _name in ["synthetic_qa"]:
|
||||
ds = ds.filter(lambda row: row["program_lang"] == lang)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# return IterableDatasetDict({"train": ds})
|
||||
if streaming:
|
||||
return _ensure_iterabledatasetdict(ds)
|
||||
return _ensure_datasetdict(ds)
|
||||
|
||||
if name is not None:
|
||||
return _load_one_config(name)
|
||||
|
||||
if streaming:
|
||||
parts = [
|
||||
_load_one_config(name)
|
||||
for name in get_dataset_config_names(dataset_name_or_path)
|
||||
]
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_iterabledatasetdicts(merged, p)
|
||||
return merged
|
||||
else:
|
||||
parts = [
|
||||
_load_one_config(name)
|
||||
for name in get_dataset_config_names(dataset_name_or_path)
|
||||
]
|
||||
if len(parts) == 1:
|
||||
return _ensure_datasetdict(parts[0])
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_datasetdicts(merged, p)
|
||||
return _ensure_datasetdict(merged)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dllm.utils import resolve_with_base_env
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
"OpenCoder-LLM/opc-sft-stage1", "BASE_DATASETS_DIR"
|
||||
)
|
||||
# If you want a specific config:
|
||||
dataset_edu = load_dataset_opc_sft(dataset_name_or_path, "realuser_instruct")
|
||||
# Otherwise, all configs concatenated:
|
||||
dataset_all = load_dataset_opc_sft(dataset_name_or_path, None)
|
||||
dataset_all_python = load_dataset_opc_sft(dataset_name_or_path, None, "python")
|
||||
breakpoint()
|
||||
|
||||
# streaming = True
|
||||
# dataset_name_or_path = resolve_with_base_env(
|
||||
# "OpenCoder-LLM/opc-annealing-corpus", "BASE_DATASETS_DIR"
|
||||
# )
|
||||
# # If you want a specific config:
|
||||
# dataset_alg_all = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus")
|
||||
# dataset_alg_python = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus", "python")
|
||||
# # Otherwise, all configs concatenated:
|
||||
# dataset_all_python = load_dataset_opc_annealing(dataset_name_or_path, None, "python")
|
||||
# dataset_all_all = load_dataset_opc_annealing(dataset_name_or_path, None)
|
||||
# breakpoint()
|
||||
108
dllm/dllm/data/ultrachat.py
Normal file
108
dllm/dllm/data/ultrachat.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import Optional, List, Dict
|
||||
from datasets import load_dataset, DatasetDict
|
||||
|
||||
|
||||
def _extract_first_turn(messages: list[dict[str, str]]) -> dict[str, str] | None:
|
||||
"""
|
||||
Given a list of chat messages like:
|
||||
[{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
...]
|
||||
return a dict with the first user/assistant exchange as:
|
||||
{"prompt": <user content>, "response": <assistant content>}
|
||||
If no valid first turn exists, return None.
|
||||
"""
|
||||
if not isinstance(messages, list) or len(messages) < 2:
|
||||
return None
|
||||
|
||||
# Find the first user message and the first assistant *after* that user msg
|
||||
# (Most entries start as [user, assistant, ...], but we guard anyway.)
|
||||
user_idx = None
|
||||
for i, m in enumerate(messages):
|
||||
if (
|
||||
isinstance(m, dict)
|
||||
and m.get("role") == "user"
|
||||
and isinstance(m.get("content"), str)
|
||||
):
|
||||
user_idx = i
|
||||
break
|
||||
if user_idx is None:
|
||||
return None
|
||||
|
||||
# Find first assistant after that user
|
||||
for j in range(user_idx + 1, len(messages)):
|
||||
m = messages[j]
|
||||
if (
|
||||
isinstance(m, dict)
|
||||
and m.get("role") == "assistant"
|
||||
and isinstance(m.get("content"), str)
|
||||
):
|
||||
user_text = messages[user_idx]["content"].strip()
|
||||
assistant_text = m["content"].strip()
|
||||
if user_text and assistant_text:
|
||||
return {"prompt": user_text, "response": assistant_text}
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def load_dataset_ultrachat(dataset_name_or_path: str) -> DatasetDict:
|
||||
"""
|
||||
Load the UltraChat 200k dataset (HuggingFaceH4/ultrachat_200k) and keep only the *first turn*
|
||||
(first user message and the assistant reply).
|
||||
|
||||
Returns a `DatasetDict` where each split contains:
|
||||
- prompt: first user message content
|
||||
- response: first assistant reply content
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_name_or_path : str
|
||||
Typically "HuggingFaceH4/ultrachat_200k" or a local path.
|
||||
data_dir : Optional[str]
|
||||
Optional subdirectory (for local paths).
|
||||
"""
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
|
||||
# We only keep examples that have a valid first (user, assistant) turn.
|
||||
def has_first_turn(example):
|
||||
messages = example.get("messages")
|
||||
return _extract_first_turn(messages) is not None
|
||||
|
||||
dataset = dataset.filter(has_first_turn, num_proc=4)
|
||||
|
||||
def map_fn(example):
|
||||
first = _extract_first_turn(example["messages"])
|
||||
# Fallbacks for robustness (shouldn't be hit after filter, but just in case)
|
||||
if first is None:
|
||||
first = {"prompt": (example.get("prompt") or "").strip(), "response": ""}
|
||||
return {"prompt": first["prompt"], "response": first["response"]}
|
||||
|
||||
# Remove original columns for a clean schema (infer from any available split)
|
||||
cols_to_remove = None
|
||||
for split_name in dataset.keys():
|
||||
cols_to_remove = dataset[split_name].column_names
|
||||
break
|
||||
|
||||
dataset = dataset.map(map_fn, remove_columns=cols_to_remove, num_proc=4)
|
||||
dataset = DatasetDict(
|
||||
{
|
||||
new: dataset[old]
|
||||
for old, new in {
|
||||
"train_sft": "train",
|
||||
"test_sft": "test",
|
||||
}.items()
|
||||
if old in dataset
|
||||
}
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Mirrors the style from your previous loaders: resolve path via env helper if available.
|
||||
from dllm.utils import resolve_with_base_env
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
"HuggingFaceH4/ultrachat_200k", "BASE_DATASETS_DIR"
|
||||
)
|
||||
dataset = load_dataset_ultrachat(dataset_name_or_path)
|
||||
breakpoint()
|
||||
377
dllm/dllm/data/utils.py
Normal file
377
dllm/dllm/data/utils.py
Normal file
@ -0,0 +1,377 @@
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDatasetDict,
|
||||
IterableDataset,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
|
||||
from dllm.utils.utils import resolve_with_base_env, parse_spec, get_default_logger
|
||||
|
||||
|
||||
logger = get_default_logger(__name__)
|
||||
|
||||
|
||||
def load_sft_dataset(
|
||||
dataset_args: str, load_preprocessed_data: bool = False
|
||||
) -> DatasetDict:
|
||||
"""
|
||||
Examples of dataset_args:
|
||||
- "tatsu-lab/alpaca"
|
||||
- "OpenCoder-LLM/opc-sft-stage2[name:educational_instruct]"
|
||||
- "tatsu-lab/alpaca[train:5000]"
|
||||
- "tatsu-lab/alpaca[train:5000] | HuggingFaceH4/ultrachat_200k[train:5000]"
|
||||
"""
|
||||
from dllm.data.alpaca import load_dataset_alpaca
|
||||
from dllm.data.opc import load_dataset_opc_sft
|
||||
|
||||
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
|
||||
all_parts = []
|
||||
|
||||
for raw in specs:
|
||||
dataset_name_or_path, kvs = parse_spec(raw)
|
||||
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
dataset_name_or_path, "BASE_DATASETS_DIR"
|
||||
)
|
||||
|
||||
if load_preprocessed_data:
|
||||
logger.info("Load preprocessed data from disk.")
|
||||
ds = load_from_disk(dataset_name_or_path)
|
||||
# Implement your customized dataset here
|
||||
elif _match(dataset_name_or_path, "tatsu-lab/alpaca"):
|
||||
ds = load_dataset_alpaca(dataset_name_or_path)
|
||||
elif _match(dataset_name_or_path, "allenai/tulu-3-sft-mixture"):
|
||||
ds = load_dataset(dataset_name_or_path)
|
||||
ds = ds["train"].train_test_split(test_size=0.1, seed=42)
|
||||
elif _match(dataset_name_or_path, "HuggingFaceTB/smoltalk"):
|
||||
name = kvs.pop("name", "all")
|
||||
ds = load_dataset(dataset_name_or_path, name=name)
|
||||
elif _match(dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage1") or _match(
|
||||
dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage2"
|
||||
):
|
||||
name = kvs.pop("name", None)
|
||||
lang = kvs.pop("lang", None)
|
||||
ds = load_dataset_opc_sft(dataset_name_or_path, name=name, lang=lang)
|
||||
elif _match(dataset_name_or_path, "HuggingFaceH4/ultrachat_200k"):
|
||||
ds = load_dataset(dataset_name_or_path)
|
||||
ds = DatasetDict({"train": ds["train_sft"], "test": ds["test_sft"]})
|
||||
else:
|
||||
ds = load_dataset(dataset_name_or_path)
|
||||
|
||||
# Normalize to DatasetDict and apply per-split limits
|
||||
ds = _ensure_datasetdict(ds)
|
||||
ds = _truncate_dataset(ds, kvs)
|
||||
all_parts.append(ds)
|
||||
|
||||
# If only one part, return as DatasetDict
|
||||
if len(all_parts) == 1:
|
||||
return _ensure_datasetdict(all_parts[0])
|
||||
|
||||
# Merge all parts into a single DatasetDict
|
||||
merged = all_parts[0]
|
||||
for part in all_parts[1:]:
|
||||
merged = _merge_datasetdicts(merged, part)
|
||||
return _ensure_datasetdict(merged)
|
||||
|
||||
|
||||
def load_pt_dataset(
|
||||
dataset_args: str, streaming: bool = True, load_preprocessed_data: bool = False
|
||||
) -> DatasetDict | IterableDatasetDict:
|
||||
"""
|
||||
Examples of dataset_args:
|
||||
- "mlfoundations/dclm-baseline-1.0"
|
||||
- "OpenCoder-LLM/opc-fineweb-code-corpus"
|
||||
- "OpenCoder-LLM/opc-fineweb-math-corpus"
|
||||
- "OpenCoder-LLM/opc-annealing-corpus[lang:python]"
|
||||
- "wikitext[name:wikitext-103-v1}]"
|
||||
"""
|
||||
from dllm.data.opc import load_dataset_opc_annealing
|
||||
|
||||
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
|
||||
if not specs:
|
||||
raise ValueError("Empty dataset_args for load_pt_dataset.")
|
||||
|
||||
# ---------- Shared loader (only differs by streaming flag) ----------
|
||||
def _load_base_dataset(
|
||||
raw: str, *, streaming: bool
|
||||
) -> tuple[DatasetDict | IterableDatasetDict, dict, str]:
|
||||
"""
|
||||
Returns: (base, kvs, dataset_name_or_path)
|
||||
- Pops 'name' from kvs when applicable (e.g., wikitext).
|
||||
- Applies identical matching logic for both streaming/non-streaming.
|
||||
"""
|
||||
dataset_name_or_path, kvs = parse_spec(raw)
|
||||
dataset_name_or_path = resolve_with_base_env(
|
||||
dataset_name_or_path, "BASE_DATASETS_DIR"
|
||||
)
|
||||
name = kvs.pop("name", None)
|
||||
|
||||
if load_preprocessed_data:
|
||||
base = load_from_disk(dataset_name_or_path)
|
||||
elif _match(dataset_name_or_path, ["OpenCoder-LLM/opc-annealing-corpus"]):
|
||||
lang = kvs.pop("lang", None)
|
||||
base = load_dataset_opc_annealing(
|
||||
dataset_name_or_path, name=name, lang=lang, streaming=streaming
|
||||
)
|
||||
else:
|
||||
base = load_dataset(dataset_name_or_path, name=name, streaming=streaming)
|
||||
|
||||
return base, kvs, dataset_name_or_path
|
||||
|
||||
# ---------- Streaming path ----------
|
||||
def _load_one_streaming_spec(raw: str) -> IterableDatasetDict:
|
||||
base, kvs, dataset_name_or_path = _load_base_dataset(raw, streaming=True)
|
||||
|
||||
split_names = list(base.keys())
|
||||
single_split = len(split_names) == 1
|
||||
single_split_name = split_names[0] if single_split else None
|
||||
|
||||
n_train = kvs.get("train")
|
||||
n_test = kvs.get("test")
|
||||
|
||||
if (n_train is not None) or (n_test is not None):
|
||||
if (n_train is not None) and (n_test is not None):
|
||||
if single_split:
|
||||
stream = base[single_split_name]
|
||||
head = stream.take(n_train + n_test)
|
||||
test = head.take(n_test)
|
||||
train = head.skip(n_test).take(n_train)
|
||||
return IterableDatasetDict({"train": train, "test": test})
|
||||
else:
|
||||
if "train" not in base or "test" not in base:
|
||||
raise ValueError(
|
||||
f"{dataset_name_or_path}: require 'train' and 'test' splits for train+test limits."
|
||||
)
|
||||
train = base["train"].take(n_train)
|
||||
test = base["test"].take(n_test)
|
||||
return IterableDatasetDict({"train": train, "test": test})
|
||||
|
||||
if n_train is not None:
|
||||
if single_split:
|
||||
train = base[single_split_name].take(n_train)
|
||||
else:
|
||||
if "train" not in base:
|
||||
raise ValueError(
|
||||
f"{dataset_name_or_path}: missing 'train' split for train limit."
|
||||
)
|
||||
train = base["train"].take(n_train)
|
||||
return IterableDatasetDict({"train": train})
|
||||
|
||||
if n_test is not None:
|
||||
if single_split:
|
||||
test = base[single_split_name].take(n_test)
|
||||
else:
|
||||
if "test" not in base:
|
||||
raise ValueError(
|
||||
f"{dataset_name_or_path}: missing 'test' split for test limit."
|
||||
)
|
||||
test = base["test"].take(n_test)
|
||||
return IterableDatasetDict({"test": test})
|
||||
|
||||
return base # already an IterableDatasetDict
|
||||
|
||||
# ---------- Non-streaming path (mirror load_sft_dataset; NO shuffle) ----------
|
||||
def _load_one_nonstreaming_spec(raw: str) -> DatasetDict:
|
||||
base, kvs, _ = _load_base_dataset(raw, streaming=False)
|
||||
ds = _ensure_datasetdict(base) # normalize
|
||||
ds = _truncate_dataset(ds, kvs) # apply limits (train/test/...)
|
||||
return ds
|
||||
|
||||
# ---------- Load & Merge ----------
|
||||
if streaming:
|
||||
logger.info("Loading dataset in streaming mode.")
|
||||
parts = [_load_one_streaming_spec(raw) for raw in specs]
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_iterabledatasetdicts(merged, p)
|
||||
# repeat streaming dataset infinitely
|
||||
merged = IterableDatasetDict(
|
||||
{k: (v.repeat(None) if k == "train" else v) for k, v in merged.items()}
|
||||
)
|
||||
return merged
|
||||
else:
|
||||
logger.info("Loading dataset in non-streaming mode.")
|
||||
parts = [_load_one_nonstreaming_spec(raw) for raw in specs]
|
||||
if len(parts) == 1:
|
||||
return _ensure_datasetdict(parts[0])
|
||||
merged = parts[0]
|
||||
for p in parts[1:]:
|
||||
merged = _merge_datasetdicts(merged, p)
|
||||
return _ensure_datasetdict(merged)
|
||||
|
||||
|
||||
def _truncate_split(split_data, n: int):
|
||||
if n is None:
|
||||
return split_data
|
||||
try:
|
||||
if hasattr(split_data, "select"):
|
||||
# Hugging Face Dataset path
|
||||
total = getattr(split_data, "num_rows", None)
|
||||
if total is None:
|
||||
# some Dataset types expose len(...)
|
||||
total = len(split_data)
|
||||
idx = list(range(min(n, total)))
|
||||
return split_data.select(idx)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return split_data[:n]
|
||||
except Exception:
|
||||
# Last resort: iterate
|
||||
return type(split_data)(item for i, item in enumerate(split_data) if i < n)
|
||||
|
||||
|
||||
def _truncate_dataset(ds, limits: dict):
|
||||
"""
|
||||
Ensure and return a DatasetDict, truncating splits mentioned in `limits`.
|
||||
"""
|
||||
ds = _ensure_datasetdict(ds) # normalize first
|
||||
out = {}
|
||||
for split, data in ds.items():
|
||||
n = limits.get(split, None)
|
||||
out[split] = _truncate_split(data, n) if n is not None else data
|
||||
return DatasetDict(out)
|
||||
|
||||
|
||||
def _concat_splits(a, b):
|
||||
"""
|
||||
Concatenate two split objects (prefer 🤗 datasets).
|
||||
"""
|
||||
if a is b:
|
||||
return a
|
||||
if a is None:
|
||||
return b
|
||||
if b is None:
|
||||
return a
|
||||
|
||||
# Prefer datasets' concatenate_datasets when both are Datasets
|
||||
try:
|
||||
from datasets import concatenate_datasets
|
||||
|
||||
if isinstance(a, Dataset) and isinstance(b, Dataset):
|
||||
return concatenate_datasets([a, b])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallbacks
|
||||
try:
|
||||
return a + b
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return type(a)(list(a) + list(b))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise TypeError(
|
||||
f"Cannot concatenate split objects of types {type(a)} and {type(b)}"
|
||||
)
|
||||
|
||||
|
||||
def _merge_datasetdicts(d1, d2):
|
||||
"""
|
||||
Merge two DatasetDict-like mappings by concatenating splits present in either.
|
||||
Always returns a DatasetDict.
|
||||
"""
|
||||
d1 = _ensure_datasetdict(d1)
|
||||
d2 = _ensure_datasetdict(d2)
|
||||
all_splits = set(d1.keys()) | set(d2.keys())
|
||||
out = {}
|
||||
for split in all_splits:
|
||||
a = d1.get(split, None)
|
||||
b = d2.get(split, None)
|
||||
if a is None:
|
||||
out[split] = b
|
||||
elif b is None:
|
||||
out[split] = a
|
||||
else:
|
||||
out[split] = _concat_splits(a, b)
|
||||
return DatasetDict(out)
|
||||
|
||||
|
||||
def _ensure_datasetdict(ds):
|
||||
"""
|
||||
Normalize various loader outputs into a DatasetDict.
|
||||
- If loader returns a DatasetDict, return as is.
|
||||
- If loader returns a mapping (e.g., dict of splits), wrap into DatasetDict.
|
||||
- If loader returns a single Dataset/list/etc., assume it's 'train'.
|
||||
"""
|
||||
if isinstance(ds, DatasetDict):
|
||||
return ds
|
||||
if isinstance(ds, dict):
|
||||
# Try to convert each split value to a Dataset if they aren't already.
|
||||
# If they are already Datasets, DatasetDict will accept them directly.
|
||||
return DatasetDict(ds)
|
||||
# Single split -> assume train
|
||||
return DatasetDict({"train": ds})
|
||||
|
||||
|
||||
def _match(name: str, needle) -> bool:
|
||||
"""
|
||||
Returns True if `name` matches any of the provided needles.
|
||||
Accepts a single string or a list/tuple of strings.
|
||||
Match condition: name endswith(needle) or needle in name.
|
||||
"""
|
||||
if isinstance(needle, (list, tuple)):
|
||||
return any(name.endswith(n) or n in name for n in needle)
|
||||
return name.endswith(needle) or needle in name
|
||||
|
||||
|
||||
def _concat_iterable_datasets(parts: list[IterableDataset]) -> IterableDataset:
|
||||
"""
|
||||
Concatenate IterableDatasets sequentially without materialization.
|
||||
Preserves streaming nature; supports downstream .take()/.skip()/.shuffle().
|
||||
"""
|
||||
if not parts:
|
||||
raise ValueError("No IterableDatasets to concatenate.")
|
||||
# Try to reuse features from the first dataset when available
|
||||
features = getattr(parts[0], "features", None)
|
||||
|
||||
def _gen():
|
||||
for ds in parts:
|
||||
yield from ds
|
||||
|
||||
return IterableDataset.from_generator(_gen, features=features)
|
||||
|
||||
|
||||
def _ensure_iterabledatasetdict(obj) -> IterableDatasetDict:
|
||||
if isinstance(obj, IterableDatasetDict):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
return IterableDatasetDict(obj)
|
||||
# Single stream -> assume train
|
||||
return IterableDatasetDict({"train": obj})
|
||||
|
||||
|
||||
def _merge_iterabledatasetdicts(
|
||||
d1: IterableDatasetDict, d2: IterableDatasetDict
|
||||
) -> IterableDatasetDict:
|
||||
"""
|
||||
Merge by concatenating any overlapping splits (streaming-safe).
|
||||
"""
|
||||
d1 = _ensure_iterabledatasetdict(d1)
|
||||
d2 = _ensure_iterabledatasetdict(d2)
|
||||
all_splits = set(d1.keys()) | set(d2.keys())
|
||||
out = {}
|
||||
for split in all_splits:
|
||||
a = d1.get(split, None)
|
||||
b = d2.get(split, None)
|
||||
if a is None:
|
||||
out[split] = b
|
||||
elif b is None:
|
||||
out[split] = a
|
||||
else:
|
||||
out[split] = _concat_iterable_datasets([a, b])
|
||||
return IterableDatasetDict(out)
|
||||
|
||||
|
||||
def _truncate_stream(ds: IterableDataset, n: int | None) -> IterableDataset:
|
||||
if n is None:
|
||||
return ds
|
||||
return ds.take(n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
breakpoint()
|
||||
1
dllm/dllm/pipelines/__init__.py
Normal file
1
dllm/dllm/pipelines/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import llada, dream, rnd, editflow
|
||||
362
dllm/dllm/pipelines/bert/eval.py
Normal file
362
dllm/dllm/pipelines/bert/eval.py
Normal file
@ -0,0 +1,362 @@
|
||||
"""
|
||||
accelerate launch \
|
||||
--num_processes 2 \
|
||||
dllm/pipelines/bert/eval.py \
|
||||
--tasks gsm8k \
|
||||
--batch_size 1 \
|
||||
--model bert \
|
||||
--device cuda \
|
||||
--num_fewshot 8 \
|
||||
--model_args "pretrained=dllm-collection/ModernBERT-base-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from lm_eval.models.utils import get_dtype
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BERTEvalConfig(LLaDAGeneratorConfig):
|
||||
max_new_tokens: int = 128
|
||||
max_length: int = 512
|
||||
steps: int = 128
|
||||
block_length: int = 128
|
||||
|
||||
pretrained: str = ""
|
||||
dtype: str | torch.dtype = "auto"
|
||||
batch_size: int = 32
|
||||
mc_num: int = 128
|
||||
is_check_greedy: bool = True
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@register_model("bert")
|
||||
class BERTEvalHarness(LM):
|
||||
def __init__(
|
||||
self,
|
||||
config: BERTEvalConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Initialize config if not provided
|
||||
if config is None:
|
||||
config = BERTEvalConfig()
|
||||
|
||||
# Pull args from config, allow kwargs to override
|
||||
pretrained = kwargs.get("pretrained", config.pretrained)
|
||||
dtype = kwargs.get("dtype", config.dtype)
|
||||
batch_size = kwargs.get("batch_size", config.batch_size)
|
||||
mc_num = kwargs.get("mc_num", config.mc_num)
|
||||
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
|
||||
device = kwargs.get("device", config.device)
|
||||
cfg = kwargs.get("cfg", config.cfg_scale)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
# Get GLOBAL rank from torch.distributed (not accelerator)
|
||||
if torch.distributed.is_initialized():
|
||||
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
|
||||
self._world_size = (
|
||||
torch.distributed.get_world_size()
|
||||
) # ← GLOBAL world size (16)
|
||||
else:
|
||||
self._rank = 0
|
||||
self._world_size = 1
|
||||
|
||||
# Use accelerator for device placement
|
||||
self.model = dllm.utils.get_model(
|
||||
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# Let accelerator handle device placement
|
||||
self.model = accelerator.prepare(self.model)
|
||||
self.device = (
|
||||
accelerator.device
|
||||
) # ← Accelerator figures out local device correctly
|
||||
self.accelerator = accelerator
|
||||
else:
|
||||
# Single GPU
|
||||
self.model = self.model.to(device)
|
||||
self.device = torch.device(device)
|
||||
self.accelerator = None
|
||||
|
||||
self.tokenizer = dllm.utils.get_tokenizer(
|
||||
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
|
||||
)
|
||||
|
||||
# generation params
|
||||
self.mask_id = self.tokenizer.mask_token_id
|
||||
self.batch_size = int(batch_size)
|
||||
self.max_length = int(max_length)
|
||||
self.max_new_tokens = int(max_new_tokens)
|
||||
self.block_length = int(block_length)
|
||||
self.steps = int(steps)
|
||||
self.cfg = float(cfg)
|
||||
self.remasking = remasking
|
||||
self.is_check_greedy = is_check_greedy
|
||||
|
||||
# loglikelihood params
|
||||
self.mc_num = int(mc_num)
|
||||
assert mc_num % self.batch_size == 0
|
||||
self.sampling_eps = 0.0
|
||||
|
||||
def apply_chat_template(
|
||||
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Method to apply a chat template to a list of chat history between user and model.
|
||||
"""
|
||||
chat_templated = self.tokenizer.apply_chat_template(
|
||||
chat_history,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=not add_generation_prompt,
|
||||
)
|
||||
return chat_templated
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self.tokenizer.name_or_path.replace("/", "__")
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def _forward_process(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
b, l = batch.shape
|
||||
|
||||
target_len = (l - prompt_index.sum()).item()
|
||||
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
||||
|
||||
x = torch.round(
|
||||
torch.linspace(
|
||||
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
|
||||
)
|
||||
).long()
|
||||
x = ((x - 1) % target_len) + 1
|
||||
assert x.min() >= 1 and x.max() <= target_len
|
||||
|
||||
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
||||
is_mask = indices < x.unsqueeze(1)
|
||||
|
||||
for i in range(b):
|
||||
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
||||
|
||||
is_mask = torch.cat(
|
||||
(
|
||||
torch.zeros(
|
||||
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
|
||||
),
|
||||
is_mask,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
noisy_batch = torch.where(is_mask, self.mask_id, batch)
|
||||
|
||||
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_logits(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfg > 0.0:
|
||||
assert len(prompt_index) == batch.shape[1]
|
||||
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
||||
un_batch = batch.clone()
|
||||
un_batch[prompt_index] = self.mask_id
|
||||
batch = torch.cat([batch, un_batch])
|
||||
|
||||
logits = self.model(batch).logits
|
||||
|
||||
if self.cfg > 0.0:
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
|
||||
return logits[:, : batch.shape[1]]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
|
||||
seq = torch.concatenate([prefix, target])[None, :]
|
||||
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
|
||||
loss_acc = []
|
||||
for _ in range(self.mc_num // self.batch_size):
|
||||
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
|
||||
|
||||
mask_indices = perturbed_seq == self.mask_id
|
||||
|
||||
logits = self.get_logits(perturbed_seq, prompt_index)
|
||||
|
||||
loss = (
|
||||
F.cross_entropy(
|
||||
logits[mask_indices], seq[mask_indices], reduction="none"
|
||||
)
|
||||
/ p_mask[mask_indices]
|
||||
)
|
||||
loss = loss.sum() / self.batch_size
|
||||
loss_acc.append(loss.item())
|
||||
|
||||
return -sum(loss_acc) / len(loss_acc)
|
||||
|
||||
@torch.no_grad()
|
||||
def suffix_greedy_prediction(
|
||||
self, prefix: torch.Tensor, target: torch.Tensor
|
||||
) -> bool:
|
||||
if not self.is_check_greedy:
|
||||
return False
|
||||
|
||||
seq = torch.full(
|
||||
(1, len(prefix) + len(target)), self.mask_id, device=self.device
|
||||
)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
prefix, target = prefix.to(self.device), target.to(self.device)
|
||||
seq[0, : len(prefix)] = prefix
|
||||
|
||||
for i in range(len(target)):
|
||||
mask_index = seq == self.mask_id
|
||||
logits = self.get_logits(seq, prompt_index)[mask_index]
|
||||
x0 = torch.argmax(logits, dim=-1)
|
||||
|
||||
p = torch.softmax(logits.to(torch.float32), dim=-1)
|
||||
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
|
||||
dim=-1
|
||||
)
|
||||
_, index = torch.sort(confidence, descending=True)
|
||||
x0[index[1:]] = self.mask_id
|
||||
seq[mask_index] = x0.clone()
|
||||
correct = target == seq[0, len(prefix) :]
|
||||
correct = torch.all(correct)
|
||||
return correct
|
||||
|
||||
def _encode_pair(
|
||||
self, context: str, continuation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
n_spaces = len(context) - len(context.rstrip())
|
||||
if n_spaces > 0:
|
||||
continuation = context[-n_spaces:] + continuation
|
||||
context = context[:-n_spaces]
|
||||
|
||||
whole_enc = self.tokenizer(context + continuation)["input_ids"]
|
||||
context_enc = self.tokenizer(context)["input_ids"]
|
||||
|
||||
context_enc_len = len(context_enc)
|
||||
continuation_enc = whole_enc[context_enc_len:]
|
||||
|
||||
return context_enc, continuation_enc
|
||||
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||
def _tokenize(e):
|
||||
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
||||
return {
|
||||
"prefix_text": e["prefix"],
|
||||
"target_text": e["target"],
|
||||
"prefix": prefix,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
ds = []
|
||||
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
|
||||
|
||||
assert max(prompt_len) <= 4096
|
||||
|
||||
out = []
|
||||
with torch.no_grad():
|
||||
for elem in tqdm(ds, desc="Computing likelihood..."):
|
||||
prefix = elem["prefix"]
|
||||
target = elem["target"]
|
||||
|
||||
ll = self.get_loglikelihood(prefix, target)
|
||||
|
||||
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
|
||||
|
||||
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
||||
torch.cuda.empty_cache()
|
||||
return out
|
||||
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_until(self, requests: list[Instance]):
|
||||
def _tokenize(e):
|
||||
return {
|
||||
"question": self.tokenizer(e["question"])["input_ids"],
|
||||
"question_text": e["question"],
|
||||
"until": e["until"],
|
||||
}
|
||||
|
||||
ds = [
|
||||
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
|
||||
]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
|
||||
out = []
|
||||
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
|
||||
|
||||
for elem in tqdm(ds, desc="Generating..."):
|
||||
prompt = [elem["question"][1:-1].to(self.device)]
|
||||
stop_tokens = elem["until"]
|
||||
generated_ids = generator.generate(
|
||||
inputs=prompt,
|
||||
steps=self.steps,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
block_length=self.block_length,
|
||||
temperature=0.0,
|
||||
cfg_scale=self.cfg,
|
||||
remasking=self.remasking,
|
||||
)
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
|
||||
)
|
||||
breakpoint()
|
||||
for stop_seq in stop_tokens:
|
||||
if stop_seq in generated_answer:
|
||||
generated_answer = generated_answer.split(stop_seq)[0]
|
||||
|
||||
# remove special tokens
|
||||
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_answer_ids, skip_special_tokens=True
|
||||
)
|
||||
out.append(generated_answer)
|
||||
if self.accelerator is not None:
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_evaluate()
|
||||
6
dllm/dllm/pipelines/dream/__init__.py
Normal file
6
dllm/dllm/pipelines/dream/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from . import generator, models, trainer, utils
|
||||
from .models.modeling_dream import DreamModel
|
||||
from .models.configuration_dream import DreamConfig
|
||||
from .models.tokenization_dream import DreamTokenizer
|
||||
from .generator import DreamGeneratorConfig, DreamGenerator
|
||||
from .trainer import DreamTrainer
|
||||
533
dllm/dllm/pipelines/dream/eval.py
Normal file
533
dllm/dllm/pipelines/dream/eval.py
Normal file
@ -0,0 +1,533 @@
|
||||
"""
|
||||
accelerate launch \
|
||||
--num_processes 2 \
|
||||
dllm/pipelines/dream/eval.py \
|
||||
--tasks gsm8k \
|
||||
--batch_size 1 \
|
||||
--model dream \
|
||||
--device cuda
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=Dream-org/Dream-v0-Base-7B,mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
|
||||
"""
|
||||
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from lm_eval.models.utils import get_dtype
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines.dream import DreamGenerator, DreamGeneratorConfig
|
||||
|
||||
eval_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamEvalConfig(DreamGeneratorConfig):
|
||||
top_p: float | None = None
|
||||
top_k: float | None = None
|
||||
max_new_tokens: int = 128
|
||||
max_length: int = 2048
|
||||
steps: int = 128
|
||||
temperature: float = 0.0
|
||||
alg: str = "entropy"
|
||||
|
||||
pretrained: str = ""
|
||||
batch_size: int = 1
|
||||
device: str = "cuda"
|
||||
dtype: str | torch.dtype = "auto"
|
||||
add_bos_token: bool = False
|
||||
nll_type: str = "mc"
|
||||
log_type: str = "ftb"
|
||||
mc_num: int = 128
|
||||
classifier_free_guidance: float = 1.0
|
||||
sampling_eps: float = 1e-3
|
||||
escape_until: bool = False
|
||||
|
||||
|
||||
@register_model("dream")
|
||||
class DreamEvalHarness(LM):
|
||||
def __init__(
|
||||
self,
|
||||
config: DreamEvalConfig | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Initialize config if not provided
|
||||
if config is None:
|
||||
config = DreamEvalConfig()
|
||||
|
||||
# Pull args from config, allow kwargs to override
|
||||
pretrained = kwargs.get("pretrained", config.pretrained)
|
||||
batch_size = kwargs.get("batch_size", config.batch_size)
|
||||
device = kwargs.get("device", config.device)
|
||||
dtype = kwargs.get("dtype", config.dtype)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
add_bos_token = kwargs.get("add_bos_token", config.add_bos_token)
|
||||
nll_type = kwargs.get("nll_type", config.nll_type)
|
||||
log_type = kwargs.get("log_type", config.log_type)
|
||||
mc_num = kwargs.get("mc_num", config.mc_num)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
classifier_free_guidance = kwargs.get(
|
||||
"classifier_free_guidance", config.classifier_free_guidance
|
||||
)
|
||||
sampling_eps = kwargs.get("sampling_eps", config.sampling_eps)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
top_p = kwargs.get("top_p", config.top_p)
|
||||
top_k = kwargs.get("top_k", config.top_k)
|
||||
alg = kwargs.get("alg", config.alg)
|
||||
alg_temp = kwargs.get("alg_temp", config.alg_temp)
|
||||
escape_until = kwargs.get("escape_until", config.escape_until)
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
# Get GLOBAL rank from torch.distributed (not accelerator)
|
||||
if torch.distributed.is_initialized():
|
||||
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
|
||||
self._world_size = (
|
||||
torch.distributed.get_world_size()
|
||||
) # ← GLOBAL world size (16)
|
||||
else:
|
||||
self._rank = 0
|
||||
self._world_size = 1
|
||||
|
||||
# Use accelerator for device placement
|
||||
self.model = dllm.utils.get_model(
|
||||
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# Let accelerator handle device placement
|
||||
self.model = accelerator.prepare(self.model)
|
||||
self.device = (
|
||||
accelerator.device
|
||||
) # ← Accelerator figures out local device correctly
|
||||
self.accelerator = accelerator
|
||||
else:
|
||||
# Single GPU
|
||||
self.model = self.model.to(device)
|
||||
self.device = torch.device(device)
|
||||
self.accelerator = None
|
||||
|
||||
self.tokenizer = dllm.utils.get_tokenizer(
|
||||
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
|
||||
)
|
||||
|
||||
# generation params
|
||||
self.mask_id = self.tokenizer.mask_token_id
|
||||
self.max_length = max_length
|
||||
self.add_bos_token = add_bos_token
|
||||
self.batch_size = int(batch_size)
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.steps = steps
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.alg = alg
|
||||
self.alg_temp = alg_temp
|
||||
self.escape_until = escape_until
|
||||
|
||||
# loglikelihood params
|
||||
self.nll_type = nll_type
|
||||
self.log_type = log_type
|
||||
self.mc_num = mc_num
|
||||
self.classifier_free_guidance = classifier_free_guidance
|
||||
self.sampling_eps = sampling_eps
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def tok_decode(
|
||||
self, tokens: torch.Tensor | list[int], skip_special_tokens: bool = True
|
||||
) -> str:
|
||||
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
def tok_encode(self, text: str, add_special_tokens: bool = True) -> torch.Tensor:
|
||||
return self.tokenizer(
|
||||
text, return_tensors="pt", add_special_tokens=add_special_tokens
|
||||
).input_ids
|
||||
|
||||
def apply_chat_template(
|
||||
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Method to apply a chat template to a list of chat history between user and model.
|
||||
"""
|
||||
chat_templated = self.tokenizer.apply_chat_template(
|
||||
chat_history,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=not add_generation_prompt,
|
||||
)
|
||||
return chat_templated
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self.tokenizer.name_or_path.replace("/", "__")
|
||||
|
||||
def generate_until(
|
||||
self, requests: list[Instance], disable_tqdm: bool = False
|
||||
) -> list[str]:
|
||||
res = []
|
||||
pbar = tqdm(
|
||||
total=len(requests),
|
||||
disable=(disable_tqdm or (self.rank != 0)),
|
||||
desc="Running generate_until requests",
|
||||
)
|
||||
generator = DreamGenerator(model=self.model, tokenizer=self.tokenizer)
|
||||
for batch_idx in range(0, len(requests), self.batch_size):
|
||||
batch_requests = requests[batch_idx : batch_idx + self.batch_size]
|
||||
contexts, gen_args = zip(*[req.arguments for req in batch_requests])
|
||||
|
||||
# ====== BEGIN merged _generate_batch logic ======
|
||||
prompts = list(contexts)
|
||||
if self.add_bos_token:
|
||||
prompts = [self.tokenizer.bos_token + p for p in prompts]
|
||||
|
||||
# tokenize
|
||||
prompt_ids = [
|
||||
self.tokenizer(
|
||||
p, return_tensors="pt", padding=False
|
||||
).input_ids.squeeze()
|
||||
for p in prompts
|
||||
]
|
||||
prompt_lens = [len(p_id) for p_id in prompt_ids]
|
||||
|
||||
if max(prompt_lens) > self.max_length - self.max_new_tokens:
|
||||
cutoff_len = self.max_length - self.max_new_tokens
|
||||
eval_logger.warning(
|
||||
f"Prompt length {max(prompt_lens)} exceeds {cutoff_len}, cutoff on the left side"
|
||||
)
|
||||
# ✅ Correct: trim from the left side (keep the last cutoff_len tokens)
|
||||
prompt_ids = [p_id[-cutoff_len:] for p_id in prompt_ids]
|
||||
|
||||
# generation
|
||||
generation_ids = generator.generate(
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
inputs=prompt_ids,
|
||||
steps=self.steps,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
alg=self.alg,
|
||||
alg_temp=self.alg_temp,
|
||||
output_history=False,
|
||||
return_dict_in_generate=False,
|
||||
)
|
||||
# decode and cleanup
|
||||
cleaned_generation_ids = [
|
||||
(
|
||||
seq[seq.ne(self.tokenizer.eos_token_id).float().argmax().long() :]
|
||||
if (seq != self.tokenizer.eos_token_id).any()
|
||||
else seq[-1:]
|
||||
)
|
||||
for seq in generation_ids
|
||||
]
|
||||
truncated_generation_ids = [
|
||||
seq[prompt_lens[i] :] for i, seq in enumerate(cleaned_generation_ids)
|
||||
]
|
||||
responses = [
|
||||
g.lstrip("<|endoftext|>").split(self.tokenizer.eos_token, 1)[0]
|
||||
for g in self.tokenizer.batch_decode(truncated_generation_ids)
|
||||
]
|
||||
|
||||
# ====== END merged _generate_batch logic ======
|
||||
|
||||
# handle "until" truncation
|
||||
if not self.escape_until:
|
||||
for i, r in enumerate(responses):
|
||||
for s in gen_args[0]["until"]:
|
||||
r = r.split(s)[0]
|
||||
responses[i] = r
|
||||
|
||||
res.extend(responses)
|
||||
pbar.update(len(contexts))
|
||||
|
||||
return res
|
||||
|
||||
def _forward_process(
|
||||
self, batch: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
b, l = batch.shape
|
||||
# sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
|
||||
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
|
||||
indices = torch.arange(b, device=batch.device).float()
|
||||
t = (u0 + indices / b) % 1
|
||||
|
||||
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
|
||||
|
||||
p_mask = p_mask[:, None].repeat(1, l)
|
||||
|
||||
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
|
||||
# always unmask bos and eos
|
||||
mask_indices[:, 0] = False
|
||||
mask_indices[:, -1] = False
|
||||
|
||||
noisy_batch = torch.where(mask_indices, self.mask_id, batch)
|
||||
return noisy_batch, p_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def get_logits(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
prompt_index : 1D bool tensor, length=batch.shape[1]
|
||||
"""
|
||||
if self.classifier_free_guidance > 1.0:
|
||||
assert len(prompt_index) == batch.shape[1]
|
||||
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
||||
un_batch = batch.clone()
|
||||
un_batch[prompt_index] = self.mask_id
|
||||
batch = torch.cat([batch, un_batch])
|
||||
|
||||
input = batch
|
||||
|
||||
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
||||
logits = self.model(input).logits
|
||||
# since bos always unmask, the first logits will not be used
|
||||
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
if self.classifier_free_guidance > 1.0:
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + self.cfg * (logits - un_logits)
|
||||
return logits[:, : batch.shape[1]]
|
||||
|
||||
@torch.no_grad()
|
||||
def _eval_target_nll_mc(
|
||||
self, prefix: torch.Tensor | None, target: torch.Tensor
|
||||
) -> float:
|
||||
if prefix is None:
|
||||
seq = target[None, :]
|
||||
else:
|
||||
seq = torch.concatenate([prefix, target])[None, :]
|
||||
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
||||
|
||||
if self.log_type == "ftb":
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
else:
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
|
||||
|
||||
loss_acc = []
|
||||
for _ in range(max(self.mc_num // self.batch_size, 1)):
|
||||
perturbed_seq = seq.clone()
|
||||
# eval_logger.info("before noising")
|
||||
perturbed_seq_, p_mask = self._forward_process(seq)
|
||||
# eval_logger.info("end noising")
|
||||
if self.log_type == "ftb":
|
||||
perturbed_seq[:, -len(target) :] = perturbed_seq_[:, -len(target) :]
|
||||
elif self.log_type == "btf":
|
||||
perturbed_seq[:, : len(prefix)] = perturbed_seq_[:, : len(prefix)]
|
||||
elif self.log_type == "union":
|
||||
perturbed_seq = perturbed_seq_
|
||||
else:
|
||||
raise NotImplementedError(self.log_type)
|
||||
|
||||
mask_indices = perturbed_seq == self.mask_id
|
||||
logits = self.get_logits(perturbed_seq, prompt_index)
|
||||
loss = (
|
||||
F.cross_entropy(
|
||||
logits[mask_indices], seq[mask_indices], reduction="none"
|
||||
)
|
||||
/ p_mask[mask_indices]
|
||||
)
|
||||
loss = loss.sum() / self.batch_size
|
||||
loss_acc.append(loss.item())
|
||||
|
||||
return sum(loss_acc) / len(loss_acc)
|
||||
|
||||
@torch.no_grad()
|
||||
def _eval_target_nll_ar(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
|
||||
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
|
||||
assert self.log_type in ["ftb", "btf"]
|
||||
assert self.nll_type in ["ar_ftb", "ar_btf"]
|
||||
|
||||
if self.log_type == "ftb":
|
||||
prompt_index = (
|
||||
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
|
||||
< prefix.shape[1]
|
||||
)
|
||||
else:
|
||||
prompt_index = (
|
||||
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
|
||||
>= prefix.shape[1]
|
||||
)
|
||||
|
||||
if self.log_type == "ftb":
|
||||
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
|
||||
else:
|
||||
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
|
||||
|
||||
mask_index = torch.ones(
|
||||
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
|
||||
)
|
||||
if self.nll_type == "ar_ftb":
|
||||
mask_index = torch.triu(mask_index)
|
||||
else:
|
||||
mask_index = torch.tril(mask_index)
|
||||
perturbed_[mask_index] = self.mask_id
|
||||
if self.log_type == "ftb":
|
||||
perturbed_seq = torch.cat(
|
||||
[prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1
|
||||
)
|
||||
else:
|
||||
perturbed_seq = torch.cat(
|
||||
[perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1
|
||||
)
|
||||
|
||||
logits_ = []
|
||||
num = (
|
||||
len(perturbed_seq) // self.batch_size
|
||||
if len(perturbed_seq) % self.batch_size == 0
|
||||
else len(perturbed_seq) // self.batch_size + 1
|
||||
)
|
||||
for i in range(num):
|
||||
end = (
|
||||
(i + 1) * self.batch_size
|
||||
if (i + 1) * self.batch_size < len(perturbed_seq)
|
||||
else len(perturbed_seq)
|
||||
)
|
||||
perturbed_seq_ = perturbed_seq[i * self.batch_size : end]
|
||||
perturbed_seq_ = perturbed_seq_.to(self.device)
|
||||
if len(perturbed_seq_.shape) == 1:
|
||||
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
|
||||
logits = self.get_logits(perturbed_seq_, prompt_index)
|
||||
logits_.append(logits.cpu())
|
||||
logits = torch.cat(logits_, dim=0)
|
||||
|
||||
temp_index = torch.ones(
|
||||
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
|
||||
)
|
||||
if self.nll_type == "ar_ftb":
|
||||
temp_index = torch.triu(temp_index, diagonal=1)
|
||||
else:
|
||||
temp_index = torch.tril(temp_index, diagonal=-1)
|
||||
mask_index[temp_index] = False
|
||||
if self.log_type == "ftb":
|
||||
logits_index = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool
|
||||
),
|
||||
mask_index,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
logits_index = torch.cat(
|
||||
[
|
||||
mask_index,
|
||||
torch.zeros(
|
||||
(perturbed_.shape[1], target.shape[1]), dtype=torch.bool
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.log_type == "ftb":
|
||||
loss = (
|
||||
F.cross_entropy(logits[logits_index], target[0], reduction="sum")
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
else:
|
||||
loss = (
|
||||
F.cross_entropy(logits[logits_index], prefix[0], reduction="sum")
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
return loss
|
||||
|
||||
def _encode_pair(
|
||||
self, context: str, continuation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.add_bos_token:
|
||||
context = self.tokenizer.bos_token + context
|
||||
|
||||
n_spaces = len(context) - len(context.rstrip())
|
||||
if n_spaces > 0:
|
||||
continuation = context[-n_spaces:] + continuation
|
||||
context = context[:-n_spaces]
|
||||
|
||||
whole_enc = self.tokenizer.encode(context + continuation) + [
|
||||
self.tokenizer.eos_token_id
|
||||
]
|
||||
context_enc = self.tokenizer.encode(context)
|
||||
|
||||
context_enc_len = len(context_enc)
|
||||
continuation_enc = whole_enc[context_enc_len:]
|
||||
|
||||
# by default truncate on the left
|
||||
cutoff_length = max(len(whole_enc) - self.max_length, 0)
|
||||
if cutoff_length > 0:
|
||||
eval_logger.warning(
|
||||
f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side"
|
||||
)
|
||||
context_remain = context_enc_len - cutoff_length
|
||||
if context_remain > 0:
|
||||
context_enc = context_enc[-context_remain:]
|
||||
else:
|
||||
eval_logger.warning(f"All context (prompt) is truncated.")
|
||||
context_enc = ""
|
||||
continuation_enc = whole_enc[-self.max_length :]
|
||||
return context_enc, continuation_enc
|
||||
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||
def _tokenize(e):
|
||||
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
||||
return {
|
||||
"prefix_text": e["prefix"],
|
||||
"target_text": e["target"],
|
||||
"prefix": prefix,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
ds = []
|
||||
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
|
||||
out = []
|
||||
with torch.no_grad():
|
||||
for elem in tqdm(ds, desc="Computing likelihood..."):
|
||||
prefix = elem["prefix"]
|
||||
target = elem["target"]
|
||||
# likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
|
||||
if self.nll_type == "mc":
|
||||
ll = -self._eval_target_nll_mc(prefix, target)
|
||||
if self.log_type == "union":
|
||||
ll = ll / (len(target) + len(prefix))
|
||||
elif self.nll_type == "ar_ftb" or self.nll_type == "ar_btf":
|
||||
ll = -self._eval_target_nll_ar(prefix, target)
|
||||
else:
|
||||
raise NotImplementedError(self.nll_type)
|
||||
|
||||
# TODO: greedy decoding
|
||||
is_target_greedy_dec = False
|
||||
|
||||
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
||||
return out
|
||||
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_evaluate()
|
||||
426
dllm/dllm/pipelines/dream/generator.py
Normal file
426
dllm/dllm/pipelines/dream/generator.py
Normal file
@ -0,0 +1,426 @@
|
||||
"""
|
||||
reference: https://huggingface.co/Dream-org/Dream-v0-Base-7B/blob/main/generation_utils.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributions as dists
|
||||
|
||||
from dllm.utils.generation_utils import get_num_transfer_tokens
|
||||
from dllm.pipelines.dream.utils import top_p_logits, top_k_logits
|
||||
from dllm.core.generation.generator import (
|
||||
GeneratorOutput,
|
||||
GeneratorConfig,
|
||||
BaseGenerator,
|
||||
)
|
||||
|
||||
|
||||
def sample_tokens(
|
||||
logits: torch.Tensor,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
margin_confidence: bool = False,
|
||||
neg_entropy: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if temperature > 0:
|
||||
logits = logits / temperature
|
||||
if top_p is not None and top_p < 1:
|
||||
logits = top_p_logits(logits, top_p)
|
||||
if top_k is not None:
|
||||
logits = top_k_logits(logits, top_k)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if temperature > 0:
|
||||
try:
|
||||
x0 = dists.Categorical(probs=probs).sample()
|
||||
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
||||
except Exception:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
else:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
|
||||
if margin_confidence:
|
||||
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
||||
top1_probs = sorted_probs[:, 0]
|
||||
top2_probs = sorted_probs[:, 1]
|
||||
confidence = top1_probs - top2_probs
|
||||
|
||||
if neg_entropy:
|
||||
epsilon = 1e-10
|
||||
log_probs = torch.log(probs + epsilon)
|
||||
confidence = torch.sum(probs * log_probs, dim=-1)
|
||||
|
||||
return confidence, x0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamGeneratorConfig(GeneratorConfig):
|
||||
max_new_tokens: int = 20
|
||||
max_length: int = (
|
||||
None # The max_length is set as input_ids.shape[1] + 20: generation_config.max_length = generation_config.max_length + input_ids_length
|
||||
)
|
||||
steps: int = 512
|
||||
eps: float = 1e-3
|
||||
alg: str = "origin"
|
||||
alg_temp: float = 0.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = 50
|
||||
stochastic_transfer: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamGenerator(BaseGenerator):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: list[torch.Tensor, list],
|
||||
config: DreamGeneratorConfig | None = None,
|
||||
generation_tokens_hook_func=lambda step, x, logits: x,
|
||||
generation_logits_hook_func=lambda step, x, logits: logits,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
"""
|
||||
Diffusion-style masked decoding for *generation from inputs*.
|
||||
(docstring unchanged)
|
||||
"""
|
||||
if config is None:
|
||||
config = DreamGeneratorConfig()
|
||||
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
eps = kwargs.get("eps", config.eps)
|
||||
alg = kwargs.get("alg", config.alg)
|
||||
alg_temp = kwargs.get("eps", config.alg_temp)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
top_p = kwargs.get("top_p", config.top_p)
|
||||
top_k = kwargs.get("top_k", config.top_k)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
# generation_tokens_hook_func = kwargs.get("generation_tokens_hook_func", config.generation_tokens_hook_func)
|
||||
# generation_logits_hook_func = kwargs.get("generation_logits_hook_func", config.generation_logits_hook_func)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# --- Initialization ---
|
||||
mask_token_id = self.tokenizer.mask_token_id
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
prompt_lens = [p.shape[0] for p in inputs]
|
||||
if max_new_tokens:
|
||||
max_length = max_new_tokens + max(prompt_lens)
|
||||
else:
|
||||
max_new_tokens = max_length - max(prompt_lens)
|
||||
|
||||
B = len(inputs)
|
||||
T = max_length
|
||||
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
|
||||
|
||||
seq_length = []
|
||||
for i, p in enumerate(inputs):
|
||||
total_len = prompt_lens[i] + max_new_tokens
|
||||
seq_length.append(total_len)
|
||||
start = T - total_len
|
||||
x[i, start : start + prompt_lens[i]] = p
|
||||
x[i, start + prompt_lens[i] : T] = mask_token_id
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
(B, T), dtype=torch.float32, device=self.model.device
|
||||
)
|
||||
for j, L in enumerate(seq_length):
|
||||
if L > 0:
|
||||
attention_mask[j, -L:] = 1.0 # Mandate to be left-padding
|
||||
|
||||
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
||||
pos_id = attention_mask.long().cumsum(-1) - 1
|
||||
pos_id.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
pos_id = None
|
||||
|
||||
mask_index = x == mask_token_id
|
||||
num_transfer_tokens_list = get_num_transfer_tokens(
|
||||
mask_index=mask_index,
|
||||
steps=steps,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
effective_steps = num_transfer_tokens_list.size(1)
|
||||
|
||||
# --- Iterative refinement ---
|
||||
x = generation_tokens_hook_func(None, x, None)
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
for i in range(effective_steps):
|
||||
mask_index = x == mask_token_id
|
||||
|
||||
logits = self.model(x, attention_mask, pos_id).logits
|
||||
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
logits = generation_logits_hook_func(i, x, logits)
|
||||
|
||||
mask_logits = logits[mask_index]
|
||||
|
||||
if alg == "maskgit_plus":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
|
||||
)
|
||||
elif alg == "topk_margin":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
margin_confidence=True,
|
||||
)
|
||||
elif alg == "entropy":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
neg_entropy=True,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown alg: {alg}")
|
||||
|
||||
full_confidence = torch.full_like(
|
||||
x, -torch.inf, device=self.model.device, dtype=logits.dtype
|
||||
)
|
||||
full_confidence[mask_index] = confidence
|
||||
|
||||
for j in range(full_confidence.shape[0]):
|
||||
number_transfer_tokens = num_transfer_tokens_list[j, i]
|
||||
if number_transfer_tokens > 0:
|
||||
if alg_temp is None or alg_temp == 0:
|
||||
_, transfer_index = torch.topk(
|
||||
full_confidence[j], number_transfer_tokens
|
||||
)
|
||||
else:
|
||||
fc = full_confidence[j] / alg_temp
|
||||
fc = F.softmax(fc, dim=-1)
|
||||
transfer_index = torch.multinomial(
|
||||
fc, num_samples=number_transfer_tokens
|
||||
)
|
||||
|
||||
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
|
||||
x_[mask_index] = x0.clone()
|
||||
x[j, transfer_index] = x_[j, transfer_index]
|
||||
|
||||
x = generation_tokens_hook_func(i, x, logits)
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
|
||||
@torch.no_grad()
|
||||
def infill(
|
||||
self,
|
||||
inputs: list[torch.Tensor, list],
|
||||
config,
|
||||
generation_tokens_hook_func=lambda step, x, logits: x,
|
||||
generation_logits_hook_func=lambda step, x, logits: logits,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
"""
|
||||
Fill in-place the tokenizer's `<mask>` tokens contained in `inputs`.
|
||||
The whole (right-aligned) canvas is denoised iteratively: at each step, a scheduler
|
||||
decides how many masked positions to commit, and a confidence rule (`alg`)
|
||||
selects *which* positions to reveal (MaskGIT-style). Non-mask tokens are never changed.
|
||||
|
||||
High-level:
|
||||
1) Build a right-aligned canvas per sample (left side padded with EOS).
|
||||
2) Compute a per-sample transfer schedule via `scheduler` and `steps`.
|
||||
3) At each step: forward pass → AR-shift logits → score masked positions
|
||||
via `alg` → choose indices to commit (top-k or soft sampling) → write tokens.
|
||||
|
||||
Notes:
|
||||
- Right padding uses EOS (serves as pad here).
|
||||
- Only `[MASK]` positions are updated; original tokens remain intact.
|
||||
- Logits are AR-shifted to preserve next-token prediction alignment.
|
||||
|
||||
Args:
|
||||
model:
|
||||
Mask predictor; returns logits of shape [B, T, V] when called as
|
||||
`model(x, attention_mask, pos_id)`.
|
||||
tokenizer:
|
||||
Must provide `mask_token_id` and `eos_token_id`.
|
||||
inputs:
|
||||
List of 1D LongTensors (token ids). Each may contain `<mask>` tokens
|
||||
to be filled; other tokens are treated as fixed context.
|
||||
scheduler (BaseAlphaScheduler):
|
||||
Controls how many masks to commit per step (deterministic or stochastic).
|
||||
generation_tokens_hook_func / generation_logits_hook_func:
|
||||
Optional hooks to intercept tokens/logits at each step.
|
||||
output_history (bool):
|
||||
If True, save intermediate canvases at each step.
|
||||
return_dict_in_generate (bool):
|
||||
If True, return `DreamModelOutput(sequences, history)`, else only `[B, T]`.
|
||||
steps (int):
|
||||
Total reverse-diffusion steps (quality–speed trade-off).
|
||||
alg (str):
|
||||
Confidence rule to rank masked positions:
|
||||
- "maskgit_plus": softmax probs
|
||||
- "topk_margin": top1 - top2 margin
|
||||
- "entropy": negative entropy
|
||||
alg_temp (float):
|
||||
Temperature for *confidence-based index sampling* (when > 0, soft selection).
|
||||
temperature / top_p / top_k:
|
||||
Token sampling hyperparameters within `sample_tokens`.
|
||||
stochastic_transfer (bool):
|
||||
If True, sample the number of transfers per step (Binomial); else use expectation.
|
||||
|
||||
Returns:
|
||||
DreamModelOutput | torch.LongTensor:
|
||||
If `return_dict_in_generate=True`, returns
|
||||
- sequences: `[B, T]` final tokens
|
||||
- history: optional list of intermediate canvases
|
||||
Otherwise returns only `[B, T]`.
|
||||
"""
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
eps = kwargs.get("eps", config.eps)
|
||||
alg = kwargs.get("alg", config.alg)
|
||||
alg_temp = kwargs.get("eps", config.alg_temp)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
top_p = kwargs.get("top_p", config.top_p)
|
||||
top_k = kwargs.get("top_k", config.top_k)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
# generation_tokens_hook_func = kwargs.get("stochastic_transfer", config.generation_tokens_hook_func)
|
||||
# generation_logits_hook_func = kwargs.get("stochastic_transfer", config.generation_logits_hook_func)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# --- Initialization ---
|
||||
mask_token_id = self.tokenizer.mask_token_id
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
|
||||
B = len(inputs)
|
||||
seq_lens = [t.shape[0] for t in inputs]
|
||||
T = max(seq_lens)
|
||||
|
||||
# Build right-aligned canvas; left side padded with EOS (used as pad)
|
||||
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
|
||||
for i, t in enumerate(inputs):
|
||||
L = seq_lens[i]
|
||||
x[i, -L:] = t
|
||||
|
||||
# Build 1D attention mask (valid tokens on the right)
|
||||
attention_mask = torch.zeros((B, T), dtype=torch.bool, device=self.model.device)
|
||||
for j, L in enumerate(seq_lens):
|
||||
if L > 0:
|
||||
attention_mask[j, -L:] = True
|
||||
|
||||
# Expand to pairwise attention if left padding is present
|
||||
if torch.any(attention_mask == 0.0):
|
||||
pos_id = attention_mask.long().cumsum(-1) - 1
|
||||
pos_id.masked_fill_(attention_mask == 0, 1)
|
||||
else:
|
||||
pos_id = None
|
||||
attention_mask = "full"
|
||||
|
||||
# Precompute per-sample transfer schedule (how many to commit per step)
|
||||
mask_index = x == mask_token_id
|
||||
num_transfer_tokens_list = get_num_transfer_tokens(
|
||||
mask_index=mask_index,
|
||||
steps=steps,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
effective_steps = num_transfer_tokens_list.size(1)
|
||||
|
||||
# Optional initial token hook
|
||||
x = generation_tokens_hook_func(None, x, None)
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
for i in range(effective_steps):
|
||||
mask_index = x == mask_token_id
|
||||
|
||||
# Forward pass, then AR-shift to predict token at position i+1
|
||||
logits = self.model(x, attention_mask, pos_id).logits
|
||||
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
logits = generation_logits_hook_func(i, x, logits)
|
||||
|
||||
# Logits restricted to current `[MASK]` positions
|
||||
mask_logits = logits[mask_index]
|
||||
|
||||
# Confidence scoring for masked positions
|
||||
if alg == "maskgit_plus":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
|
||||
)
|
||||
elif alg == "topk_margin":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
margin_confidence=True,
|
||||
)
|
||||
elif alg == "entropy":
|
||||
confidence, x0 = sample_tokens(
|
||||
mask_logits,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
neg_entropy=True,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown alg: {alg}")
|
||||
|
||||
# Scatter per-position confidence back to full canvas
|
||||
full_confidence = torch.full_like(
|
||||
x, -torch.inf, device=self.model.device, dtype=logits.dtype
|
||||
)
|
||||
full_confidence[mask_index] = confidence
|
||||
|
||||
# Commit the scheduled number of tokens per sample
|
||||
for j in range(B):
|
||||
number_transfer_tokens = num_transfer_tokens_list[j, i]
|
||||
if number_transfer_tokens > 0:
|
||||
if alg_temp is None or alg_temp == 0:
|
||||
_, transfer_index = torch.topk(
|
||||
full_confidence[j], number_transfer_tokens
|
||||
)
|
||||
else:
|
||||
fc = full_confidence[j] / alg_temp
|
||||
fc = F.softmax(fc, dim=-1)
|
||||
transfer_index = torch.multinomial(
|
||||
fc, num_samples=number_transfer_tokens
|
||||
)
|
||||
|
||||
# Candidate tokens at masked positions only
|
||||
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
|
||||
x_[mask_index] = x0.clone()
|
||||
x[j, transfer_index] = x_[j, transfer_index]
|
||||
|
||||
# Optional token hook + history logging
|
||||
x = generation_tokens_hook_func(i, x, logits)
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
13
dllm/dllm/pipelines/dream/models/__init__.py
Normal file
13
dllm/dllm/pipelines/dream/models/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from .configuration_dream import DreamConfig
|
||||
from .modeling_dream import DreamModel
|
||||
|
||||
# Register with HuggingFace Auto classes for local usage
|
||||
try:
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
|
||||
|
||||
AutoConfig.register("Dream", DreamConfig)
|
||||
AutoModel.register(DreamConfig, DreamModel)
|
||||
AutoModelForMaskedLM.register(DreamConfig, DreamModel)
|
||||
except ImportError:
|
||||
# transformers not available or Auto classes not imported
|
||||
pass
|
||||
85
dllm/dllm/pipelines/dream/models/configuration_dream.py
Normal file
85
dllm/dllm/pipelines/dream/models/configuration_dream.py
Normal file
@ -0,0 +1,85 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dream model configuration"""
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DreamConfig(PretrainedConfig):
|
||||
model_type = "Dream"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
intermediate_size=22016,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=False, # cache not used in diffusion
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
mask_token_id=151666,
|
||||
pad_token_id=151643,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.mask_token_id = mask_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
465
dllm/dllm/pipelines/dream/models/generation_utils.py
Normal file
465
dllm/dllm/pipelines/dream/models/generation_utils.py
Normal file
@ -0,0 +1,465 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributions as dists
|
||||
from torch.nn import functional as F
|
||||
from transformers import __version__
|
||||
from transformers.generation.configuration_utils import (
|
||||
GenerationConfig
|
||||
)
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def top_p_logits(logits, top_p=None):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
||||
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
def top_k_logits(logits, top_k=None):
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
|
||||
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
||||
|
||||
if temperature > 0:
|
||||
logits = logits / temperature
|
||||
if top_p is not None and top_p < 1:
|
||||
logits = top_p_logits(logits, top_p)
|
||||
if top_k is not None:
|
||||
logits = top_k_logits(logits, top_k)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
if temperature > 0:
|
||||
try:
|
||||
x0 = dists.Categorical(probs=probs).sample()
|
||||
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
||||
except:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
else:
|
||||
confidence, x0 = probs.max(dim=-1)
|
||||
|
||||
if margin_confidence:
|
||||
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
||||
# Extract top1 and top2 probabilities
|
||||
top1_probs = sorted_probs[:, 0]
|
||||
top2_probs = sorted_probs[:, 1]
|
||||
# Calculate confidence as top1 - top2
|
||||
confidence = top1_probs - top2_probs
|
||||
|
||||
if neg_entropy:
|
||||
epsilon = 1e-10
|
||||
log_probs = torch.log(probs + epsilon)
|
||||
confidence = torch.sum(probs * log_probs, dim=-1)
|
||||
|
||||
return confidence, x0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamModelOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
history: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class DreamGenerationConfig(GenerationConfig):
|
||||
def __init__(self, **kwargs):
|
||||
self.temperature: float = kwargs.pop("temperature", 0.0)
|
||||
self.top_p: Optional[float] = kwargs.pop("top_p", None)
|
||||
self.top_k: Optional[int] = kwargs.pop("top_k", None)
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
|
||||
# diffusion specific params
|
||||
self.eps: float = kwargs.pop("eps", 1e-3)
|
||||
self.steps: int = kwargs.pop("steps", 512)
|
||||
self.alg: str = kwargs.pop("alg", 'origin')
|
||||
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
|
||||
|
||||
# Parameters that define the output variables of `generate`
|
||||
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
|
||||
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
|
||||
self.output_history: bool = kwargs.pop("output_history", False)
|
||||
|
||||
# Special tokens that can be used at generation time
|
||||
self.mask_token_id = kwargs.pop("mask_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
||||
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
||||
# interface.
|
||||
self._from_model_config = kwargs.pop("_from_model_config", False)
|
||||
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||
self.transformers_version = kwargs.pop("transformers_version", __version__)
|
||||
|
||||
# Additional attributes without default values
|
||||
if not self._from_model_config:
|
||||
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
|
||||
# model's default configuration file
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except AttributeError as err:
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
# Validate the values of the attributes
|
||||
self.validate(is_init=True)
|
||||
|
||||
def validate(self, is_init=False, **kwargs):
|
||||
pass
|
||||
|
||||
class DreamGenerationMixin:
|
||||
@staticmethod
|
||||
def _expand_inputs_for_generation(
|
||||
expand_size: int = 1,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None
|
||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
|
||||
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
|
||||
# the input tensor and thus requires more memory although no change is applied
|
||||
if expand_size == 1:
|
||||
return input_ids, attention_mask
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
|
||||
return input_ids, attention_mask
|
||||
|
||||
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
|
||||
"""Performs validation related to the resulting generated length"""
|
||||
|
||||
# Can't throw warnings/exceptions during compilation
|
||||
if is_torchdynamo_compiling():
|
||||
return
|
||||
|
||||
# 1. Max length warnings related to poor parameterization
|
||||
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
|
||||
# 20 is the default max_length of the generation config
|
||||
warnings.warn(
|
||||
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
|
||||
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
|
||||
"generation.",
|
||||
UserWarning,
|
||||
)
|
||||
if input_ids_length >= generation_config.max_length:
|
||||
input_ids_string = "input_ids"
|
||||
raise ValueError(
|
||||
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing `max_length` or, better yet, setting `max_new_tokens`."
|
||||
)
|
||||
|
||||
def _prepare_generated_length(
|
||||
self,
|
||||
generation_config,
|
||||
has_default_max_length,
|
||||
input_ids_length,
|
||||
):
|
||||
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
|
||||
|
||||
if generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length and generation_config.max_length is not None:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
|
||||
elif has_default_max_length:
|
||||
if generation_config.max_length == DreamGenerationConfig().max_length:
|
||||
generation_config.max_length = generation_config.max_length + input_ids_length
|
||||
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
|
||||
if max_position_embeddings is not None:
|
||||
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
|
||||
|
||||
return generation_config
|
||||
|
||||
def _prepare_generation_config(
|
||||
self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
|
||||
) -> DreamGenerationConfig:
|
||||
"""
|
||||
Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
||||
function handles retrocompatibility with respect to configuration files.
|
||||
"""
|
||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||
using_model_generation_config = False
|
||||
if generation_config is None:
|
||||
generation_config = DreamGenerationConfig.from_model_config(self.config)
|
||||
using_model_generation_config = True
|
||||
|
||||
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
|
||||
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
|
||||
# exception will be raised in `_validate_model_kwargs`
|
||||
if not is_torchdynamo_compiling():
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
_kwargs = generation_config.update(**kwargs)
|
||||
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
|
||||
if not using_model_generation_config:
|
||||
if generation_config.bos_token_id is None:
|
||||
generation_config.bos_token_id = self.generation_config.bos_token_id
|
||||
if generation_config.eos_token_id is None:
|
||||
generation_config.eos_token_id = self.generation_config.eos_token_id
|
||||
if generation_config.pad_token_id is None:
|
||||
generation_config.pad_token_id = self.generation_config.pad_token_id
|
||||
if generation_config.mask_token_id is None:
|
||||
generation_config.mask_token_id = self.generation_config.mask_token_id
|
||||
|
||||
return generation_config
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
generation_config: DreamGenerationConfig,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
"""
|
||||
Prepares the special tokens for generation, overwriting the generation config with their processed versions
|
||||
converted to tensor.
|
||||
|
||||
Note that `generation_config` is changed in place and stops being serializable after this method is called.
|
||||
That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
|
||||
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
||||
"""
|
||||
|
||||
# Convert special tokens to tensors
|
||||
def _tensor_or_none(token, device=None):
|
||||
if token is None:
|
||||
return token
|
||||
|
||||
device = device if device is not None else self.device
|
||||
if isinstance(token, torch.Tensor):
|
||||
return token.to(device)
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
|
||||
|
||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
|
||||
eos_token_tensor = eos_token_tensor.unsqueeze(0)
|
||||
|
||||
# Set pad token if unset (and there are conditions to do so)
|
||||
if pad_token_tensor is None and eos_token_tensor is not None:
|
||||
pad_token_tensor = eos_token_tensor[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
||||
|
||||
# Update generation config with the updated special tokens tensors
|
||||
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
|
||||
# (in their non-tensor form), in order to enable end-to-end compilation. See
|
||||
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
|
||||
generation_config._bos_token_tensor = bos_token_tensor
|
||||
generation_config._eos_token_tensor = eos_token_tensor
|
||||
generation_config._pad_token_tensor = pad_token_tensor
|
||||
generation_config._mask_token_tensor = mask_token_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def diffusion_generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[DreamGenerationConfig] = None,
|
||||
**kwargs,
|
||||
) -> Union[DreamModelOutput, torch.LongTensor]:
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
generation_config = self._prepare_generation_config(generation_config, **kwargs)
|
||||
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
|
||||
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
|
||||
|
||||
# 2. Define model inputs
|
||||
assert inputs is not None
|
||||
input_ids = inputs
|
||||
device = input_ids.device
|
||||
attention_mask = kwargs.pop("attention_mask", None)
|
||||
self._prepare_special_tokens(generation_config, device=device)
|
||||
|
||||
# 3. Prepare `max_length`.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 4. Check input_ids
|
||||
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
|
||||
warnings.warn(
|
||||
"You are calling .generate() with the `input_ids` being on a device type different"
|
||||
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
|
||||
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
|
||||
" Please make sure that you have put `input_ids` to the"
|
||||
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
||||
" running `.generate()`.",
|
||||
UserWarning,
|
||||
)
|
||||
if (
|
||||
hasattr(generation_config, "pad_token_id") and
|
||||
torch.any(input_ids == generation_config.pad_token_id) and
|
||||
attention_mask is None
|
||||
):
|
||||
warnings.warn(
|
||||
"Padding was detected but no attention mask is passed here. For correct "
|
||||
"generation results, please set `attention_mask` when batch-padding inputs.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
input_ids, attention_mask = self._expand_inputs_for_generation(
|
||||
expand_size=generation_config.num_return_sequences,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask
|
||||
)
|
||||
|
||||
result = self._sample(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
generation_tokens_hook_func=generation_tokens_hook_func,
|
||||
generation_logits_hook_func=generation_logits_hook_func
|
||||
)
|
||||
return result
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor],
|
||||
generation_config: DreamGenerationConfig,
|
||||
generation_tokens_hook_func,
|
||||
generation_logits_hook_func
|
||||
) -> Union[DreamModelOutput, torch.LongTensor]:
|
||||
# init values
|
||||
|
||||
output_history = generation_config.output_history
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
max_length = generation_config.max_length
|
||||
mask_token_id = generation_config.mask_token_id
|
||||
steps = generation_config.steps
|
||||
eps = generation_config.eps
|
||||
alg = generation_config.alg
|
||||
alg_temp = generation_config.alg_temp
|
||||
temperature = generation_config.temperature
|
||||
top_p = generation_config.top_p
|
||||
top_k = generation_config.top_k
|
||||
|
||||
histories = [] if (return_dict_in_generate and output_history) else None
|
||||
|
||||
# pad input_ids to max_length
|
||||
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
|
||||
|
||||
if attention_mask is not None and torch.any(attention_mask == 0.0):
|
||||
# we do not mask the [MASK] tokens so value = 1.0
|
||||
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
|
||||
tok_idx = attention_mask.long().cumsum(-1) - 1
|
||||
tok_idx.masked_fill_(attention_mask == 0, 1)
|
||||
# attention_mask is of shape [B, N]
|
||||
# broadcast to [B, 1, N, N]
|
||||
attention_mask = torch.logical_and(
|
||||
attention_mask.unsqueeze(1).unsqueeze(-2),
|
||||
attention_mask.unsqueeze(1).unsqueeze(-1),
|
||||
)
|
||||
else:
|
||||
tok_idx = None
|
||||
attention_mask = "full"
|
||||
|
||||
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
|
||||
|
||||
# this allows user-defined token control of the intermediate steps
|
||||
x = generation_tokens_hook_func(None, x, None)
|
||||
for i in range(steps):
|
||||
|
||||
mask_index = (x == mask_token_id)
|
||||
logits = self(x, attention_mask, tok_idx).logits
|
||||
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
|
||||
# this allows user-defined logits control of the intermediate steps
|
||||
logits = generation_logits_hook_func(i, x, logits)
|
||||
|
||||
mask_logits = logits[mask_index]
|
||||
t = timesteps[i]
|
||||
s = timesteps[i + 1]
|
||||
|
||||
if alg == 'origin':
|
||||
p_transfer = 1 - s / t if i < steps - 1 else 1
|
||||
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
|
||||
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
|
||||
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
x[mask_index] = x0.clone()
|
||||
else:
|
||||
if alg == 'maskgit_plus':
|
||||
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
elif alg == 'topk_margin':
|
||||
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
|
||||
elif alg == 'entropy':
|
||||
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown alg: {alg}")
|
||||
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
||||
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
||||
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
||||
full_confidence[mask_index] = confidence
|
||||
if number_transfer_tokens > 0:
|
||||
if alg_temp is None or alg_temp == 0:
|
||||
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
||||
else:
|
||||
full_confidence = full_confidence / alg_temp
|
||||
full_confidence = F.softmax(full_confidence, dim=-1)
|
||||
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
||||
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
||||
x_[mask_index] = x0.clone()
|
||||
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
||||
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
||||
|
||||
# this allows user-defined token control of the intermediate steps
|
||||
x = generation_tokens_hook_func(i, x, logits)
|
||||
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
if return_dict_in_generate:
|
||||
return DreamModelOutput(
|
||||
sequences=x,
|
||||
history=histories,
|
||||
)
|
||||
else:
|
||||
return x
|
||||
850
dllm/dllm/pipelines/dream/models/modeling_dream.py
Normal file
850
dllm/dllm/pipelines/dream/models/modeling_dream.py
Normal file
@ -0,0 +1,850 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT and Qwen implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Dream model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
MaskedLMOutput,
|
||||
)
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
from transformers import PretrainedConfig
|
||||
from .configuration_dream import DreamConfig
|
||||
from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "Dream-7B"
|
||||
_CONFIG_FOR_DOC = "DreamConfig"
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
|
||||
class DreamRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
DreamRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
|
||||
class DreamRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[DreamConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
def reset_parameters(self):
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
|
||||
class DreamMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_state):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class DreamAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
and "Generating Long Sequences with Sparse Transformers".
|
||||
"""
|
||||
|
||||
def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = False
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = DreamRotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class DreamSdpaAttention(DreamAttention):
|
||||
"""
|
||||
Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from DreamAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# causal_mask = attention_mask
|
||||
# if attention_mask is not None: # no matter the length, we just slice it
|
||||
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
# is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=False, # hard coded
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class DreamDecoderLayer(nn.Module):
|
||||
def __init__(self, config: DreamConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_once(
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
|
||||
# self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
self.self_attn = DreamSdpaAttention(config, layer_idx)
|
||||
|
||||
self.mlp = DreamMLP(config)
|
||||
self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
class DreamPreTrainedModel(PreTrainedModel):
|
||||
config_class = DreamConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DreamDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
*model_args,
|
||||
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
use_safetensors: Optional[bool] = None,
|
||||
weights_only: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
_model = super().from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
*model_args,
|
||||
config=config,
|
||||
cache_dir=cache_dir,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
use_safetensors=use_safetensors,
|
||||
weights_only=weights_only,
|
||||
**kwargs,
|
||||
)
|
||||
# NOTE(Lin): we need to override the generation config
|
||||
# because the generation config loaded in `from_pretrained`
|
||||
# does not include all the attributes of DreamGenerationConfig
|
||||
resume_download = kwargs.get("resume_download", None)
|
||||
proxies = kwargs.get("proxies", None)
|
||||
subfolder = kwargs.get("subfolder", "")
|
||||
from_auto_class = kwargs.get("_from_auto", False)
|
||||
from_pipeline = kwargs.get("_from_pipeline", None)
|
||||
_model.generation_config = DreamGenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
)
|
||||
return _model
|
||||
|
||||
class DreamBaseModel(DreamPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: DreamConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: DreamConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = DreamRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = DreamBaseModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def reset_rope_parameters(self):
|
||||
self.model.rotary_emb.reset_parameters()
|
||||
for layer in self.model.layers:
|
||||
layer.self_attn.rotary_emb.reset_parameters()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, MaskedLMOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if isinstance(attention_mask, str) and attention_mask == "full" or attention_mask == None:
|
||||
# whether attention_mask is full
|
||||
pass
|
||||
|
||||
elif isinstance(attention_mask, torch.Tensor):
|
||||
if not torch.any(attention_mask == 0.0):
|
||||
attention_mask = 'full'
|
||||
elif attention_mask.dim() == 2:
|
||||
# [B, L] → [B, 1, L, L]
|
||||
attention_mask = torch.logical_and(
|
||||
attention_mask.unsqueeze(1).unsqueeze(-2),
|
||||
attention_mask.unsqueeze(1).unsqueeze(-1),
|
||||
)
|
||||
attention_mask = attention_mask.to(torch.bool)
|
||||
|
||||
elif attention_mask.dim() in (3, 4):
|
||||
# already extended/broadcasted form
|
||||
if attention_mask.dtype != torch.bool:
|
||||
attention_mask = attention_mask.to(torch.bool)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected attention_mask shape: {attention_mask.shape}")
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unsupported attention_mask type: {type(attention_mask)}")
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
346
dllm/dllm/pipelines/dream/models/tokenization_dream.py
Normal file
346
dllm/dllm/pipelines/dream/models/tokenization_dream.py
Normal file
@ -0,0 +1,346 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Dream team, HKUNLP Group and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on Qwen's implementations in this library.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for Dream."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import unicodedata
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import regex as re
|
||||
|
||||
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
}
|
||||
|
||||
|
||||
MAX_MODEL_INPUT_SIZES = {"dream/dream-tokenizer": 32768}
|
||||
|
||||
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||||
|
||||
|
||||
@lru_cache()
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
||||
characters the bpe code barfs on.
|
||||
|
||||
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
||||
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
||||
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
||||
tables between utf-8 bytes and unicode strings.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class DreamTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a Dream tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
|
||||
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Dream-org/Dream-v0-Base-7B", trust_remote_code=True)
|
||||
>>> tokenizer("Hello world")["input_ids"]
|
||||
[9707, 1879]
|
||||
|
||||
>>> tokenizer(" Hello world")["input_ids"]
|
||||
[21927, 1879]
|
||||
```
|
||||
This is expected.
|
||||
|
||||
You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
errors (`str`, *optional*, defaults to `"replace"`):
|
||||
Paradigm to follow when decoding bytes to UTF-8. See
|
||||
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token. Not applicable for this tokenizer.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
|
||||
tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
|
||||
split_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the special tokens should be split during the tokenization process. The default behavior is
|
||||
to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
|
||||
['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
|
||||
'|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
errors="replace",
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token=None,
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>",
|
||||
clean_up_tokenization_spaces=False,
|
||||
split_special_tokens=False,
|
||||
**kwargs,
|
||||
):
|
||||
# Dream vocab does not contain control tokens; added tokens need to be special
|
||||
bos_token = (
|
||||
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(bos_token, str)
|
||||
else bos_token
|
||||
)
|
||||
eos_token = (
|
||||
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(eos_token, str)
|
||||
else eos_token
|
||||
)
|
||||
unk_token = (
|
||||
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(unk_token, str)
|
||||
else unk_token
|
||||
)
|
||||
pad_token = (
|
||||
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
|
||||
if isinstance(pad_token, str)
|
||||
else pad_token
|
||||
)
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
bpe_merges = []
|
||||
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||
for i, line in enumerate(merges_handle):
|
||||
line = line.strip()
|
||||
if (i == 0 and line.startswith("#version:")) or not line:
|
||||
continue
|
||||
bpe_merges.append(tuple(line.split()))
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
# NOTE: the cache can grow without bound and will get really large for long running processes
|
||||
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
|
||||
# not a memory leak but appears as one.
|
||||
# GPT2Tokenizer has the same problem, so let's be consistent.
|
||||
self.cache = {}
|
||||
|
||||
self.pat = re.compile(PRETOKENIZE_REGEX)
|
||||
|
||||
if kwargs.get("add_prefix_space", False):
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
errors=errors,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
unk_token=unk_token,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
split_special_tokens=split_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.encoder)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
|
||||
def get_vocab(self):
|
||||
return dict(self.encoder, **self.added_tokens_encoder)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
else:
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
|
||||
def _tokenize(self, text):
|
||||
"""Tokenize a string."""
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = "".join(
|
||||
self.byte_encoder[b] for b in token.encode("utf-8")
|
||||
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.decoder.get(index)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
text = "".join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||
return text
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = False,
|
||||
spaces_between_special_tokens: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
|
||||
# and cannot be configured elsewhere, but it should default to False for DreamTokenizer
|
||||
return super().decode(
|
||||
token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
merge_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
||||
)
|
||||
|
||||
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write("#version: 0.2\n")
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!"
|
||||
)
|
||||
index = token_index
|
||||
writer.write(" ".join(bpe_tokens) + "\n")
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
def prepare_for_tokenization(self, text, **kwargs):
|
||||
text = unicodedata.normalize("NFC", text)
|
||||
return (text, kwargs)
|
||||
|
||||
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
|
||||
from .configuration_dream import DreamConfig
|
||||
|
||||
TOKENIZER_MAPPING.register(DreamConfig, (DreamTokenizer, None))
|
||||
84
dllm/dllm/pipelines/dream/trainer.py
Normal file
84
dllm/dllm/pipelines/dream/trainer.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from dllm.core.trainers import MDLMTrainer
|
||||
|
||||
|
||||
def cart_weight(
|
||||
masked_indices: torch.Tensor, t: torch.Tensor, p: float = 0.3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Optimized CART weight computation using matrix operations.
|
||||
|
||||
Args:
|
||||
masked_indices (torch.Tensor): (b, l) bool tensor indicating masked positions.
|
||||
t (torch.Tensor): (b,) time steps (0-1 sampled uniformly). Not directly used in CART.
|
||||
p (float): Parameter of geometric distribution (0 < p <= 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (b, l) float tensor of weights.
|
||||
"""
|
||||
b, l = masked_indices.shape
|
||||
device = masked_indices.device
|
||||
|
||||
idx = torch.arange(l, device=device)
|
||||
dist_matrix = (idx[None, :] - idx[:, None]).abs() - 1
|
||||
dist_matrix = torch.clamp(dist_matrix, min=0) # (l, l)
|
||||
geo_matrix = (
|
||||
torch.log(torch.tensor(p, device=device))
|
||||
+ (dist_matrix - 1).clamp(min=0) * torch.log(torch.tensor(1 - p, device=device))
|
||||
).exp() * 0.5 # Ensure numerical stability
|
||||
geo_matrix.masked_fill_(dist_matrix == 0, 0.0) # ignore distance = 0
|
||||
|
||||
valid_mask = (~masked_indices).float() # (b, l), 1 = unmasked
|
||||
weights = valid_mask @ geo_matrix.T # (b, l)
|
||||
weights = weights * masked_indices.float()
|
||||
return weights
|
||||
|
||||
|
||||
class DreamTrainer(MDLMTrainer):
|
||||
"""
|
||||
DreamTrainer: specialization of MDLMTrainer for Dream training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
loss_weight_type: str = "cart[geo_p:0.3]",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, loss_weight_type=loss_weight_type, **kwargs)
|
||||
|
||||
def _preprocess_inputs(self, inputs):
|
||||
labels = inputs["labels"]
|
||||
assert (labels[:, 0] == -100).all()
|
||||
|
||||
def _postprocess_outputs(self, outputs):
|
||||
logits = outputs.logits
|
||||
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
def _compute_loss_weights(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
inputs: dict[str, Any],
|
||||
masked_indices: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if self.loss_weight_type.startswith("cart"):
|
||||
# parse geo_p
|
||||
import re
|
||||
|
||||
match = re.search(r"geo_p:(0\.\d+)", self.loss_weight_type)
|
||||
geo_p = float(match.group(1)) if match else 0.3
|
||||
loss_weights = cart_weight(masked_indices, t, p=geo_p)
|
||||
else:
|
||||
loss_weights = super()._compute_loss_weights(
|
||||
t=t,
|
||||
inputs=inputs,
|
||||
masked_indices=masked_indices,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
return loss_weights
|
||||
180
dllm/dllm/pipelines/dream/utils.py
Normal file
180
dllm/dllm/pipelines/dream/utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
|
||||
|
||||
def top_p_logits(logits, top_p=None):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
||||
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=None):
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
||||
return logits
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamSFTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Randomly crop response length to reduce length bias during generation.
|
||||
|
||||
Reference: https://github.com/DreamLM/Dream/blob/main/src/trainer/fsdp_sft_trainer.py
|
||||
"""
|
||||
|
||||
perbatch_cutoff: bool = True # Use prebatch truncation if True
|
||||
resp_cutoff_ratio: float = 0.0 # Prob. of post-collation truncation
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 1) Pre-collation truncation (per-sample)
|
||||
# -------------------------------------------------------------------------
|
||||
def apply_perbatch_cutoff(self, features):
|
||||
"""
|
||||
Randomly pick a response length from batch (`kept_len`) and trim other responses.
|
||||
Before:
|
||||
[<--promptA----><------responseA------>]
|
||||
[<--promptB-><---responseB--->]
|
||||
[<---promptC----><--respC-->]
|
||||
After:
|
||||
[<--promptA----><---respA--->]
|
||||
[<--promptB-><--respB-->]
|
||||
[<---promptC----><--respC-->]
|
||||
kept_len = 10 → trim each response to ≤10 tokens (before padding)
|
||||
"""
|
||||
resp_lens = torch.tensor(
|
||||
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
|
||||
)
|
||||
kept_len = int(np.random.choice(resp_lens))
|
||||
for f, r_len in zip(features, resp_lens):
|
||||
remove_len = max(r_len - kept_len, 0)
|
||||
if remove_len > 0:
|
||||
# f["input_ids"] = f["input_ids"][:-remove_len]
|
||||
# f["attention_mask"] = f["attention_mask"][:-remove_len]
|
||||
# f["labels"] = f["labels"][:-remove_len]
|
||||
for key in ["input_ids", "labels", "attention_mask"]:
|
||||
if key in f:
|
||||
f[key] = f[key][:-remove_len]
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 2) Post-collation truncation
|
||||
# -------------------------------------------------------------------------
|
||||
def apply_resp_cutoff(self, batch, features):
|
||||
"""
|
||||
Uniformly chop tail *after padding*. All sequences truncated to new_seq_len.
|
||||
Before:
|
||||
[<--promptA----><-----respA----->] 40
|
||||
[<--promptB-><respB><----pad---->] 40
|
||||
[<---promptC----><--respC--><pad>] 40
|
||||
cutoff_len = 5
|
||||
After:
|
||||
[<--promptA----><--respA--->] 35
|
||||
[<--promptB-><respB><--pad->] 35
|
||||
[<---promptC----><--respC-->] 35
|
||||
"""
|
||||
orig_seq_lens = [len(f["input_ids"]) for f in features]
|
||||
resp_lens = torch.tensor(
|
||||
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
|
||||
)
|
||||
min_resp_len = resp_lens.min().item()
|
||||
if min_resp_len <= 1:
|
||||
return batch
|
||||
|
||||
cutoff_len = int(np.random.randint(1, min_resp_len))
|
||||
new_seq_len = max(orig_seq_lens) - cutoff_len
|
||||
|
||||
for key in ["input_ids", "labels", "attention_mask"]:
|
||||
if key in batch:
|
||||
batch[key] = batch[key][:, :new_seq_len].contiguous()
|
||||
return batch
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# 3) Main call: pick truncation mode
|
||||
# -------------------------------------------------------------------------
|
||||
def __call__(self, features, return_tensors=None):
|
||||
# optional pre-collation truncation
|
||||
if self.perbatch_cutoff:
|
||||
features = self.apply_perbatch_cutoff(features)
|
||||
|
||||
# always collate only the needed fields
|
||||
base = [
|
||||
{k: f[k] for k in ("input_ids", "labels", "attention_mask") if k in f}
|
||||
for f in features
|
||||
]
|
||||
batch = super().__call__(base, return_tensors=return_tensors)
|
||||
|
||||
# optional post-collation truncation
|
||||
if (
|
||||
not self.perbatch_cutoff
|
||||
and self.resp_cutoff_ratio > 0
|
||||
and np.random.rand() < self.resp_cutoff_ratio
|
||||
):
|
||||
batch = self.apply_resp_cutoff(batch, features)
|
||||
|
||||
batch.pop("prompt_len", None)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DreamPTCollator(transformers.DataCollatorForSeq2Seq):
|
||||
random_length_ratio: float = 0.01
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors=return_tensors)
|
||||
input_ids, labels, attention_mask = (
|
||||
outputs["input_ids"],
|
||||
outputs["labels"],
|
||||
outputs["attention_mask"],
|
||||
)
|
||||
bsz, seq_len = input_ids.shape
|
||||
|
||||
# --- Random truncation for robustness ---
|
||||
if torch.rand(1).item() < self.random_length_ratio:
|
||||
random_len = torch.randint(1, seq_len + 1, (1,)).item()
|
||||
input_ids = input_ids[:, :random_len]
|
||||
labels = labels[:, :random_len]
|
||||
attention_mask = attention_mask[:, :random_len]
|
||||
|
||||
# --- Add BOS token to the beginning of input_ids ---
|
||||
bos = torch.full(
|
||||
(bsz, 1),
|
||||
self.tokenizer.bos_token_id,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
input_ids = torch.cat([bos, input_ids], dim=1)
|
||||
|
||||
# --- Prepend zeros to labels instead of BOS ---
|
||||
ignore_labels = self.label_pad_token_id * torch.ones(
|
||||
(bsz, 1), dtype=labels.dtype, device=labels.device
|
||||
)
|
||||
labels = torch.cat([ignore_labels, labels], dim=1)
|
||||
|
||||
# --- Prepend ones to attention_mask ---
|
||||
bos_attention = torch.ones(
|
||||
(bsz, 1), dtype=attention_mask.dtype, device=attention_mask.device
|
||||
)
|
||||
attention_mask = torch.cat([bos_attention, attention_mask], dim=1)
|
||||
|
||||
# --- Update and return ---
|
||||
outputs["input_ids"] = input_ids
|
||||
outputs["labels"] = labels
|
||||
outputs["attention_mask"] = attention_mask
|
||||
# Check if attention_mask is all ones and set it to None
|
||||
if torch.all(outputs["attention_mask"] == 1):
|
||||
outputs.pop("attention_mask")
|
||||
return outputs
|
||||
14
dllm/dllm/pipelines/editflow/__init__.py
Normal file
14
dllm/dllm/pipelines/editflow/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
from . import trainer, utils
|
||||
from .models.dream.modelling_dream import (
|
||||
EditFlowDreamConfig,
|
||||
EditFlowDreamModel,
|
||||
)
|
||||
from .models.llada.modelling_llada import (
|
||||
EditFlowLLaDAConfig,
|
||||
EditFlowLLaDAModel,
|
||||
)
|
||||
from .models.bert.modelling_modernbert import (
|
||||
EditFlowModernBertConfig,
|
||||
EditFlowModernBertModel,
|
||||
)
|
||||
from dllm.pipelines.editflow.trainer import EditFlowTrainer
|
||||
@ -0,0 +1,89 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
class EditFlowModernBertConfig(transformers.ModernBertConfig):
|
||||
model_type = "editflow-modernbert" # <- NEW model_type
|
||||
|
||||
|
||||
class EditFlowModernBertModel(transformers.ModernBertForMaskedLM):
|
||||
config_class = EditFlowModernBertConfig
|
||||
modules_to_save = {
|
||||
"rate_heads",
|
||||
"sub_logits",
|
||||
"ins_logits",
|
||||
} # fully fintuned even using lora
|
||||
|
||||
def __init__(self, config):
|
||||
# fa2 has bugs when forward(output_hidden_states=True)
|
||||
config._attn_implementation = "sdpa"
|
||||
super().__init__(config)
|
||||
in_lm, out_lm = self.decoder.in_features, self.decoder.out_features
|
||||
use_bias = self.decoder.bias is not None
|
||||
# Create new, independent heads (no deepcopy)
|
||||
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.rate_heads = nn.Sequential(nn.Linear(in_lm, 3), nn.Softplus())
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
t: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
h = output["hidden_states"][-1] # final hidden states
|
||||
h = self.head(h)
|
||||
# Position heads
|
||||
sub_log = self.sub_logits(h) # [B, L, V]
|
||||
ins_log = self.ins_logits(h) # [B, L, V]
|
||||
|
||||
rates = self.rate_heads(h)
|
||||
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
|
||||
-1
|
||||
) # [B, L], [B, L], [B, L]
|
||||
return dict(
|
||||
sub_rate_hat=sub_rate_hat, # [B,L]
|
||||
del_rate_hat=del_rate_hat, # [B,L]
|
||||
ins_rate_hat=ins_rate_hat, # [B,L]
|
||||
ins_logits=ins_log, # [B,L,V]
|
||||
sub_logits=sub_log, # [B,L,V]
|
||||
)
|
||||
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoConfig
|
||||
|
||||
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
||||
AutoConfig.register("editflow-modernbert", EditFlowModernBertConfig)
|
||||
AutoModel.register(EditFlowModernBertConfig, EditFlowModernBertModel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dllm
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
# Load a config from a local path (either a directory containing config.json, or the file itself)
|
||||
config_path = dllm.utils.resolve_with_base_env(
|
||||
"answerdotai/ModernBERT-base", "BASE_MODELS_DIR"
|
||||
)
|
||||
config = EditFlowModernBertConfig.from_pretrained(config_path)
|
||||
if hasattr(config, "auto_map"):
|
||||
delattr(config, "auto_map")
|
||||
if hasattr(config, "architectures"):
|
||||
delattr(config, "architectures")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
model = EditFlowModernBertModel(config)
|
||||
model.save_pretrained("models-tmp/editflow-modernbert")
|
||||
auto_model = AutoModel.from_pretrained("models-tmp/editflow-modernbert")
|
||||
97
dllm/dllm/pipelines/editflow/models/dream/modelling_dream.py
Normal file
97
dllm/dllm/pipelines/editflow/models/dream/modelling_dream.py
Normal file
@ -0,0 +1,97 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dllm.pipelines import dream
|
||||
|
||||
|
||||
class EditFlowDreamConfig(dream.DreamConfig):
|
||||
model_type = "editflow-dream" # <- NEW model_type
|
||||
|
||||
|
||||
class EditFlowDreamModel(dream.DreamModel):
|
||||
config_class = EditFlowDreamConfig
|
||||
modules_to_save = {
|
||||
"rate_heads",
|
||||
"sub_logits",
|
||||
"ins_logits",
|
||||
} # fully fintuned even using lora
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
in_lm, out_lm = self.lm_head.in_features, self.lm_head.out_features
|
||||
use_bias = self.lm_head.bias is not None
|
||||
# Create new, independent heads (no deepcopy)
|
||||
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
|
||||
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
t: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
h = output["hidden_states"][-1] # final hidden states
|
||||
# Position heads
|
||||
sub_log = self.sub_logits(h) # [B, L, V]
|
||||
sub_log = torch.concatenate(
|
||||
[torch.zeros_like(sub_log)[:, :1], sub_log[:, :-1]], dim=1
|
||||
) # [B, L, V]
|
||||
ins_log = self.ins_logits(h) # [B, L, V]
|
||||
|
||||
rates = self.rate_heads(h)
|
||||
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
|
||||
-1
|
||||
) # [B, L], [B, L], [B, L]
|
||||
sub_rate_hat = torch.concatenate(
|
||||
[torch.zeros_like(sub_rate_hat[:, :1]), sub_rate_hat[:, :-1]], dim=1
|
||||
) # [B, L]
|
||||
del_rate_hat = torch.concatenate(
|
||||
[torch.zeros_like(del_rate_hat[:, :1]), del_rate_hat[:, :-1]], dim=1
|
||||
) # [B, L]
|
||||
return dict(
|
||||
sub_rate_hat=sub_rate_hat, # [B,L]
|
||||
del_rate_hat=del_rate_hat, # [B,L]
|
||||
ins_rate_hat=ins_rate_hat, # [B,L]
|
||||
ins_logits=ins_log, # [B,L,V]
|
||||
sub_logits=sub_log, # [B,L,V]
|
||||
)
|
||||
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoConfig
|
||||
|
||||
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
||||
AutoConfig.register("editflow-dream", EditFlowDreamConfig)
|
||||
AutoModel.register(EditFlowDreamConfig, EditFlowDreamModel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dllm
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
# Load a config from a local path (either a directory containing config.json, or the file itself)
|
||||
config_path = dllm.utils.resolve_with_base_env(
|
||||
"Dream-org/Dream-v0-Base-7B", "BASE_MODELS_DIR"
|
||||
)
|
||||
config = EditFlowDreamConfig.from_pretrained(config_path)
|
||||
if hasattr(config, "auto_map"):
|
||||
delattr(config, "auto_map")
|
||||
if hasattr(config, "architectures"):
|
||||
delattr(config, "architectures")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
model = EditFlowDreamModel(config)
|
||||
model.save_pretrained("models-tmp/editflow-dream")
|
||||
auto_model = AutoModel.from_pretrained("models-tmp/editflow-dream")
|
||||
91
dllm/dllm/pipelines/editflow/models/llada/modelling_llada.py
Normal file
91
dllm/dllm/pipelines/editflow/models/llada/modelling_llada.py
Normal file
@ -0,0 +1,91 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
class EditFlowLLaDAConfig(llada.LLaDAConfig):
|
||||
model_type = "editflow-llada" # <- NEW model_type
|
||||
|
||||
|
||||
class EditFlowLLaDAModel(llada.LLaDAModelLM):
|
||||
config_class = EditFlowLLaDAConfig
|
||||
modules_to_save = {
|
||||
"rate_heads",
|
||||
"sub_logits",
|
||||
"ins_logits",
|
||||
} # fully fintuned even using lora
|
||||
|
||||
def __init__(self, config):
|
||||
# TODO: time embedding
|
||||
super().__init__(config)
|
||||
ff = self.model.transformer.ff_out
|
||||
in_f, out_f = ff.in_features, ff.out_features
|
||||
use_bias = ff.bias is not None
|
||||
# Create new, independent heads (no deepcopy)
|
||||
self.sub_logits = nn.Linear(in_f, out_f, bias=use_bias)
|
||||
self.ins_logits = nn.Linear(in_f, out_f, bias=use_bias)
|
||||
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
t: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO: time embedding
|
||||
output = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
h = output["hidden_states"][-1] # final hidden states
|
||||
# Position heads
|
||||
sub_log = self.sub_logits(h) # [B, L, V]
|
||||
ins_log = self.ins_logits(h) # [B, L, V]
|
||||
|
||||
rates = self.rate_heads(h)
|
||||
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
|
||||
-1
|
||||
) # [B, L], [B, L], [B, L]
|
||||
return dict(
|
||||
sub_rate_hat=sub_rate_hat, # [B,L]
|
||||
del_rate_hat=del_rate_hat, # [B,L]
|
||||
ins_rate_hat=ins_rate_hat, # [B,L]
|
||||
ins_logits=ins_log, # [B,L,V]
|
||||
sub_logits=sub_log, # [B,L,V]
|
||||
)
|
||||
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoConfig
|
||||
|
||||
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
||||
AutoConfig.register("editflow-llada", EditFlowLLaDAConfig)
|
||||
AutoModel.register(EditFlowLLaDAConfig, EditFlowLLaDAModel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dllm
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
# Load a config from a local path (either a directory containing config.json, or the file itself)
|
||||
config_path = dllm.utils.resolve_with_base_env(
|
||||
"GSAI-ML/LLaDA-8B-Base", "BASE_MODELS_DIR"
|
||||
)
|
||||
config = EditFlowLLaDAConfig.from_pretrained(config_path)
|
||||
if hasattr(config, "auto_map"):
|
||||
delattr(config, "auto_map")
|
||||
if hasattr(config, "architectures"):
|
||||
delattr(config, "architectures")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
model = EditFlowLLaDAModel(config)
|
||||
model.save_pretrained("models-tmp/editflow-llada")
|
||||
auto_model = AutoModel.from_pretrained("models-tmp/editflow-llada")
|
||||
407
dllm/dllm/pipelines/editflow/trainer.py
Normal file
407
dllm/dllm/pipelines/editflow/trainer.py
Normal file
@ -0,0 +1,407 @@
|
||||
from typing import Any, Dict, Union, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import transformers
|
||||
|
||||
from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler
|
||||
from dllm.pipelines.editflow.utils import pad_1d
|
||||
|
||||
|
||||
BLANK = -1
|
||||
|
||||
|
||||
def align_with_blanks(
|
||||
x0: List[int], x1: List[int], sub_cost: int = 1, gap_cost: int = 1
|
||||
) -> Dict:
|
||||
"""
|
||||
Needleman–Wunsch global alignment of two integer sequences with:
|
||||
match cost = 0, substitution cost = sub_cost, gap cost = gap_cost.
|
||||
Returns aligned sequences (z0, z1) of equal length containing BLANK = ε where gaps occur.
|
||||
"""
|
||||
n, m = len(x0), len(x1)
|
||||
# DP tables
|
||||
dp = [[0] * (m + 1) for _ in range(n + 1)]
|
||||
ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag', 'up', 'left'
|
||||
|
||||
for i in range(1, n + 1):
|
||||
dp[i][0] = i * gap_cost
|
||||
ptr[i][0] = "up"
|
||||
for j in range(1, m + 1):
|
||||
dp[0][j] = j * gap_cost
|
||||
ptr[0][j] = "left"
|
||||
|
||||
for i in range(1, n + 1):
|
||||
for j in range(1, m + 1):
|
||||
cost_diag = dp[i - 1][j - 1] + (0 if x0[i - 1] == x1[j - 1] else sub_cost)
|
||||
cost_up = dp[i - 1][j] + gap_cost
|
||||
cost_left = dp[i][j - 1] + gap_cost
|
||||
best = min(cost_diag, cost_up, cost_left)
|
||||
dp[i][j] = best
|
||||
if best == cost_diag:
|
||||
ptr[i][j] = "diag"
|
||||
elif best == cost_up:
|
||||
ptr[i][j] = "up"
|
||||
else:
|
||||
ptr[i][j] = "left"
|
||||
|
||||
# traceback
|
||||
z0, z1 = [], []
|
||||
i, j = n, m
|
||||
while i > 0 or j > 0:
|
||||
p = ptr[i][j]
|
||||
if p == "diag":
|
||||
z0.append(x0[i - 1])
|
||||
z1.append(x1[j - 1])
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif p == "up":
|
||||
z0.append(x0[i - 1])
|
||||
z1.append(BLANK)
|
||||
i -= 1
|
||||
else: # 'left'
|
||||
z0.append(BLANK)
|
||||
z1.append(x1[j - 1])
|
||||
j -= 1
|
||||
z0.reverse()
|
||||
z1.reverse()
|
||||
# return Alignment(z0=z0, z1=z1)
|
||||
# return {"z0": z0, "z1": z1}
|
||||
return dict(z0=z0, z1=z1)
|
||||
|
||||
|
||||
# def align_with_blanks(
|
||||
# x0: list[int], x1: list[int], sub_cost: int = 1, gap_cost: int = 1
|
||||
# ) -> dict:
|
||||
# """
|
||||
# Needleman–Wunsch with a secondary objective that defers gaps to the end:
|
||||
# - 'up' (gap in z1) is penalized if j < m
|
||||
# - 'left' (gap in z0) is penalized if i < n
|
||||
# This pushes blanks (-1) to the *right* whether x0 > x1 or x0 < x1.
|
||||
# """
|
||||
# n, m = len(x0), len(x1)
|
||||
|
||||
# dp_cost = [[0] * (m + 1) for _ in range(n + 1)]
|
||||
# dp_pen = [[0] * (m + 1) for _ in range(n + 1)]
|
||||
# ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag' | 'up' | 'left'
|
||||
|
||||
# # Left edge: all 'up' moves with j=0 (< m) → penalize each step
|
||||
# for i in range(1, n + 1):
|
||||
# dp_cost[i][0] = i * gap_cost
|
||||
# dp_pen[i][0] = i # i early 'up' moves
|
||||
# ptr[i][0] = "up"
|
||||
|
||||
# # Top edge: all 'left' moves with i=0 (< n) → penalize each step
|
||||
# for j in range(1, m + 1):
|
||||
# dp_cost[0][j] = j * gap_cost
|
||||
# dp_pen[0][j] = j # j early 'left' moves
|
||||
# ptr[0][j] = "left"
|
||||
|
||||
# for i in range(1, n + 1):
|
||||
# xi = x0[i - 1]
|
||||
# for j in range(1, m + 1):
|
||||
# yj = x1[j - 1]
|
||||
|
||||
# # diag
|
||||
# cost_diag = dp_cost[i - 1][j - 1] + (0 if xi == yj else sub_cost)
|
||||
# pen_diag = dp_pen[i - 1][j - 1]
|
||||
# cand_diag = (cost_diag, pen_diag)
|
||||
|
||||
# # up: add blank to z1, penalize if j < m (early)
|
||||
# cost_up = dp_cost[i - 1][j] + gap_cost
|
||||
# pen_up = dp_pen[i - 1][j] + (1 if j < m else 0)
|
||||
# cand_up = (cost_up, pen_up)
|
||||
|
||||
# # left: add blank to z0, penalize if i < n (early)
|
||||
# cost_left = dp_cost[i][j - 1] + gap_cost
|
||||
# pen_left = dp_pen[i][j - 1] + (1 if i < n else 0)
|
||||
# cand_left = (cost_left, pen_left)
|
||||
|
||||
# # choose (cost,pen) min; deterministic tie-break: diag > left > up
|
||||
# best = min(cand_diag, cand_left, cand_up)
|
||||
# dp_cost[i][j], dp_pen[i][j] = best
|
||||
# if best == cand_diag:
|
||||
# ptr[i][j] = "diag"
|
||||
# elif best == cand_left:
|
||||
# ptr[i][j] = "left"
|
||||
# else:
|
||||
# ptr[i][j] = "up"
|
||||
|
||||
# # traceback
|
||||
# z0, z1 = [], []
|
||||
# i, j = n, m
|
||||
# while i > 0 or j > 0:
|
||||
# p = ptr[i][j]
|
||||
# if p == "diag":
|
||||
# z0.append(x0[i - 1])
|
||||
# z1.append(x1[j - 1])
|
||||
# i -= 1
|
||||
# j -= 1
|
||||
# elif p == "up":
|
||||
# z0.append(x0[i - 1])
|
||||
# z1.append(BLANK)
|
||||
# i -= 1
|
||||
# else: # 'left'
|
||||
# z0.append(BLANK)
|
||||
# z1.append(x1[j - 1])
|
||||
# j -= 1
|
||||
|
||||
# z0.reverse()
|
||||
# z1.reverse()
|
||||
# return dict(z0=z0, z1=z1)
|
||||
|
||||
|
||||
def strip_blanks(z: list[int]) -> list[int]:
|
||||
# IMPORTANT: do NOT strip BOS; we only remove BLANKs
|
||||
return [t for t in z if t != BLANK]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edit:
|
||||
kind: str # "SUB" | "DEL" | "INS"
|
||||
pos: int # position (for SUB/DEL) or token-row idx for INS (incl. BOS row 0)
|
||||
token: int | None # token for SUB/INS, else None
|
||||
|
||||
|
||||
def build_remaining_edits(zt: list[int], z1: list[int]) -> list[Edit]:
|
||||
edits: list[Edit] = []
|
||||
|
||||
def count_nonblank_prefix(z: list[int], j: int) -> int:
|
||||
c = 0
|
||||
for k in range(j):
|
||||
if z[k] != BLANK:
|
||||
c += 1
|
||||
return c
|
||||
|
||||
for j, (a, b) in enumerate(zip(zt, z1)):
|
||||
if a == b:
|
||||
continue
|
||||
nb = count_nonblank_prefix(
|
||||
zt, j
|
||||
) # counts BOS as 1, first content token will be nb=1 before its column
|
||||
|
||||
if a == BLANK and b != BLANK:
|
||||
# INSERT after row (nb-1): BOS insert => nb=1 -> gap=0; general case works too
|
||||
gap = max(nb - 1, 0)
|
||||
edits.append(Edit("INS", gap, b))
|
||||
|
||||
elif a != BLANK and b == BLANK:
|
||||
# DELETE token at row nb (first content token => nb=1, allowed; BOS is never BLANK so nb>=1)
|
||||
pos = nb
|
||||
# if pos > 0: # forbid BOS (row 0)
|
||||
edits.append(Edit("DEL", pos, None))
|
||||
|
||||
else: # a != BLANK, b != BLANK, a != b
|
||||
# SUB token at row nb
|
||||
pos = nb
|
||||
# if pos > 0: # forbid BOS (row 0)
|
||||
edits.append(Edit("SUB", pos, b))
|
||||
return edits
|
||||
|
||||
|
||||
class EditFlowTrainer(transformers.Trainer):
|
||||
"""
|
||||
Trainer for Edit Flows where the model returns:
|
||||
- sub_logits: [B,L,V] (token dist for SUB)
|
||||
- ins_logits: [B,L,V] (token dist for INS)
|
||||
- sub_rate_hat: [B,L] (normalized rates; NO kappa factor)
|
||||
- del_rate_hat: [B,L]
|
||||
- ins_rate_hat: [B,L]
|
||||
True intensities are w * rate_hat, with w = kappa_dot(t) / (1 - kappa(t)).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
scheduler: BaseKappaScheduler | None = None,
|
||||
normalize_per_position: bool = True,
|
||||
time_epsilon: float = 1e-3,
|
||||
max_w: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.scheduler = scheduler or CubicKappaScheduler()
|
||||
self.normalize_per_position = normalize_per_position
|
||||
self.time_epsilon = time_epsilon
|
||||
self.max_w = max_w
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: transformers.PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
device = self.model.device
|
||||
B = len(inputs["x0_ids"])
|
||||
|
||||
# -------- 1) Align with blanks (z0,z1) and sample time t --------
|
||||
aligns = [
|
||||
align_with_blanks(x0, x1)
|
||||
for x0, x1 in zip(inputs["x0_ids"], inputs["x1_ids"])
|
||||
]
|
||||
z0_list = [a["z0"] for a in aligns]
|
||||
z1_list = [a["z1"] for a in aligns]
|
||||
assert all(len(z0) == len(z1) for z0, z1 in zip(z0_list, z1_list))
|
||||
assert all(z0[0] != BLANK for z0 in z0_list) # BOS must remain
|
||||
|
||||
t = (1 - self.time_epsilon) * torch.rand(B, 1, device=device) # [B,1]
|
||||
k = self.scheduler.kappa(t).to(device) # [B,1]
|
||||
w = self.scheduler.weight(t).squeeze(1).to(device) # [B]
|
||||
if self.max_w:
|
||||
w = w.clamp(max=self.max_w)
|
||||
|
||||
# -------- 2) Sample z_t by κ-mixing (vectorized per example) --------
|
||||
# Keep python lists -> tensors per-example to reuse build_remaining_edits
|
||||
zt_list: list[list[int]] = []
|
||||
for z0, z1, kb in zip(z0_list, z1_list, k.squeeze(1).tolist()):
|
||||
# per-column Bernoulli(κ) mix; BOS is equal in z0/z1 so it stays BOS
|
||||
choose_target = torch.rand(len(z0)) < kb
|
||||
zt = [b if choose_target[j] else a for j, (a, b) in enumerate(zip(z0, z1))]
|
||||
zt_list.append(zt)
|
||||
|
||||
# -------- 3) Strip blanks to x_t and compute remaining edits --------
|
||||
xt_list = [strip_blanks(zt) for zt in zt_list]
|
||||
edits_list: list[list[Edit]] = [
|
||||
build_remaining_edits(zt, z1) for zt, z1 in zip(zt_list, z1_list)
|
||||
]
|
||||
|
||||
# -------- 4) Collate x_t for the model --------
|
||||
x_tok, x_mask = pad_1d(
|
||||
xt_list, pad_val=self.processing_class.pad_token_id
|
||||
) # [B,Lmax], [B,Lmax]
|
||||
x_tok = x_tok.to(device)
|
||||
x_mask = x_mask.to(device)
|
||||
|
||||
# -------- 5) Forward pass --------
|
||||
out = model(input_ids=x_tok, attention_mask=x_mask, t=t.to(device))
|
||||
# Rename for clarity: model returns normalized rates (no kappa)
|
||||
sub_rate_hat = out["sub_rate_hat"] # [B,L]
|
||||
del_rate_hat = out["del_rate_hat"] # [B,L]
|
||||
ins_rate_hat = out["ins_rate_hat"] # [B,L]
|
||||
logQ_sub = F.log_softmax(out["sub_logits"], dim=-1) # [B,L,V]
|
||||
logQ_ins = F.log_softmax(out["ins_logits"], dim=-1) # [B,L,V]
|
||||
|
||||
# *** NEW: zero-cost anchor to "touch" every head even if unused this step ***
|
||||
# Using .sum() * 0.0 keeps a graph dependency without changing the loss value.
|
||||
# Include both logits (for SUB/INS heads) and rates (for SUB/DEL/INS heads).
|
||||
# This is important for Deepspeed ZeRO stage 2/3 to avoid skipping unused parameters.
|
||||
anchor = (
|
||||
sub_rate_hat.sum() * 0.0
|
||||
+ del_rate_hat.sum() * 0.0
|
||||
+ ins_rate_hat.sum() * 0.0
|
||||
+ logQ_sub.sum() * 0.0
|
||||
+ logQ_ins.sum() * 0.0
|
||||
)
|
||||
|
||||
# Utility
|
||||
def safe_log(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.log(x.clamp_min(1e-12))
|
||||
|
||||
# -------- 6) Survival term --------
|
||||
# Survival = E[sum of true intensities over valid rows]
|
||||
# true intensity = w[b] * rate_hat[b, i]
|
||||
mask_f = x_mask.float()
|
||||
# L = mask_f.sum(dim=1).clamp_min(1.0) # [B] number of positions (incl. BOS)
|
||||
L1 = torch.tensor(
|
||||
[len(x1) for x1 in inputs["x1_ids"]], device=device, dtype=torch.float
|
||||
).clamp_min(1.0)
|
||||
denom = L1 if self.normalize_per_position else torch.ones_like(L1)
|
||||
|
||||
Lambda_hat = ((sub_rate_hat + del_rate_hat + ins_rate_hat) * mask_f).sum(
|
||||
dim=1
|
||||
) # [B]
|
||||
loss_surv = ((w * Lambda_hat) / denom).mean()
|
||||
|
||||
# -------- 7) Positive edit terms --------
|
||||
# For each remaining edit e: -log true rate(e) - log token prob(e) if tokenized
|
||||
# loss_pos_per = sub_rate_hat.new_zeros(B) # [B]
|
||||
# for b, edits in enumerate(edits_list):
|
||||
# if not edits:
|
||||
# continue
|
||||
# cur_len = int(x_mask[b].sum().item())
|
||||
# for e in edits:
|
||||
# pos = e.pos
|
||||
# assert 0 <= pos < cur_len, f"pos {pos} out of range {cur_len}"
|
||||
# if e.kind == "SUB":
|
||||
# loss_pos_per[b] -= logQ_sub[b, pos, e.token] + safe_log(
|
||||
# sub_rate_hat[b, pos]
|
||||
# )
|
||||
# elif e.kind == "DEL":
|
||||
# loss_pos_per[b] -= safe_log(del_rate_hat[b, pos])
|
||||
# else: # "INS"
|
||||
# loss_pos_per[b] -= logQ_ins[b, pos, e.token] + safe_log(
|
||||
# ins_rate_hat[b, pos]
|
||||
# )
|
||||
|
||||
# -------- 7) Positive edit terms (vectorized) --------
|
||||
pos_sub, tok_sub, pos_ins, tok_ins, pos_del = [], [], [], [], []
|
||||
for b, edits in enumerate(edits_list):
|
||||
cur_len = int(x_mask[b].sum().item())
|
||||
ps, ts, pi, ti, pd = [], [], [], [], []
|
||||
for e in edits:
|
||||
if not (0 <= e.pos < cur_len):
|
||||
raise AssertionError(
|
||||
f"pos {e.pos} out of range {cur_len} for b={b}"
|
||||
)
|
||||
if e.kind == "SUB":
|
||||
ps.append(e.pos)
|
||||
ts.append(e.token)
|
||||
elif e.kind == "INS":
|
||||
pi.append(e.pos)
|
||||
ti.append(e.token)
|
||||
else:
|
||||
pd.append(e.pos)
|
||||
pos_sub.append(
|
||||
torch.tensor(ps, device=x_tok.device, dtype=torch.long) if ps else None
|
||||
)
|
||||
tok_sub.append(
|
||||
torch.tensor(ts, device=x_tok.device, dtype=torch.long) if ts else None
|
||||
)
|
||||
pos_ins.append(
|
||||
torch.tensor(pi, device=x_tok.device, dtype=torch.long) if pi else None
|
||||
)
|
||||
tok_ins.append(
|
||||
torch.tensor(ti, device=x_tok.device, dtype=torch.long) if ti else None
|
||||
)
|
||||
pos_del.append(
|
||||
torch.tensor(pd, device=x_tok.device, dtype=torch.long) if pd else None
|
||||
)
|
||||
|
||||
loss_pos_terms = []
|
||||
for b in range(B):
|
||||
lp = x_tok.new_zeros(())
|
||||
if pos_sub[b] is not None:
|
||||
lp = (
|
||||
lp
|
||||
- (
|
||||
logQ_sub[b, pos_sub[b], tok_sub[b]]
|
||||
+ safe_log(sub_rate_hat[b, pos_sub[b]])
|
||||
).sum()
|
||||
)
|
||||
if pos_ins[b] is not None:
|
||||
lp = (
|
||||
lp
|
||||
- (
|
||||
logQ_ins[b, pos_ins[b], tok_ins[b]]
|
||||
+ safe_log(ins_rate_hat[b, pos_ins[b]])
|
||||
).sum()
|
||||
)
|
||||
if pos_del[b] is not None:
|
||||
lp = lp - safe_log(del_rate_hat[b, pos_del[b]]).sum()
|
||||
loss_pos_terms.append(lp)
|
||||
loss_pos_per = torch.stack(loss_pos_terms) # [B]
|
||||
|
||||
# # Average positive term per sequence (MC estimator across batch)
|
||||
loss_pos = ((w * loss_pos_per) / denom).mean()
|
||||
|
||||
# -------- 8) Total --------
|
||||
loss = loss_surv + loss_pos + anchor
|
||||
return (loss, out) if return_outputs else loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
218
dllm/dllm/pipelines/editflow/utils.py
Normal file
218
dllm/dllm/pipelines/editflow/utils.py
Normal file
@ -0,0 +1,218 @@
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Text
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from dllm.utils.utils import parse_spec
|
||||
|
||||
|
||||
# ------------------------------- Collator (x0 source) --------------------------------
|
||||
@dataclass
|
||||
class X0Sampler:
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list[int]:
|
||||
raise NotImplementedError("Subclasses must implement __call__.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleX0Empty(X0Sampler):
|
||||
"""Return BOS-only (i.e., empty tail)."""
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleX0Masks(X0Sampler):
|
||||
"""Return a run of mask tokens of given length."""
|
||||
|
||||
length: int = 128
|
||||
tokenizer: transformers.PreTrainedTokenizer = None
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list[int]:
|
||||
mask_id = getattr(self.tokenizer, "mask_token_id", None)
|
||||
if mask_id is None:
|
||||
raise ValueError("tokenizer needs mask_token_id for mask-based sampler")
|
||||
return [int(mask_id)] * self.length
|
||||
|
||||
|
||||
# ---------------- Factory ---------------- #
|
||||
_X0_SAMPLER_CLASSES: dict[str, type[X0Sampler]] = {
|
||||
"empty": SampleX0Empty,
|
||||
"masks": SampleX0Masks,
|
||||
}
|
||||
|
||||
|
||||
def make_x0_sampler(name: str, tokenizer: Any, **kwargs) -> X0Sampler:
|
||||
try:
|
||||
name, kvs = parse_spec(name)
|
||||
cls = _X0_SAMPLER_CLASSES[name.lower()]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Unknown x0 sampler '{name}'. Available: {list(_X0_SAMPLER_CLASSES)}"
|
||||
)
|
||||
# merged_kwargs = {**kvs, **kwargs}
|
||||
return cls(tokenizer=tokenizer, **kvs, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditFlowCollator:
|
||||
tokenizer: transformers.PreTrainedTokenizer = None
|
||||
x0_sampler: Callable | str | None = X0Sampler # can be func OR name
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.x0_sampler, str):
|
||||
self.x0_sampler = make_x0_sampler(self.x0_sampler, self.tokenizer)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, list[Any]]:
|
||||
if not features:
|
||||
return {}
|
||||
|
||||
keys = features[0].keys()
|
||||
batch = {k: [ex[k] for ex in features] for k in keys}
|
||||
batch["x1_ids"] = batch["input_ids"]
|
||||
|
||||
if "prompt_len" not in batch:
|
||||
assert self.tokenizer.bos_token_id is not None
|
||||
bos = self.tokenizer.bos_token_id
|
||||
batch["x1_ids"] = [
|
||||
x if x and x[0] == bos else [bos] + x for x in batch["x1_ids"]
|
||||
]
|
||||
batch["x0_ids"] = [
|
||||
x1_ids[:1] + self.x0_sampler(x1_ids=x1_ids[1:])
|
||||
for x1_ids in batch["x1_ids"]
|
||||
]
|
||||
else:
|
||||
batch["x0_ids"] = [
|
||||
x1_ids[:prompt_len] + self.x0_sampler(x1_ids=x1_ids[prompt_len:])
|
||||
for x1_ids, prompt_len in zip(batch["x1_ids"], batch["prompt_len"])
|
||||
]
|
||||
|
||||
batch["return_loss"] = True
|
||||
return batch
|
||||
|
||||
|
||||
# ------------------------------- Trainer utils --------------------------------
|
||||
def pad_1d(
|
||||
batch_lists: list[list[int]], pad_val: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pads a list of variable-length integer lists into:
|
||||
- out: tensor of shape [B, Lmax] with padding value `pad_val`
|
||||
- mask: tensor of shape [B, Lmax] with 1 for real tokens and 0 for padding (int mask)
|
||||
"""
|
||||
B = len(batch_lists)
|
||||
Lmax = max((len(x) for x in batch_lists), default=0)
|
||||
out = torch.full((B, Lmax), pad_val, dtype=torch.long)
|
||||
mask = torch.zeros((B, Lmax), dtype=torch.long) # 0/1 mask (int)
|
||||
|
||||
for b, x in enumerate(batch_lists):
|
||||
if not x:
|
||||
continue
|
||||
L = len(x)
|
||||
out[b, :L] = torch.tensor(x, dtype=torch.long)
|
||||
mask[b, :L] = 1 # mark valid positions with 1
|
||||
|
||||
return out, mask
|
||||
|
||||
|
||||
def init_editflow_from_src(
|
||||
ef_model, src_model, lm_head_key: str = "lm_head", verbose: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize an EditFlowModel (ef_model) from a pretrained source model.
|
||||
|
||||
If DeepSpeed ZeRO-3 is enabled (detected via HF's `is_deepspeed_zero3_enabled()`),
|
||||
this function temporarily gathers full parameters for both models on rank 0,
|
||||
performs the copy there, and then returns to sharded mode automatically.
|
||||
Otherwise it behaves like a normal CPU/GPU single-process copy.
|
||||
|
||||
Returns (missing_keys, unexpected_keys) from load_state_dict(strict=False).
|
||||
"""
|
||||
import deepspeed
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
dist_ok = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
rank = torch.distributed.get_rank() if dist_ok else 0
|
||||
|
||||
def _copy_once():
|
||||
src_sd = src_model.state_dict()
|
||||
tgt_sd = ef_model.state_dict()
|
||||
new_sd = OrderedDict()
|
||||
|
||||
# 1) copy matching backbone tensors
|
||||
for k, v in src_sd.items():
|
||||
if k in tgt_sd and tgt_sd[k].shape == v.shape:
|
||||
new_sd[k] = v
|
||||
|
||||
# 2) duplicate lm_head -> sub_logits & ins_logits (weight + optional bias)
|
||||
lm_w = f"{lm_head_key}.weight"
|
||||
lm_b = f"{lm_head_key}.bias"
|
||||
|
||||
if lm_w in src_sd:
|
||||
if "sub_logits.weight" in tgt_sd:
|
||||
new_sd["sub_logits.weight"] = src_sd[lm_w]
|
||||
if "ins_logits.weight" in tgt_sd:
|
||||
new_sd["ins_logits.weight"] = src_sd[lm_w]
|
||||
if lm_b in src_sd:
|
||||
if "sub_logits.bias" in tgt_sd:
|
||||
new_sd["sub_logits.bias"] = src_sd[lm_b]
|
||||
if "ins_logits.bias" in tgt_sd:
|
||||
new_sd["ins_logits.bias"] = src_sd[lm_b]
|
||||
|
||||
# 3) non-strict load so new rate heads remain randomly initialized
|
||||
missing, unexpected = ef_model.load_state_dict(new_sd, strict=False)
|
||||
return new_sd, missing, unexpected
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
# All ranks enter/exit together; only rank 0 materializes full tensors.
|
||||
params = list(ef_model.parameters()) + list(src_model.parameters())
|
||||
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
||||
if rank == 0:
|
||||
new_sd, missing, unexpected = _copy_once()
|
||||
else:
|
||||
new_sd, missing, unexpected = OrderedDict(), [], []
|
||||
|
||||
if dist_ok:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if verbose and rank == 0:
|
||||
_p = getattr(globals().get("dllm", None), "utils", None)
|
||||
printer = getattr(_p, "print_main", print) if _p else print
|
||||
printer(
|
||||
f"[EditFlow init][ZeRO-3] Copied {len(new_sd)} tensors from Src Model."
|
||||
)
|
||||
if missing:
|
||||
printer(" Missing (expected for new rate heads, etc.):")
|
||||
for k in missing:
|
||||
printer(" -", k)
|
||||
if unexpected:
|
||||
printer(" Unexpected (check key names):")
|
||||
for k in unexpected:
|
||||
printer(" -", k)
|
||||
return missing, unexpected
|
||||
|
||||
# --- Non-ZeRO (or DS not present) path ---
|
||||
new_sd, missing, unexpected = _copy_once()
|
||||
if verbose:
|
||||
_p = getattr(globals().get("dllm", None), "utils", None)
|
||||
printer = getattr(_p, "print_main", print) if _p else print
|
||||
printer(f"[EditFlow init] Copied {len(new_sd)} tensors from Src Model.")
|
||||
if missing:
|
||||
printer(" Missing (expected for new rate heads, etc.):")
|
||||
for k in missing:
|
||||
printer(" -", k)
|
||||
if unexpected:
|
||||
printer(" Unexpected (check key names):")
|
||||
for k in unexpected:
|
||||
printer(" -", k)
|
||||
return missing, unexpected
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
7
dllm/dllm/pipelines/llada/__init__.py
Normal file
7
dllm/dllm/pipelines/llada/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from . import generator, trainer
|
||||
from .models.modeling_llada import LLaDAModelLM
|
||||
from .models.configuration_llada import LLaDAConfig
|
||||
from .models.modeling_lladamoe import LLaDAMoEModelLM
|
||||
from .models.configuration_lladamoe import LLaDAMoEConfig
|
||||
from .generator import LLaDAGeneratorConfig, LLaDAGenerator
|
||||
from .trainer import LLaDATrainer
|
||||
357
dllm/dllm/pipelines/llada/eval.py
Normal file
357
dllm/dllm/pipelines/llada/eval.py
Normal file
@ -0,0 +1,357 @@
|
||||
"""
|
||||
accelerate launch \
|
||||
--num_processes 2 \
|
||||
dllm/pipelines/llada/eval.py \
|
||||
--tasks gsm8k \
|
||||
--model llada \
|
||||
--num_fewshot 8 \
|
||||
--model_args "pretrained=GSAI-ML/LLaDA-8B-Base,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
from lm_eval.api.instance import Instance
|
||||
from lm_eval.api.model import LM
|
||||
from lm_eval.api.registry import register_model
|
||||
from lm_eval.models.utils import get_dtype
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDAEvalConfig(LLaDAGeneratorConfig):
|
||||
max_new_tokens: int = 1024
|
||||
max_length: int = 4096
|
||||
steps: int = 1024
|
||||
block_length: int = 1024
|
||||
|
||||
pretrained: str = ""
|
||||
dtype: str | torch.dtype = "auto"
|
||||
batch_size: int = 32
|
||||
mc_num: int = 128
|
||||
is_check_greedy: bool = True
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
@register_model("llada")
|
||||
class LLaDAEvalHarness(LM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LLaDAEvalConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = LLaDAEvalConfig()
|
||||
|
||||
# Pull args from config, allow kwargs to override
|
||||
pretrained = kwargs.get("pretrained", config.pretrained)
|
||||
dtype = kwargs.get("dtype", config.dtype)
|
||||
batch_size = kwargs.get("batch_size", config.batch_size)
|
||||
mc_num = kwargs.get("mc_num", config.mc_num)
|
||||
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
|
||||
device = kwargs.get("device", config.device)
|
||||
cfg = kwargs.get("cfg", config.cfg_scale)
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
|
||||
accelerator = accelerate.Accelerator()
|
||||
|
||||
# Get GLOBAL rank from torch.distributed (not accelerator)
|
||||
if torch.distributed.is_initialized():
|
||||
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
|
||||
self._world_size = (
|
||||
torch.distributed.get_world_size()
|
||||
) # ← GLOBAL world size (16)
|
||||
else:
|
||||
self._rank = 0
|
||||
self._world_size = 1
|
||||
|
||||
# Use accelerator for device placement
|
||||
self.model = dllm.utils.get_model(
|
||||
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
if accelerator.num_processes > 1:
|
||||
# Let accelerator handle device placement
|
||||
self.model = accelerator.prepare(self.model)
|
||||
self.device = (
|
||||
accelerator.device
|
||||
) # ← Accelerator figures out local device correctly
|
||||
self.accelerator = accelerator
|
||||
else:
|
||||
# Single GPU
|
||||
self.model = self.model.to(device)
|
||||
self.device = torch.device(device)
|
||||
self.accelerator = None
|
||||
|
||||
self.tokenizer = dllm.utils.get_tokenizer(
|
||||
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
|
||||
)
|
||||
|
||||
# generation params
|
||||
self.mask_id = self.tokenizer.mask_token_id
|
||||
self.batch_size = int(batch_size)
|
||||
self.max_length = max_length
|
||||
self.max_new_tokens = int(max_new_tokens)
|
||||
self.block_length = int(block_length)
|
||||
self.steps = int(steps)
|
||||
self.cfg = float(cfg)
|
||||
self.remasking = remasking
|
||||
self.is_check_greedy = is_check_greedy
|
||||
|
||||
# loglikelihood params
|
||||
self.mc_num = int(mc_num)
|
||||
assert mc_num % self.batch_size == 0
|
||||
self.sampling_eps = 0.0
|
||||
|
||||
def apply_chat_template(
|
||||
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Method to apply a chat template to a list of chat history between user and model.
|
||||
"""
|
||||
chat_templated = self.tokenizer.apply_chat_template(
|
||||
chat_history,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=not add_generation_prompt,
|
||||
)
|
||||
return chat_templated
|
||||
|
||||
@property
|
||||
def tokenizer_name(self) -> str:
|
||||
return self.tokenizer.name_or_path.replace("/", "__")
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def _forward_process(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
b, l = batch.shape
|
||||
|
||||
target_len = (l - prompt_index.sum()).item()
|
||||
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
||||
|
||||
x = torch.round(
|
||||
torch.linspace(
|
||||
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
|
||||
)
|
||||
).long()
|
||||
x = ((x - 1) % target_len) + 1
|
||||
assert x.min() >= 1 and x.max() <= target_len
|
||||
|
||||
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
||||
is_mask = indices < x.unsqueeze(1)
|
||||
|
||||
for i in range(b):
|
||||
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
||||
|
||||
is_mask = torch.cat(
|
||||
(
|
||||
torch.zeros(
|
||||
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
|
||||
),
|
||||
is_mask,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
noisy_batch = torch.where(is_mask, self.mask_id, batch)
|
||||
|
||||
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_logits(
|
||||
self, batch: torch.Tensor, prompt_index: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if self.cfg > 0.0:
|
||||
assert len(prompt_index) == batch.shape[1]
|
||||
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
||||
un_batch = batch.clone()
|
||||
un_batch[prompt_index] = self.mask_id
|
||||
batch = torch.cat([batch, un_batch])
|
||||
|
||||
logits = self.model(batch).logits
|
||||
|
||||
if self.cfg > 0.0:
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
|
||||
return logits[:, : batch.shape[1]]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
|
||||
seq = torch.concatenate([prefix, target])[None, :]
|
||||
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
|
||||
loss_acc = []
|
||||
for _ in range(self.mc_num // self.batch_size):
|
||||
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
|
||||
|
||||
mask_indices = perturbed_seq == self.mask_id
|
||||
|
||||
logits = self.get_logits(perturbed_seq, prompt_index)
|
||||
|
||||
loss = (
|
||||
F.cross_entropy(
|
||||
logits[mask_indices], seq[mask_indices], reduction="none"
|
||||
)
|
||||
/ p_mask[mask_indices]
|
||||
)
|
||||
loss = loss.sum() / self.batch_size
|
||||
loss_acc.append(loss.item())
|
||||
|
||||
return -sum(loss_acc) / len(loss_acc)
|
||||
|
||||
@torch.no_grad()
|
||||
def suffix_greedy_prediction(
|
||||
self, prefix: torch.Tensor, target: torch.Tensor
|
||||
) -> bool:
|
||||
if not self.is_check_greedy:
|
||||
return False
|
||||
|
||||
seq = torch.full(
|
||||
(1, len(prefix) + len(target)), self.mask_id, device=self.device
|
||||
)
|
||||
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
||||
prefix, target = prefix.to(self.device), target.to(self.device)
|
||||
seq[0, : len(prefix)] = prefix
|
||||
|
||||
for i in range(len(target)):
|
||||
mask_index = seq == self.mask_id
|
||||
logits = self.get_logits(seq, prompt_index)[mask_index]
|
||||
x0 = torch.argmax(logits, dim=-1)
|
||||
|
||||
p = torch.softmax(logits.to(torch.float32), dim=-1)
|
||||
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
|
||||
dim=-1
|
||||
)
|
||||
_, index = torch.sort(confidence, descending=True)
|
||||
x0[index[1:]] = self.mask_id
|
||||
seq[mask_index] = x0.clone()
|
||||
correct = target == seq[0, len(prefix) :]
|
||||
correct = torch.all(correct)
|
||||
return correct
|
||||
|
||||
def _encode_pair(
|
||||
self, context: str, continuation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
n_spaces = len(context) - len(context.rstrip())
|
||||
if n_spaces > 0:
|
||||
continuation = context[-n_spaces:] + continuation
|
||||
context = context[:-n_spaces]
|
||||
|
||||
whole_enc = self.tokenizer(context + continuation)["input_ids"]
|
||||
context_enc = self.tokenizer(context)["input_ids"]
|
||||
|
||||
context_enc_len = len(context_enc)
|
||||
continuation_enc = whole_enc[context_enc_len:]
|
||||
|
||||
return context_enc, continuation_enc
|
||||
|
||||
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
|
||||
def _tokenize(e):
|
||||
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
||||
return {
|
||||
"prefix_text": e["prefix"],
|
||||
"target_text": e["target"],
|
||||
"prefix": prefix,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
ds = []
|
||||
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
|
||||
|
||||
assert max(prompt_len) <= 4096
|
||||
|
||||
out = []
|
||||
with torch.no_grad():
|
||||
for elem in tqdm(ds, desc="Computing likelihood..."):
|
||||
prefix = elem["prefix"]
|
||||
target = elem["target"]
|
||||
|
||||
ll = self.get_loglikelihood(prefix, target)
|
||||
|
||||
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
|
||||
|
||||
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
||||
torch.cuda.empty_cache()
|
||||
return out
|
||||
|
||||
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_until(self, requests: list[Instance]) -> list[str]:
|
||||
def _tokenize(e):
|
||||
return {
|
||||
"question": self.tokenizer(e["question"])["input_ids"],
|
||||
"question_text": e["question"],
|
||||
"until": e["until"],
|
||||
}
|
||||
|
||||
ds = [
|
||||
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
|
||||
]
|
||||
ds = Dataset.from_list(ds)
|
||||
ds = ds.map(_tokenize)
|
||||
ds = ds.with_format("torch")
|
||||
|
||||
out = []
|
||||
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
|
||||
|
||||
for elem in tqdm(ds, desc="Generating..."):
|
||||
prompt = [elem["question"].to(self.device)]
|
||||
stop_tokens = elem["until"]
|
||||
generated_ids = generator.generate(
|
||||
inputs=prompt,
|
||||
steps=self.steps,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
block_length=self.block_length,
|
||||
temperature=0.0,
|
||||
cfg_scale=self.cfg,
|
||||
remasking=self.remasking,
|
||||
)
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
|
||||
)
|
||||
for stop_seq in stop_tokens:
|
||||
if stop_seq in generated_answer:
|
||||
generated_answer = generated_answer.split(stop_seq)[0]
|
||||
|
||||
# remove special tokens
|
||||
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
|
||||
generated_answer = self.tokenizer.decode(
|
||||
generated_answer_ids, skip_special_tokens=True
|
||||
)
|
||||
out.append(generated_answer)
|
||||
if self.accelerator is not None:
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_evaluate()
|
||||
379
dllm/dllm/pipelines/llada/generator.py
Normal file
379
dllm/dllm/pipelines/llada/generator.py
Normal file
@ -0,0 +1,379 @@
|
||||
"""
|
||||
reference: https://github.com/ML-GSAI/LLaDA/blob/main/generate.py
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from dllm.utils.generation_utils import get_num_transfer_tokens
|
||||
from dllm.core.generation.generator import (
|
||||
GeneratorOutput,
|
||||
GeneratorConfig,
|
||||
BaseGenerator,
|
||||
)
|
||||
|
||||
|
||||
def add_gumbel_noise(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
||||
"""
|
||||
The Gumbel max is a method for sampling categorical distributions.
|
||||
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
||||
Thus, we use float64.
|
||||
"""
|
||||
if temperature == 0:
|
||||
return logits
|
||||
logits = logits.to(torch.float64)
|
||||
noise = torch.rand_like(logits, dtype=torch.float64)
|
||||
gumbel_noise = (-torch.log(noise)) ** temperature
|
||||
return logits.exp() / gumbel_noise
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDAGeneratorConfig(GeneratorConfig):
|
||||
max_new_tokens: int = 128
|
||||
max_length: int = (
|
||||
None # There's no explicit length_limit except for the tokenizer/model context
|
||||
)
|
||||
block_length: int = 128
|
||||
steps: int = 128
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
stochastic_transfer: bool = False
|
||||
cfg_scale: float = 0.0
|
||||
cfg_keep_tokens: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDAGenerator(BaseGenerator):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: list[torch.Tensor | list],
|
||||
config: LLaDAGeneratorConfig | None = None,
|
||||
**kwargs,
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
if config is None:
|
||||
config = LLaDAGeneratorConfig()
|
||||
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
|
||||
max_length = kwargs.get("max_length", config.max_length)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
|
||||
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
assert 1 <= block_length
|
||||
assert 1 <= steps
|
||||
mask_id = self.tokenizer.mask_token_id
|
||||
eos_id = self.tokenizer.eos_token_id
|
||||
|
||||
# ----- Shape bookkeeping: per-sample prompt lengths and final canvas width -----
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
prompt_lens = [p.shape[0] for p in inputs]
|
||||
|
||||
if max_new_tokens:
|
||||
max_length = max_new_tokens + max(prompt_lens)
|
||||
else:
|
||||
max_new_tokens = max_length - max(prompt_lens)
|
||||
|
||||
B = len(inputs)
|
||||
T = max_length
|
||||
|
||||
# ----- Initialize canvas with EOS, copy inputs, and append mask tail -----
|
||||
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
|
||||
for i, p in enumerate(inputs):
|
||||
x[i, : prompt_lens[i]] = p # keep original prompt tokens
|
||||
x[i, prompt_lens[i] : prompt_lens[i] + max_new_tokens] = (
|
||||
mask_id # append `max_new_tokens` masks to be generated
|
||||
)
|
||||
attention_mask = (x != eos_id).long() if B > 1 else None
|
||||
|
||||
# Tokens that were *given* at the start (non-mask, non-EOS).
|
||||
# These will be masked in the unconditional forward pass for CFG.
|
||||
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
|
||||
unmasked_index = (x != mask_id) & (x != eos_id)
|
||||
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
|
||||
keep_mask = torch.isin(
|
||||
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
|
||||
)
|
||||
unmasked_index = unmasked_index & ~keep_mask
|
||||
|
||||
# ----- Block scheduling over the appended mask tail -----
|
||||
num_blocks = math.ceil(max_new_tokens / block_length)
|
||||
steps = math.ceil(steps / num_blocks) # per-block step budget
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
|
||||
for b in range(num_blocks):
|
||||
# Build a per-sample mask *within this block* (aligned to each prompt's tail)
|
||||
block_mask_index = torch.zeros(
|
||||
(B, block_length), dtype=torch.bool, device=x.device
|
||||
)
|
||||
|
||||
for j in range(B):
|
||||
start = prompt_lens[j] + b * block_length
|
||||
end = min(start + block_length, prompt_lens[j] + max_new_tokens, T)
|
||||
if start < end:
|
||||
width = end - start
|
||||
block_mask_index[j, :width] = (
|
||||
x[j, start:end] == mask_id
|
||||
) # which positions in this block are still masked
|
||||
|
||||
# Decide how many tokens to reveal per step in this block
|
||||
num_transfer_tokens = get_num_transfer_tokens(
|
||||
mask_index=block_mask_index,
|
||||
steps=steps,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
|
||||
# Some steps may be skipped if there are no transfers
|
||||
effective_steps = num_transfer_tokens.size(1)
|
||||
|
||||
# ----- Iterative reveal inside the current block -----
|
||||
for i in range(effective_steps):
|
||||
mask_index = x == mask_id # current global mask map
|
||||
|
||||
# Optional CFG: second forward where original prompt tokens are masked out
|
||||
if cfg_scale > 0.0:
|
||||
un_x = x.clone()
|
||||
un_x[unmasked_index] = mask_id
|
||||
x_ = torch.cat([x, un_x], dim=0)
|
||||
logits = self.model(
|
||||
x_, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
||||
else:
|
||||
logits = self.model(
|
||||
x, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
|
||||
# Argmax decoding with optional Gumbel-Max noise for exploration
|
||||
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
||||
x0 = torch.argmax(
|
||||
logits_with_noise, dim=-1
|
||||
) # [B, T] predicted token ids
|
||||
|
||||
# Per-position confidence used to pick which masks to commit this step
|
||||
if remasking == "low_confidence":
|
||||
p = F.softmax(logits, dim=-1)
|
||||
x0_p = torch.squeeze(
|
||||
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
|
||||
) # [B, T] confidence of predicted token
|
||||
elif remasking == "random":
|
||||
x0_p = torch.rand(
|
||||
(x0.shape[0], x0.shape[1]), device=x0.device
|
||||
) # random scores
|
||||
else:
|
||||
raise NotImplementedError(remasking)
|
||||
|
||||
# Restrict selection window to the *current block's* tail region
|
||||
for j in range(B):
|
||||
x0_p[j, prompt_lens[j] + (b + 1) * block_length :] = -np.inf
|
||||
|
||||
# Only allow updates at currently masked positions; keep others fixed
|
||||
x0 = torch.where(mask_index, x0, x)
|
||||
confidence = torch.where(
|
||||
mask_index, x0_p, -np.inf
|
||||
) # consider masked positions only
|
||||
|
||||
# Pick exactly `num_transfer_tokens[j, i]` highest-confidence positions per sample
|
||||
transfer_index = torch.zeros_like(
|
||||
x0, dtype=torch.bool, device=x0.device
|
||||
)
|
||||
for j in range(confidence.shape[0]):
|
||||
_, select_index = torch.topk(
|
||||
confidence[j], k=num_transfer_tokens[j, i]
|
||||
)
|
||||
transfer_index[j, select_index] = True
|
||||
|
||||
# Commit chosen predictions into the canvas
|
||||
x[transfer_index] = x0[transfer_index]
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
# ----- Output format -----
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
|
||||
@torch.no_grad()
|
||||
def infill(
|
||||
self, inputs: list[torch.Tensor | list], config, **kwargs
|
||||
) -> GeneratorOutput | torch.Tensor:
|
||||
"""
|
||||
Fill in-place the <|mdm_mask|> tokens contained in `inputs`.
|
||||
The whole (padded) sequence is split into block windows of length
|
||||
`block_length`; within each window we progressively "unmask" positions
|
||||
according to the scheduler and chosen remasking strategy.
|
||||
|
||||
Notes:
|
||||
- Right padding uses EOS.
|
||||
- CFG masks out *originally known* (non-mask, non-EOS) tokens in the
|
||||
unconditional branch, identical to `generate`.
|
||||
- Only masked positions are ever updated; non-mask tokens are left intact.
|
||||
"""
|
||||
# ----- pull args from config, allow kwargs to override -----
|
||||
steps = kwargs.get("steps", config.steps)
|
||||
block_length = kwargs.get("block_length", config.block_length)
|
||||
temperature = kwargs.get("temperature", config.temperature)
|
||||
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
|
||||
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
|
||||
remasking = kwargs.get("remasking", config.remasking)
|
||||
stochastic_transfer = kwargs.get(
|
||||
"stochastic_transfer", config.stochastic_transfer
|
||||
)
|
||||
return_dict_in_generate = kwargs.get(
|
||||
"return_dict_in_generate", config.return_dict_in_generate
|
||||
)
|
||||
|
||||
mask_id = self.tokenizer.mask_token_id
|
||||
eos_id = self.tokenizer.eos_token_id
|
||||
|
||||
# ----- Build canvas: right-pad with EOS to the max length in the batch -----
|
||||
if isinstance(inputs[0], list):
|
||||
inputs = [
|
||||
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
|
||||
for p in inputs
|
||||
]
|
||||
|
||||
B = len(inputs)
|
||||
seq_lens = [t.shape[0] for t in inputs]
|
||||
T = max(seq_lens)
|
||||
|
||||
# Default to a single block spanning the whole sequence
|
||||
if block_length is None:
|
||||
block_length = T
|
||||
|
||||
assert 1 <= block_length
|
||||
assert 1 <= steps
|
||||
|
||||
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
|
||||
for i, t in enumerate(inputs):
|
||||
x[i, : seq_lens[i]] = t
|
||||
attention_mask = (x != eos_id).long() if B > 1 else None
|
||||
|
||||
# Tokens that were *given* at the start (non-mask, non-EOS).
|
||||
# These will be masked in the unconditional forward pass for CFG.
|
||||
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
|
||||
unmasked_index = (x != mask_id) & (x != eos_id)
|
||||
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
|
||||
keep_mask = torch.isin(
|
||||
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
|
||||
)
|
||||
unmasked_index = unmasked_index & ~keep_mask
|
||||
|
||||
# ----- Blockwise schedule over the *entire* (padded) sequence -----
|
||||
num_blocks = math.ceil(T / block_length)
|
||||
steps_per_block = math.ceil(steps / num_blocks)
|
||||
histories = [x.clone()] if return_dict_in_generate else None
|
||||
|
||||
# Create attention mask where eos_token_id is masked (set to 0)
|
||||
attention_mask = (x != eos_id).long()
|
||||
|
||||
for b in range(num_blocks):
|
||||
start = b * block_length
|
||||
stop = min(start + block_length, T)
|
||||
|
||||
# Per-sample view of which positions in this block are masks
|
||||
block_mask_index = torch.zeros(
|
||||
(B, block_length), dtype=torch.bool, device=self.model.device
|
||||
)
|
||||
widths = []
|
||||
for j in range(B):
|
||||
# Width limited by sample's true length and sequence end
|
||||
width = max(0, min(seq_lens[j], stop) - start)
|
||||
widths.append(width)
|
||||
if width > 0:
|
||||
block_mask_index[j, :width] = x[j, start : start + width] == mask_id
|
||||
|
||||
# Decide how many tokens to reveal at each step in this block
|
||||
num_transfer_tokens = get_num_transfer_tokens(
|
||||
mask_index=block_mask_index,
|
||||
steps=steps_per_block,
|
||||
scheduler=self.scheduler,
|
||||
stochastic=stochastic_transfer,
|
||||
)
|
||||
|
||||
# Some blocks may have no masks => effective_steps == 0
|
||||
effective_steps = num_transfer_tokens.size(1)
|
||||
|
||||
for s in range(effective_steps):
|
||||
mask_index_full = x == mask_id
|
||||
|
||||
# ----- Forward pass (+ optional CFG) -----
|
||||
if cfg_scale > 0.0:
|
||||
un_x = x.clone()
|
||||
un_x[unmasked_index] = mask_id
|
||||
x_ = torch.cat([x, un_x], dim=0)
|
||||
logits = self.model(
|
||||
x_, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
||||
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
||||
else:
|
||||
logits = self.model(
|
||||
x, attention_mask=attention_mask
|
||||
).logits # Use attention mask here
|
||||
|
||||
# Greedy with optional Gumbel-Max noise
|
||||
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
||||
x0 = torch.argmax(logits_with_noise, dim=-1) # [B, T]
|
||||
|
||||
# Confidence used for choosing which masks to commit this step
|
||||
if remasking == "low_confidence":
|
||||
p = F.softmax(logits, dim=-1)
|
||||
x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(
|
||||
-1
|
||||
) # [B, T]
|
||||
elif remasking == "random":
|
||||
x0_p = torch.rand((B, T), device=self.model.device)
|
||||
else:
|
||||
raise NotImplementedError(remasking)
|
||||
|
||||
# Restrict selection to the *current* block only
|
||||
for j in range(B):
|
||||
end_j = start + widths[j]
|
||||
# Outside current block => impossible to select
|
||||
x0_p[j, :start] = -np.inf
|
||||
x0_p[j, end_j:] = -np.inf
|
||||
|
||||
# Only consider currently-masked positions as candidates
|
||||
x0 = torch.where(mask_index_full, x0, x)
|
||||
confidence = torch.where(mask_index_full, x0_p, -np.inf)
|
||||
|
||||
# Pick exactly num_transfer_tokens[j, s] positions per sample
|
||||
transfer_index = torch.zeros_like(x, dtype=torch.bool)
|
||||
for j in range(B):
|
||||
k = int(num_transfer_tokens[j, s].item())
|
||||
if k > 0:
|
||||
_, select_idx = torch.topk(confidence[j], k=k)
|
||||
transfer_index[j, select_idx] = True
|
||||
|
||||
# Commit selected predictions into the canvas
|
||||
x[transfer_index] = x0[transfer_index]
|
||||
if histories is not None:
|
||||
histories.append(x.clone())
|
||||
|
||||
# ----- Output format -----
|
||||
if not return_dict_in_generate:
|
||||
return x
|
||||
else:
|
||||
return GeneratorOutput(sequences=x, histories=histories)
|
||||
19
dllm/dllm/pipelines/llada/models/__init__.py
Normal file
19
dllm/dllm/pipelines/llada/models/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
from .configuration_llada import LLaDAConfig
|
||||
from .modeling_llada import LLaDAModelLM
|
||||
from .configuration_lladamoe import LLaDAMoEConfig
|
||||
from .modeling_lladamoe import LLaDAMoEModelLM
|
||||
|
||||
# Register with HuggingFace Auto classes for local usage
|
||||
try:
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
|
||||
|
||||
AutoConfig.register("llada", LLaDAConfig)
|
||||
AutoModel.register(LLaDAConfig, LLaDAModelLM)
|
||||
AutoModelForMaskedLM.register(LLaDAConfig, LLaDAModelLM)
|
||||
|
||||
AutoConfig.register("lladamoe", LLaDAMoEConfig)
|
||||
AutoModel.register(LLaDAMoEConfig, LLaDAMoEModelLM)
|
||||
AutoModelForMaskedLM.register(LLaDAMoEConfig, LLaDAMoEModelLM)
|
||||
except ImportError:
|
||||
# transformers not available or Auto classes not imported
|
||||
pass
|
||||
459
dllm/dllm/pipelines/llada/models/configuration_llada.py
Normal file
459
dllm/dllm/pipelines/llada/models/configuration_llada.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""
|
||||
LLaDA configuration
|
||||
"""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from enum import Enum
|
||||
from os import PathLike
|
||||
from typing import Union
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ActivationType",
|
||||
"ActivationCheckpointingStrategy",
|
||||
"BlockType",
|
||||
"LayerNormType",
|
||||
"InitFnType",
|
||||
"ModelConfig",
|
||||
]
|
||||
|
||||
PathOrStr = Union[str, PathLike]
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""
|
||||
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
|
||||
We include this here for compatibility with older version of Python.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"'{str(self)}'"
|
||||
|
||||
|
||||
class LayerNormType(StrEnum):
|
||||
default = "default"
|
||||
"""
|
||||
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
|
||||
"""
|
||||
|
||||
low_precision = "low_precision"
|
||||
"""
|
||||
A low-precision version of the default LayerNorm.
|
||||
"""
|
||||
|
||||
rms = "rms"
|
||||
"""
|
||||
An RMSNorm implementation. When using ``torch.compile`` this is
|
||||
probably the fastest implementation.
|
||||
"""
|
||||
|
||||
gemma_rms = "gemma_rms"
|
||||
"""
|
||||
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
|
||||
probably the fastest implementation.
|
||||
"""
|
||||
|
||||
amd_compatible = "amd_compatible"
|
||||
"""
|
||||
LayerNorm implemented manually to work around an issue with ROCm.
|
||||
"""
|
||||
|
||||
|
||||
class ActivationType(StrEnum):
|
||||
gelu = "gelu"
|
||||
relu = "relu"
|
||||
silu = "silu"
|
||||
swiglu = "swiglu"
|
||||
|
||||
|
||||
class BlockType(StrEnum):
|
||||
sequential = "sequential"
|
||||
parallel = "parallel"
|
||||
|
||||
llama = "llama"
|
||||
"""
|
||||
A block similar to the sequential block with slightly different
|
||||
implementations of operations like attention to imitate the behavior of Llama.
|
||||
"""
|
||||
|
||||
|
||||
class InitFnType(StrEnum):
|
||||
mitchell = "mitchell"
|
||||
"""
|
||||
The strategy suggested to us by Mitchell Wortsman from UW.
|
||||
This uses a truncated normal distribution with an adaptive standard deviation that depends
|
||||
on the size of the weights as well as the depth of the layer.
|
||||
"""
|
||||
|
||||
normal = "normal"
|
||||
"""
|
||||
All weights are initialized from the same normal distribution.
|
||||
"""
|
||||
|
||||
kaiming_normal = "kaiming_normal"
|
||||
"""
|
||||
All weights are initialized with the Kaiming method from a normal distribution.
|
||||
Note this currently won't work with FSDP.
|
||||
"""
|
||||
|
||||
fan_in = "fan_in"
|
||||
"""
|
||||
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
|
||||
is the input dimensionality of the kernel.
|
||||
"""
|
||||
|
||||
full_megatron = "full_megatron"
|
||||
"""
|
||||
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig():
|
||||
"""
|
||||
LLaDA (model) configuration.
|
||||
"""
|
||||
|
||||
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
|
||||
|
||||
d_model: int = 768
|
||||
"""
|
||||
The hidden size of the model.
|
||||
"""
|
||||
|
||||
n_heads: int = 12
|
||||
"""
|
||||
The number of self-attention heads.
|
||||
"""
|
||||
|
||||
n_kv_heads: Optional[int] = None
|
||||
"""
|
||||
The number of heads to use for keys and values. Defaults to `n_heads`.
|
||||
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
|
||||
Set this to 1 for multi-query attention.
|
||||
Set it to some in-between value for Llama2-style grouped query attention.
|
||||
"""
|
||||
|
||||
n_layers: int = 12
|
||||
"""
|
||||
The number of layers/blocks.
|
||||
"""
|
||||
|
||||
mlp_ratio: int = 4
|
||||
"""
|
||||
The ratio of the inner MLP dimensionality to ``d_model``.
|
||||
This is only used when ``mlp_hidden_size`` is not set.
|
||||
"""
|
||||
|
||||
mlp_hidden_size: Optional[int] = None
|
||||
"""
|
||||
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
|
||||
"""
|
||||
|
||||
activation_type: ActivationType = ActivationType.swiglu
|
||||
"""
|
||||
The activation function to use within the MLP layers.
|
||||
"""
|
||||
|
||||
block_type: BlockType = BlockType.sequential
|
||||
"""
|
||||
The transformer block implementation.
|
||||
"""
|
||||
|
||||
block_group_size: int = 1
|
||||
"""
|
||||
The number of blocks to group together into a single parent block.
|
||||
This has no affect on the number of parameters in the model and is only used to wrap groups
|
||||
of blocks together with a single FSDP wrapper during training.
|
||||
"""
|
||||
|
||||
alibi: bool = False
|
||||
"""
|
||||
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
|
||||
"""
|
||||
|
||||
alibi_bias_max: float = 8.0
|
||||
"""
|
||||
Maximum absolute value of ALiBi bias.
|
||||
"""
|
||||
|
||||
rope: bool = False
|
||||
"""
|
||||
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
|
||||
"""
|
||||
|
||||
rope_full_precision: bool = True
|
||||
"""
|
||||
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
|
||||
apply RoPE at the precision of the input.
|
||||
"""
|
||||
|
||||
flash_attention: bool = False
|
||||
"""
|
||||
If ``True``, use ``FlashAttention``.
|
||||
"""
|
||||
|
||||
attention_dropout: float = 0.1
|
||||
"""
|
||||
The dropout probability within the attention modules.
|
||||
"""
|
||||
|
||||
multi_query_attention: Optional[bool] = None
|
||||
"""
|
||||
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
|
||||
and is more efficient during inference.
|
||||
"""
|
||||
|
||||
attention_layer_norm: bool = False
|
||||
"""
|
||||
Apply layer norm to the keys and queries within the attention mechanism.
|
||||
This can help stabilize training.
|
||||
"""
|
||||
|
||||
residual_dropout: float = 0.1
|
||||
"""
|
||||
The dropout probability for the MLP and attention output within each block.
|
||||
"""
|
||||
|
||||
embedding_dropout: float = 0.1
|
||||
"""
|
||||
The dropout probability for embeddings.
|
||||
"""
|
||||
|
||||
input_emb_norm: bool = False
|
||||
"""
|
||||
An input hidden_states norm implementation by gemmma.
|
||||
"""
|
||||
|
||||
layer_norm_type: LayerNormType = LayerNormType.default
|
||||
"""
|
||||
The layernorm implementation to use.
|
||||
"""
|
||||
|
||||
layer_norm_with_affine: bool = True
|
||||
"""
|
||||
Whether to include bias and weight parameters for the layer norms.
|
||||
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
|
||||
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
|
||||
to ``False``.
|
||||
"""
|
||||
|
||||
rms_norm_eps: float = 1e-05
|
||||
"""
|
||||
The rms layernorm eps param.
|
||||
"""
|
||||
|
||||
attention_layer_norm_with_affine: bool = True
|
||||
"""
|
||||
Toggle affine transform for the QK norms.
|
||||
"""
|
||||
|
||||
max_sequence_length: int = 1024
|
||||
"""
|
||||
The maximum input sequence length supported by the model.
|
||||
"""
|
||||
|
||||
rope_theta: float = 10000.0
|
||||
"""
|
||||
The rope base param.
|
||||
"""
|
||||
|
||||
include_qkv_bias: Optional[bool] = False
|
||||
"""
|
||||
Whether or not to include bias parameters in qkv linear layers.
|
||||
"""
|
||||
|
||||
include_bias: bool = False
|
||||
"""
|
||||
Whether or not to include bias parameters in linear layers.
|
||||
In PaLM, they got rid of all bias terms because they found that large
|
||||
models tend to have near 0 bias terms anyway.
|
||||
"""
|
||||
|
||||
bias_for_layer_norm: Optional[bool] = None
|
||||
"""
|
||||
Whether or not to include bias parameters in layer norm.
|
||||
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
|
||||
layer norm.
|
||||
When this is None (the default), it inherits the setting from include_bias.
|
||||
"""
|
||||
|
||||
scale_logits: bool = False
|
||||
"""
|
||||
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
|
||||
"""
|
||||
|
||||
vocab_size: int = 50257
|
||||
"""
|
||||
Vocabulary size of the model.
|
||||
"""
|
||||
|
||||
embedding_size: Optional[int] = 50304
|
||||
"""
|
||||
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
|
||||
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
|
||||
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
|
||||
substantially.
|
||||
"""
|
||||
|
||||
weight_tying: bool = True
|
||||
"""
|
||||
Whether to tie output linear weights to the input embedding.
|
||||
"""
|
||||
|
||||
eos_token_id: int = 50256
|
||||
"""
|
||||
The ID of the end-of-sentence special token.
|
||||
"""
|
||||
|
||||
pad_token_id: int = 50256
|
||||
"""
|
||||
The ID of the token to use for padding. Defaults to the ID of the EOS token.
|
||||
"""
|
||||
|
||||
mask_token_id: Optional[int] = 50256
|
||||
"""
|
||||
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
|
||||
"""
|
||||
|
||||
init_device: Optional[str] = None
|
||||
"""
|
||||
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
|
||||
"""
|
||||
|
||||
init_fn: InitFnType = InitFnType.normal
|
||||
"""
|
||||
The weight initialization strategy.
|
||||
"""
|
||||
|
||||
init_std: float = 0.02
|
||||
"""
|
||||
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
|
||||
as "normal".
|
||||
"""
|
||||
|
||||
init_cutoff_factor: Optional[float] = None
|
||||
"""
|
||||
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
|
||||
as "normal". Setting this to None means values are not cutoff.
|
||||
"""
|
||||
|
||||
precision: Optional[str] = None
|
||||
"""
|
||||
Precision used to train/evaluate with. You shouldn't set this directly.
|
||||
See :data:`TrainConfig.precision` instead.
|
||||
"""
|
||||
|
||||
@property
|
||||
def effective_n_kv_heads(self) -> int:
|
||||
if self.n_kv_heads is None:
|
||||
if self.multi_query_attention is True:
|
||||
return 1
|
||||
else:
|
||||
return self.n_heads
|
||||
else:
|
||||
if self.multi_query_attention is None:
|
||||
return self.n_kv_heads
|
||||
if self.multi_query_attention:
|
||||
n_kv_heads_should_be = 1
|
||||
else:
|
||||
n_kv_heads_should_be = self.n_heads
|
||||
if self.n_kv_heads == n_kv_heads_should_be:
|
||||
return n_kv_heads_should_be
|
||||
else:
|
||||
raise Exception(
|
||||
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
|
||||
)
|
||||
|
||||
class ActivationCheckpointingStrategy(StrEnum):
|
||||
whole_layer = "whole_layer"
|
||||
"""
|
||||
Checkpoint every transformer layer.
|
||||
"""
|
||||
|
||||
one_in_two = "one_in_two"
|
||||
"""
|
||||
Checkpoint one in two transformer layers.
|
||||
"""
|
||||
|
||||
one_in_three = "one_in_three"
|
||||
"""
|
||||
Checkpoint one in three transformer layers.
|
||||
"""
|
||||
|
||||
one_in_four = "one_in_four"
|
||||
"""
|
||||
Checkpoint one in four transformer layers.
|
||||
"""
|
||||
|
||||
two_in_three = "two_in_three"
|
||||
"""
|
||||
Checkpoint two out of every three transformer layers.
|
||||
"""
|
||||
|
||||
three_in_four = "three_in_four"
|
||||
"""
|
||||
Checkpoint three out of four of every transformer layers.
|
||||
"""
|
||||
|
||||
four_in_five = "four_in_five"
|
||||
"""
|
||||
Checkpoint four out of five of every transformer layers.
|
||||
"""
|
||||
|
||||
nine_in_ten = "nine_in_ten"
|
||||
"""
|
||||
Checkpoint nine out of ten of every transformer layers.
|
||||
"""
|
||||
|
||||
fine_grained = "fine_grained"
|
||||
"""
|
||||
Focus checkpointing on where it is cheap to recompute and saves most memory.
|
||||
"""
|
||||
|
||||
|
||||
class LLaDAConfig(PretrainedConfig):
|
||||
model_type = "llada"
|
||||
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
||||
|
||||
def __init__(self, use_cache: bool = False, **kwargs):
|
||||
model_config = ModelConfig()
|
||||
all_kwargs = model_config.__dict__
|
||||
all_kwargs.update(kwargs)
|
||||
all_kwargs.update({"use_cache": use_cache})
|
||||
all_kwargs.update(
|
||||
{
|
||||
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
|
||||
}
|
||||
)
|
||||
super().__init__(**all_kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.n_heads
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.n_layers
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.d_model
|
||||
96
dllm/dllm/pipelines/llada/models/configuration_lladamoe.py
Normal file
96
dllm/dllm/pipelines/llada/models/configuration_lladamoe.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
LLaDA MoE configuration
|
||||
"""
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class LLaDAMoEConfig(PretrainedConfig):
|
||||
model_type = "lladamoe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=-1,
|
||||
hidden_size=-1,
|
||||
dense_intermediate_size=-1,
|
||||
expert_intermediate_size=-1,
|
||||
shared_expert_intermediate_size=-1,
|
||||
num_hidden_layers=-1,
|
||||
num_attention_heads=-1,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=False,
|
||||
pad_token_id=1,
|
||||
bos_token_id=None,
|
||||
eos_token_id=50279,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=-1,
|
||||
partial_rotary_factor=-1,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
clip_qkv=None,
|
||||
num_experts_per_tok=-1,
|
||||
num_experts=-1,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.01,
|
||||
norm_topk_prob=None,
|
||||
qk_layernorm=None,
|
||||
moe_layer_freq=[],
|
||||
moe_router_enable_expert_bias=None,
|
||||
moe_router_score_function=None,
|
||||
routed_scaling_factor=1,
|
||||
router_num_group=-2,
|
||||
router_topk_group=-2,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.expert_intermediate_size = expert_intermediate_size
|
||||
self.dense_intermediate_size = dense_intermediate_size
|
||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.clip_qkv = clip_qkv
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.qk_layernorm = qk_layernorm
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
|
||||
self.moe_router_score_function = moe_router_score_function
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.router_num_group = router_num_group
|
||||
self.router_topk_group = router_topk_group
|
||||
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
1458
dllm/dllm/pipelines/llada/models/modeling_llada.py
Normal file
1458
dllm/dllm/pipelines/llada/models/modeling_llada.py
Normal file
File diff suppressed because it is too large
Load Diff
1168
dllm/dllm/pipelines/llada/models/modeling_lladamoe.py
Normal file
1168
dllm/dllm/pipelines/llada/models/modeling_lladamoe.py
Normal file
File diff suppressed because it is too large
Load Diff
3
dllm/dllm/pipelines/llada/trainer.py
Normal file
3
dllm/dllm/pipelines/llada/trainer.py
Normal file
@ -0,0 +1,3 @@
|
||||
from dllm.core.trainers import MDLMTrainer
|
||||
|
||||
LLaDATrainer = MDLMTrainer
|
||||
7
dllm/dllm/pipelines/rnd/__init__.py
Normal file
7
dllm/dllm/pipelines/rnd/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# from dllm.pipelines.rnd import generate, trainer
|
||||
from . import models
|
||||
from .models import RND1LM, RND1Config, RND1GenerationConfig
|
||||
|
||||
# from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
|
||||
# from dllm.pipelines.rnd.models.configuration_rnd import RND1Config
|
||||
from .trainer import RNDTrainer
|
||||
53
dllm/dllm/pipelines/rnd/models/__init__.py
Normal file
53
dllm/dllm/pipelines/rnd/models/__init__.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Radical Numerics Diffusion (RND1) - Diffusion-based Language Model.
|
||||
"""
|
||||
|
||||
from .configuration_rnd import RND1Config
|
||||
from .modeling_rnd import (
|
||||
RND1LM,
|
||||
RND1Model,
|
||||
RND1PreTrainedModel,
|
||||
RND1Attention,
|
||||
RND1DecoderLayer,
|
||||
RND1SparseMoeBlock,
|
||||
)
|
||||
from .generation_config import RND1GenerationConfig
|
||||
from .generation_utils import RND1GenerationMixin
|
||||
from .sampling import (
|
||||
diffusion_sample,
|
||||
apply_top_k_filtering,
|
||||
apply_top_p_filtering,
|
||||
)
|
||||
from .terminal_visualizer import TerminalVisualizer, SimpleProgressBar
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"RND1Config",
|
||||
"RND1GenerationConfig",
|
||||
"RND1LM",
|
||||
"RND1Model",
|
||||
"RND1PreTrainedModel",
|
||||
"RND1Attention",
|
||||
"RND1DecoderLayer",
|
||||
"RND1SparseMoeBlock",
|
||||
"RND1GenerationMixin",
|
||||
"TerminalVisualizer",
|
||||
"SimpleProgressBar",
|
||||
]
|
||||
|
||||
# Register with HuggingFace Auto classes for local usage
|
||||
try:
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
|
||||
|
||||
AutoConfig.register("rnd1", RND1Config)
|
||||
AutoModel.register(RND1Config, RND1LM)
|
||||
AutoModelForMaskedLM.register(RND1Config, RND1LM)
|
||||
except ImportError:
|
||||
# transformers not available or Auto classes not imported
|
||||
pass
|
||||
124
dllm/dllm/pipelines/rnd/models/configuration_rnd.py
Normal file
124
dllm/dllm/pipelines/rnd/models/configuration_rnd.py
Normal file
@ -0,0 +1,124 @@
|
||||
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 Model Configuration.
|
||||
|
||||
This module defines the configuration class for RND1 models.
|
||||
The default settings are derived from Qwen/Qwen3-30B-A3B and augmented
|
||||
with RND1-specific parameters.
|
||||
"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
# Qwen3-30B-A3B / checkpoint defaults
|
||||
CONFIG_DEFAULTS = {
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"decoder_sparse_step": 1,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 48,
|
||||
"mlp_only_layers": [],
|
||||
"moe_intermediate_size": 768,
|
||||
"norm_topk_prob": True,
|
||||
"num_attention_heads": 32,
|
||||
"num_experts": 128,
|
||||
"num_experts_per_tok": 8,
|
||||
"num_hidden_layers": 48,
|
||||
"num_key_value_heads": 4,
|
||||
"output_router_logits": False,
|
||||
"pad_token_id": 151643,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": False,
|
||||
"rope_theta": 1000000.0,
|
||||
"router_aux_loss_coef": 0.001,
|
||||
"sliding_window": False,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"use_cache": False,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936,
|
||||
}
|
||||
|
||||
|
||||
class RND1Config(PretrainedConfig):
|
||||
"""
|
||||
Configuration class for RND1 models.
|
||||
|
||||
This configuration extends Qwen3MoeConfig with additional parameters
|
||||
specific to the RND1 (Radical Numerics Diffusion v1) architecture.
|
||||
|
||||
Args:
|
||||
moe_backend: Backend for MoE computation ("hf", "vllm", "sglang" or "flashinfer")
|
||||
num_diffusion_steps: Default number of diffusion steps for generation
|
||||
mask_token_id: Token ID used for masking (default: 151669 for Qwen)
|
||||
**kwargs: Additional arguments passed to Qwen3MoeConfig
|
||||
"""
|
||||
|
||||
model_type = "rnd1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_backend: str = "hf",
|
||||
num_diffusion_steps: int = 256,
|
||||
mask_token_id: int = 151669,
|
||||
**kwargs,
|
||||
):
|
||||
# Force non-causal and no caching for RND1
|
||||
kwargs["use_cache"] = False
|
||||
kwargs["is_causal"] = False
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set defaults after pretrained init to prevent overrides
|
||||
self.set_config_defaults()
|
||||
|
||||
# QoL: set attn impl directly from config
|
||||
if "attn_implementation" in kwargs:
|
||||
self._attn_implementation = kwargs["attn_implementation"]
|
||||
|
||||
# RND1-specific parameters
|
||||
self.moe_backend = moe_backend
|
||||
self.num_diffusion_steps = num_diffusion_steps
|
||||
self.mask_token_id = mask_token_id
|
||||
|
||||
# Ensure bidirectional attention and no caching
|
||||
self.is_causal = False
|
||||
self.use_cache = False
|
||||
|
||||
def set_config_defaults(self):
|
||||
"""
|
||||
Ensure model defaults are set according to final training checkpoint
|
||||
|
||||
Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which
|
||||
RND1 is derived.
|
||||
"""
|
||||
for k, v in CONFIG_DEFAULTS.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes configuration to dictionary with auto_map for Hub.
|
||||
|
||||
The auto_map ensures that when users load from HuggingFace Hub,
|
||||
the correct custom classes are automatically resolved.
|
||||
"""
|
||||
data = super().to_dict()
|
||||
data.setdefault(
|
||||
"auto_map",
|
||||
{
|
||||
"AutoConfig": "configuration_rnd.RND1Config",
|
||||
"AutoModel": "modeling_rnd.RND1Model",
|
||||
"AutoModelForMaskedLM": "modeling_rnd.RND1LM",
|
||||
},
|
||||
)
|
||||
return data
|
||||
77
dllm/dllm/pipelines/rnd/models/generation_config.py
Normal file
77
dllm/dllm/pipelines/rnd/models/generation_config.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 Generation Configuration.
|
||||
|
||||
This module defines the generation configuration for RND1 models,
|
||||
controlling the diffusion-based generation process.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
|
||||
|
||||
class RND1GenerationConfig(GenerationConfig):
|
||||
"""
|
||||
Configuration class for RND1 generation parameters.
|
||||
|
||||
This class extends the base GenerationConfig to include parameters
|
||||
specific to diffusion-based language generation.
|
||||
|
||||
Args:
|
||||
max_length: Maximum sequence length
|
||||
num_diffusion_steps: Number of denoising steps in the diffusion process
|
||||
mask_token_id: Token ID used for masking during diffusion
|
||||
temperature: Temperature for sampling (higher = more random)
|
||||
top_k: Optional top-k filtering
|
||||
top_p: Optional nucleus (top-p) filtering
|
||||
greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
|
||||
**kwargs: Additional arguments passed to GenerationConfig
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_length: int = 256,
|
||||
num_diffusion_steps: int = 256,
|
||||
mask_token_id: int = 151669,
|
||||
temperature: float = 0.1,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
greedy: bool = False,
|
||||
bos_token_id: int = None,
|
||||
eos_token_id: int = None,
|
||||
pad_token_id: int = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Force no caching for RND generation
|
||||
# kwargs['use_cache'] = False
|
||||
kwargs.pop('use_cache', None)
|
||||
super().__init__(
|
||||
max_length=max_length,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=not greedy,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# RND-specific parameters
|
||||
self.num_diffusion_steps = num_diffusion_steps
|
||||
self.mask_token_id = mask_token_id
|
||||
self.greedy = greedy
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert configuration to dictionary."""
|
||||
output = super().to_dict()
|
||||
output["num_diffusion_steps"] = self.num_diffusion_steps
|
||||
output["mask_token_id"] = self.mask_token_id
|
||||
output["greedy"] = self.greedy
|
||||
return output
|
||||
187
dllm/dllm/pipelines/rnd/models/generation_utils.py
Normal file
187
dllm/dllm/pipelines/rnd/models/generation_utils.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 Generation Utilities.
|
||||
|
||||
This module provides generation utilities and mixins for RND1 models,
|
||||
including the main GenerationMixin class that integrates with HuggingFace.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from transformers import GenerationMixin as HFGenerationMixin
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
|
||||
|
||||
|
||||
class RND1GenerationMixin(HFGenerationMixin):
|
||||
"""
|
||||
Generation mixin for RND1 models.
|
||||
|
||||
This mixin provides generation methods compatible with HuggingFace's
|
||||
generation API while using RND1's diffusion-based sampling internally.
|
||||
"""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.LongTensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
# RND1-specific parameters
|
||||
prefix_ids: Optional[torch.LongTensor] = None,
|
||||
suffix_ids: Optional[torch.LongTensor] = None,
|
||||
infill_length: Optional[int] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
**kwargs, # Accept all kwargs to be compatible with pipelines
|
||||
) -> Union[torch.LongTensor, Dict[str, Any]]:
|
||||
"""
|
||||
Generate text using RND1's diffusion-based sampling.
|
||||
|
||||
Follows HuggingFace's standard generate API, using diffusion sampling
|
||||
internally. Supports both standard generation and infilling.
|
||||
|
||||
Args:
|
||||
inputs: Input token IDs to use as prefix (standard HF parameter)
|
||||
generation_config: Generation configuration object
|
||||
prefix_ids: Alternative to inputs for infilling tasks
|
||||
suffix_ids: Optional suffix for infilling tasks
|
||||
infill_length: Length of infill region (for infilling)
|
||||
return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
|
||||
**kwargs: Additional arguments (accepted for compatibility)
|
||||
|
||||
Returns:
|
||||
Generated token IDs or GenerateDecoderOnlyOutput
|
||||
"""
|
||||
if generation_config is not None:
|
||||
gen_config = generation_config
|
||||
model_kwargs = kwargs.copy()
|
||||
else:
|
||||
# Only prepare config from kwargs if no config was provided
|
||||
gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
if inputs is not None:
|
||||
prefix_ids = inputs.to(device)
|
||||
elif prefix_ids is not None:
|
||||
prefix_ids = prefix_ids.to(device)
|
||||
else:
|
||||
prefix_ids = None
|
||||
|
||||
if suffix_ids is not None:
|
||||
suffix_ids = suffix_ids.to(device)
|
||||
|
||||
eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
|
||||
eos_token_id = None if eos_token_id == -1 else eos_token_id
|
||||
pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
|
||||
bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
|
||||
mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
|
||||
|
||||
if infill_length is not None and prefix_ids is not None:
|
||||
# Infilling mode: use specified infill_length
|
||||
prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
|
||||
suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
|
||||
seq_len = prefix_len + infill_length + suffix_len
|
||||
else:
|
||||
# Standard generation mode
|
||||
if prefix_ids is not None:
|
||||
prefix_len = prefix_ids.shape[1]
|
||||
if gen_config.max_new_tokens is not None:
|
||||
seq_len = prefix_len + gen_config.max_new_tokens
|
||||
else:
|
||||
seq_len = gen_config.max_length or self.config.max_position_embeddings
|
||||
else:
|
||||
seq_len = gen_config.max_length or self.config.max_position_embeddings
|
||||
|
||||
num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
|
||||
getattr(self.config, "num_diffusion_steps", 256))
|
||||
|
||||
temperature = float(getattr(gen_config, "temperature", 1.0))
|
||||
top_k = getattr(gen_config, "top_k", None)
|
||||
top_p = getattr(gen_config, "top_p", None)
|
||||
|
||||
greedy = getattr(gen_config, "greedy",
|
||||
not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
sequences = diffusion_sample(
|
||||
model=self,
|
||||
seq_len=seq_len,
|
||||
num_steps=num_diffusion_steps,
|
||||
mask_token_id=mask_token_id,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
greedy=greedy,
|
||||
prefix_ids=prefix_ids,
|
||||
suffix_ids=suffix_ids,
|
||||
infill_length=infill_length,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
device=device,
|
||||
visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
|
||||
)
|
||||
|
||||
if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
|
||||
from transformers.generation.utils import GenerateDecoderOnlyOutput
|
||||
return GenerateDecoderOnlyOutput(sequences=sequences)
|
||||
|
||||
return sequences
|
||||
|
||||
def generate_with_visualization(
|
||||
self,
|
||||
tokenizer,
|
||||
inputs: Optional[torch.LongTensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
suffix_ids: Optional[torch.LongTensor] = None,
|
||||
infill_length: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generate with live visualization (for demos).
|
||||
|
||||
This method requires a tokenizer to display the generation process.
|
||||
For production use, prefer `generate()`.
|
||||
|
||||
Args:
|
||||
tokenizer: Tokenizer for decoding tokens to text
|
||||
inputs: Input token IDs to use as prefix
|
||||
generation_config: Generation configuration object
|
||||
suffix_ids: Optional suffix token IDs
|
||||
infill_length: Length of infill region
|
||||
**kwargs: Additional arguments for backward compatibility
|
||||
|
||||
Returns:
|
||||
Generated token IDs as LongTensor
|
||||
"""
|
||||
from .terminal_visualizer import TerminalVisualizer
|
||||
visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
|
||||
|
||||
return self.generate(
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
suffix_ids=suffix_ids,
|
||||
infill_length=infill_length,
|
||||
visualizer=visualizer,
|
||||
return_dict_in_generate=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare inputs for generation (required by HuggingFace).
|
||||
|
||||
For RND1, we don't use the standard autoregressive generation,
|
||||
so this just returns the input_ids.
|
||||
"""
|
||||
return {"input_ids": input_ids}
|
||||
653
dllm/dllm/pipelines/rnd/models/modeling_rnd.py
Normal file
653
dllm/dllm/pipelines/rnd/models/modeling_rnd.py
Normal file
@ -0,0 +1,653 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 model implementation.
|
||||
|
||||
This module implements the RND1 architecture with bidirectional attention for
|
||||
diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
|
||||
with multiple backend options (HF, vLLM, SGLang, FlashInfer).
|
||||
|
||||
Based on the Qwen3Moe architecture:
|
||||
https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple, List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.utils import logging
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
MoeModelOutputWithPast,
|
||||
MaskedLMOutput,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from .configuration_rnd import RND1Config
|
||||
from .generation_utils import RND1GenerationMixin
|
||||
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||
Qwen3MoeRMSNorm,
|
||||
Qwen3MoeRotaryEmbedding,
|
||||
Qwen3MoeMLP,
|
||||
apply_rotary_pos_emb
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts as fused_experts_vllm, fused_topk as fused_topk_vllm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm as VLLMRMSNorm
|
||||
except Exception:
|
||||
fused_experts_vllm = None
|
||||
fused_topk_vllm = None
|
||||
|
||||
try:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
|
||||
# from sglang.srt.layers.layernorm import RMSNorm as SGLangRMSNorm # TODO: buggy atm
|
||||
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
||||
except Exception:
|
||||
sglang_fused_moe = None
|
||||
StandardTopKOutput = None
|
||||
|
||||
|
||||
try:
|
||||
import flashinfer.fused_moe as fused_moe
|
||||
## TODO: below needs flashinfer>=0.4.0, but has some bug atm
|
||||
# from flashinfer.norm import rmsnorm as flashinfer_rmsnorm
|
||||
# class FlashInferRMSNorm(Qwen3MoeRMSNorm):
|
||||
# """Wrapper around FlashInfer RMSNorm to match Qwen3MoeRMSNorm interface"""
|
||||
# def forward(self, hidden_states):
|
||||
# return flashinfer_rmsnorm(hidden_states, self.weight, self.variance_epsilon)
|
||||
|
||||
except Exception:
|
||||
fused_moe = None
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""Expand key/value heads to match query heads for grouped-query attention."""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class RND1Attention(nn.Module):
|
||||
"""RND1 attention layer with bidirectional attention for diffusion modeling."""
|
||||
|
||||
def __init__(self, config: RND1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
|
||||
|
||||
if config.moe_backend == "vllm":
|
||||
RMSNormClass = VLLMRMSNorm
|
||||
else:
|
||||
RMSNormClass = Qwen3MoeRMSNorm
|
||||
self.q_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
self.sliding_window = getattr(config, "sliding_window", None)
|
||||
|
||||
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
dual_cache: Optional[bool] = False,
|
||||
replace_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa")
|
||||
|
||||
if use_sdpa:
|
||||
if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
|
||||
if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
|
||||
attention_mask = attention_mask.to(dtype=query_states.dtype)
|
||||
|
||||
assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
attn_out = attn_out.transpose(1, 2).contiguous()
|
||||
attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_out = self.o_proj(attn_out)
|
||||
return attn_out, None
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
# TODO: modify this to boolean masks
|
||||
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
|
||||
attn_out = torch.matmul(attn_weights, value_states)
|
||||
attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
|
||||
attn_out = self.o_proj(attn_out)
|
||||
|
||||
return attn_out, None
|
||||
|
||||
|
||||
class RND1DecoderLayer(nn.Module):
|
||||
"""RND1 decoder layer with bidirectional attention for diffusion language modeling."""
|
||||
|
||||
def __init__(self, config: RND1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.self_attn = RND1Attention(config, layer_idx)
|
||||
self.mlp = RND1SparseMoeBlock(config)
|
||||
if config.moe_backend == "vllm":
|
||||
RMSNormClass = VLLMRMSNorm
|
||||
else:
|
||||
RMSNormClass = Qwen3MoeRMSNorm
|
||||
self.input_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
replace_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_out, attn_weights = self.self_attn(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
position_embeddings=position_embeddings,
|
||||
replace_position=replace_position,
|
||||
)
|
||||
hidden_states = residual + attn_out
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
ff_out = self.mlp(hidden_states)
|
||||
if isinstance(ff_out, tuple):
|
||||
ff_out = ff_out[0]
|
||||
hidden_states = residual + ff_out
|
||||
|
||||
return hidden_states, attn_weights
|
||||
|
||||
|
||||
class RND1SparseMoeBlock(nn.Module):
|
||||
"""RND1 Sparse MoE block with multiple backend support (HF, vLLM, SGLang, FlashInfer)."""
|
||||
|
||||
def __init__(self, config: RND1Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.backend = getattr(config, "moe_backend", "hf")
|
||||
self.num_experts = config.num_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
|
||||
|
||||
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
||||
self.experts = nn.ModuleList(
|
||||
[Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
|
||||
)
|
||||
|
||||
# Cached weight tensors for optimized backends
|
||||
self._w1 = None
|
||||
self._w2 = None
|
||||
if self.backend == "sglang":
|
||||
if sglang_fused_moe is None or StandardTopKOutput is None:
|
||||
raise RuntimeError("sglang is not available, cannot use sglang backend")
|
||||
elif self.backend == "flashinfer":
|
||||
if fused_moe is None:
|
||||
raise RuntimeError("flashinfer is not available, cannot use flashinfer backend")
|
||||
elif self.backend == "vllm":
|
||||
if fused_experts_vllm is None or fused_topk_vllm is None:
|
||||
raise RuntimeError("vllm is not available, cannot use vllm backend")
|
||||
|
||||
@torch.no_grad()
|
||||
def _initialize_weights(
|
||||
self,
|
||||
free_experts: bool = True,
|
||||
mode: str = "vllm",
|
||||
) -> None:
|
||||
logger.info(f"Initializing weights for {mode} backend")
|
||||
# Stack directly on device where weights already reside (loaded by HF)
|
||||
gate_list: List[torch.Tensor] = []
|
||||
up_list: List[torch.Tensor] = []
|
||||
down_list: List[torch.Tensor] = []
|
||||
|
||||
# Collect weight references without any device moves
|
||||
for expert in self.experts:
|
||||
gate_list.append(expert.gate_proj.weight.data)
|
||||
up_list.append(expert.up_proj.weight.data)
|
||||
down_list.append(expert.down_proj.weight.data)
|
||||
|
||||
gate_w_stacked = torch.stack(gate_list, dim=0).contiguous()
|
||||
up_w_stacked = torch.stack(up_list, dim=0).contiguous()
|
||||
down_w_stacked = torch.stack(down_list, dim=0).contiguous()
|
||||
|
||||
if mode == "flashinfer":
|
||||
w1 = torch.cat([up_w_stacked, gate_w_stacked], dim=1) # FlashInfer expects [up; gate] ordering
|
||||
else:
|
||||
w1 = torch.cat([gate_w_stacked, up_w_stacked], dim=1)
|
||||
w2 = down_w_stacked
|
||||
self._w1 = w1
|
||||
self._w2 = w2
|
||||
|
||||
|
||||
if free_experts:
|
||||
# Free per-expert modules to reclaim memory
|
||||
logger.info(f"Freeing experts for {mode} backend")
|
||||
del self.experts
|
||||
self.experts = None
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass with expert routing and computation."""
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
x = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# Expert routing
|
||||
router_logits = self.gate(x)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.backend == "vllm":
|
||||
routing_weights, selected_experts, *_ = fused_topk_vllm(
|
||||
hidden_states=x,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.norm_topk_prob,
|
||||
)
|
||||
else:
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
# if self.backend == "hf":
|
||||
# final_hidden_states = torch.zeros(
|
||||
# (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
||||
# )
|
||||
|
||||
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
# expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
# for expert_idx in expert_hit:
|
||||
# expert_layer = self.experts[expert_idx]
|
||||
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
# current_state = x[top_x]
|
||||
# current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
# out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
# return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
if self.backend == "hf":
|
||||
# Accumulate buffer: [B*S, H]
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# expert_mask: [E, top_k, tokens]
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.num_experts
|
||||
).permute(2, 1, 0).contiguous()
|
||||
|
||||
# 顺序遍历所有 experts;即使本 rank 没命中也要进入 forward,避免 ZeRO-3 控制流分歧
|
||||
for e in range(self.num_experts):
|
||||
expert_layer = self.experts[int(e)]
|
||||
|
||||
# 取出该 expert 命中的 token 索引
|
||||
idx, top_x = torch.where(expert_mask[e]) # idx∈[0, top_k), shapes: [n_tok_e]
|
||||
current_state = x[top_x] # [n_tok_e, H],n_tok_e 可能为 0
|
||||
# if top_x.numel() == 0:
|
||||
# print("0")
|
||||
|
||||
# 空批照样前向;大多数 Linear/MLP 对 0 行输入是 no-op,但会对齐 ZeRO-3 的参数路径
|
||||
expert_out = expert_layer(current_state) # [n_tok_e, H]
|
||||
|
||||
# 路由权重并加权
|
||||
w = routing_weights[top_x, idx] # [n_tok_e]
|
||||
expert_out = expert_out * w.unsqueeze(-1) # [n_tok_e, H]
|
||||
|
||||
# 累加回全局缓冲;当 n_tok_e=0 时这是合法的空操作
|
||||
final_hidden_states.index_add_(0, top_x, expert_out.to(hidden_states.dtype))
|
||||
|
||||
out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
elif self.backend == "flashinfer":
|
||||
# if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
|
||||
# self._initialize_flashinfer_weights()
|
||||
if self._w1 is None or self._w2 is None:
|
||||
self._initialize_weights(mode="flashinfer")
|
||||
|
||||
result = fused_moe.cutlass_fused_moe(
|
||||
input=x,
|
||||
token_selected_experts=selected_experts.to(torch.int),
|
||||
token_final_scales=routing_weights.to(torch.float32),
|
||||
fc1_expert_weights=self._w1,
|
||||
fc2_expert_weights=self._w2,
|
||||
output_dtype=x.dtype,
|
||||
quant_scales=None,
|
||||
)
|
||||
if isinstance(result, (list, tuple)):
|
||||
out_flat = result[0]
|
||||
else:
|
||||
out_flat = result
|
||||
out = out_flat.view(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
elif self.backend == "sglang":
|
||||
if self._w1 is None or self._w2 is None:
|
||||
self._initialize_weights(mode="sglang")
|
||||
|
||||
topk_output = StandardTopKOutput(
|
||||
topk_weights=routing_weights,
|
||||
topk_ids=selected_experts,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
out_flat = sglang_fused_moe(
|
||||
hidden_states=x,
|
||||
w1=self._w1,
|
||||
w2=self._w2,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
out = out_flat.view(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
elif self.backend == "vllm":
|
||||
if self._w1 is None or self._w2 is None:
|
||||
self._initialize_weights()
|
||||
|
||||
out_flat = fused_experts_vllm(
|
||||
x,
|
||||
self._w1,
|
||||
self._w2,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
)
|
||||
out = out_flat.view(batch_size, sequence_length, hidden_dim)
|
||||
return out, router_logits.view(batch_size, sequence_length, -1)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {self.backend}")
|
||||
|
||||
|
||||
class RND1PreTrainedModel(PreTrainedModel):
|
||||
"""Base class for RND1 models with weight initialization and loading support."""
|
||||
config_class = RND1Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["RND1DecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights using normal distribution."""
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
*model_args,
|
||||
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
use_safetensors: Optional[bool] = None,
|
||||
weights_only: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load pretrained model with generation config."""
|
||||
_model = super().from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
*model_args,
|
||||
config=config,
|
||||
cache_dir=cache_dir,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
use_safetensors=use_safetensors,
|
||||
weights_only=weights_only,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
resume_download = kwargs.get("resume_download", None)
|
||||
proxies = kwargs.get("proxies", None)
|
||||
subfolder = kwargs.get("subfolder", "")
|
||||
from_auto_class = kwargs.get("_from_auto", False)
|
||||
from_pipeline = kwargs.get("_from_pipeline", None)
|
||||
|
||||
_model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
)
|
||||
|
||||
# If configured to use a fused backend, pack fused tensors once after load
|
||||
try:
|
||||
cfg = getattr(_model, "config", None)
|
||||
backend = getattr(cfg, "moe_backend", "hf") if cfg is not None else "hf"
|
||||
if backend in ("sglang", "vllm"):
|
||||
# Walk decoder layers and initialize fused weights
|
||||
model_core = getattr(_model, "model", _model)
|
||||
layers = getattr(model_core, "layers", None)
|
||||
if isinstance(layers, nn.ModuleList):
|
||||
for layer in layers:
|
||||
mlp = getattr(layer, "mlp", None)
|
||||
if hasattr(mlp, "_initialize_weights"):
|
||||
mlp._initialize_weights(
|
||||
free_experts=True,
|
||||
mode=backend,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.warning(f"Backend {backend} weight processing skipped: {_e}")
|
||||
|
||||
return _model
|
||||
|
||||
|
||||
class RND1Model(RND1PreTrainedModel):
|
||||
"""RND1 transformer model with bidirectional attention for diffusion language modeling."""
|
||||
|
||||
def __init__(self, config: RND1Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
|
||||
if config.moe_backend == "vllm":
|
||||
RMSNormClass = VLLMRMSNorm
|
||||
else:
|
||||
RMSNormClass = Qwen3MoeRMSNorm
|
||||
self.norm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> MoeModelOutputWithPast:
|
||||
"""Forward pass through the RND1 model."""
|
||||
|
||||
if (input_ids is None) == (inputs_embeds is None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
|
||||
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
# shape: (batch_size, 1, 1, seq_len)
|
||||
attention_mask = attention_mask.to(dtype=torch.float)[:, None, None, :]
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
|
||||
|
||||
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states, _ = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
router_logits=None,
|
||||
)
|
||||
|
||||
|
||||
class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
|
||||
"""Radical Numerics Diffusion Language Model with bidirectional attention."""
|
||||
|
||||
def __init__(self, config: RND1Config):
|
||||
super().__init__(config)
|
||||
self.model = RND1Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
"""Get the input embeddings layer."""
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
"""Set the input embeddings layer."""
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
"""Get the output embeddings layer (lm_head)."""
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
"""Set the output embeddings layer (lm_head)."""
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@classmethod
|
||||
def can_generate(cls) -> bool:
|
||||
"""Indicates this model can generate text."""
|
||||
return True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> MaskedLMOutput:
|
||||
"""Forward pass with optional loss computation."""
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
)
|
||||
260
dllm/dllm/pipelines/rnd/models/sampling.py
Normal file
260
dllm/dllm/pipelines/rnd/models/sampling.py
Normal file
@ -0,0 +1,260 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
RND1 sampling module for masked diffusion generation.
|
||||
|
||||
This module implements entropy-based token selection for iterative denoising
|
||||
in diffusion language models. Supports both greedy and stochastic sampling
|
||||
with optional prefix/suffix constraints and infilling.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k filtering to logits: with non-top-k values set to -inf
|
||||
"""
|
||||
top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
|
||||
filtered_logits = torch.full_like(logits, float('-inf'))
|
||||
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
|
||||
return filtered_logits
|
||||
|
||||
|
||||
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
|
||||
"""
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above threshold
|
||||
sorted_indices_to_remove = cumulative_probs > p
|
||||
sorted_indices_to_remove[..., 0] = False # Keep at least one token
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
||||
return logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def diffusion_sample(
|
||||
model: nn.Module,
|
||||
seq_len: int = 256,
|
||||
num_steps: int = 256,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: float = 1.0,
|
||||
greedy: bool = True,
|
||||
mask_token_id: int = 151669,
|
||||
prefix_ids: Optional[torch.LongTensor] = None,
|
||||
suffix_ids: Optional[torch.LongTensor] = None,
|
||||
infill_length: Optional[int] = None,
|
||||
eos_token_id: int = 151645,
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
visualizer: Optional[object] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Perform masked diffusion sampling with entropy-based token selection.
|
||||
|
||||
Args:
|
||||
model: The RND1 language model
|
||||
seq_len: Target sequence length
|
||||
num_steps: Number of denoising steps
|
||||
top_k: Optional top-k filtering for sampling (None = no filtering)
|
||||
top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
|
||||
When both top_k and top_p are set, top_k is applied first, then top_p
|
||||
temperature: Temperature for sampling (higher = more random, lower = more deterministic)
|
||||
Values close to 0 are clamped to 1e-8 to avoid division by zero
|
||||
greedy: Whether to use greedy sampling (True) or stochastic (False)
|
||||
mask_token_id: Token ID for masked positions (default: 151669)
|
||||
prefix_ids: Optional prefix token IDs to preserve
|
||||
suffix_ids: Optional suffix token IDs to preserve
|
||||
infill_length: Length of infill region between prefix/suffix
|
||||
eos_token_id: End of sequence token ID (default: 151645)
|
||||
pad_token_id: Padding token ID (default: None, uses 0 if needed)
|
||||
bos_token_id: Beginning of sequence token ID (default: None)
|
||||
device: Device for computation (None = infer from model)
|
||||
visualizer: Optional visualizer for live visualization
|
||||
|
||||
Returns:
|
||||
Generated token IDs as LongTensor
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
if device is None:
|
||||
device = next(model.parameters()).device
|
||||
else:
|
||||
device = torch.device(device)
|
||||
|
||||
if pad_token_id is None:
|
||||
pad_token_id = 0
|
||||
|
||||
# Build initial masked sequence
|
||||
# When prefix_ids is provided, we create a sequence of length seq_len where:
|
||||
# - The prefix occupies the first pre_len positions
|
||||
# - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
|
||||
if prefix_ids is not None or suffix_ids is not None:
|
||||
if prefix_ids is not None:
|
||||
prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device)
|
||||
pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
|
||||
else:
|
||||
pre_len = 0
|
||||
|
||||
if suffix_ids is not None:
|
||||
suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device)
|
||||
suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
|
||||
else:
|
||||
suf_len = 0
|
||||
|
||||
reserved = (1 if eos_token_id is not None else 0)
|
||||
used = pre_len + suf_len + reserved
|
||||
|
||||
if used > seq_len:
|
||||
raise ValueError(
|
||||
f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
|
||||
f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
|
||||
f"Please increase seq_len or reduce input lengths."
|
||||
)
|
||||
elif used == seq_len:
|
||||
raise ValueError(
|
||||
f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
|
||||
f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
|
||||
f"Need at least 1 position for generation."
|
||||
)
|
||||
|
||||
infill_length = min(infill_length or (seq_len - used), seq_len - used)
|
||||
|
||||
x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
|
||||
pos = 0
|
||||
# if bos_token_id is not None:
|
||||
# x[0, pos] = bos_token_id; pos += 1
|
||||
if eos_token_id is not None:
|
||||
x[0, -1] = eos_token_id
|
||||
if pre_len > 0:
|
||||
x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]
|
||||
pos += pre_len
|
||||
fill_start, fill_end = pos, pos + infill_length
|
||||
x[0, fill_start:fill_end] = mask_token_id
|
||||
# print(fill_start, fill_end, seq_len, used, x[0, -1])
|
||||
pos = fill_end
|
||||
if suf_len > 0:
|
||||
x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]
|
||||
pos += suf_len
|
||||
|
||||
init_maskable = torch.zeros_like(x, dtype=torch.bool)
|
||||
init_maskable[0, fill_start:fill_end] = True
|
||||
else:
|
||||
x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
|
||||
if bos_token_id is not None:
|
||||
x[0, 0] = bos_token_id
|
||||
if eos_token_id is not None:
|
||||
x[0, -1] = eos_token_id
|
||||
init_maskable = x.eq(mask_token_id)
|
||||
|
||||
if bos_token_id is not None:
|
||||
init_maskable[:, 0] = False
|
||||
if eos_token_id is not None:
|
||||
init_maskable &= x.ne(eos_token_id)
|
||||
init_maskable &= x.ne(pad_token_id)
|
||||
|
||||
maskable = init_maskable.clone()
|
||||
xt = x.clone()
|
||||
|
||||
if visualizer:
|
||||
visualizer.start_visualization(xt, maskable, num_steps)
|
||||
|
||||
def forward_scores(tokens):
|
||||
"""Compute predictions and entropy scores for next tokens."""
|
||||
# Try with input_ids parameter first (standard HF models)
|
||||
try:
|
||||
model_output = model(input_ids=tokens)
|
||||
except TypeError:
|
||||
# Fall back to positional argument
|
||||
model_output = model(tokens)
|
||||
|
||||
# Apply temperature scaling (with safety for near-zero temperature)
|
||||
safe_temperature = max(temperature, 1e-8) # Prevent division by zero
|
||||
logits = model_output.logits / safe_temperature
|
||||
|
||||
# Apply filtering strategies
|
||||
# Note: When both top_k and top_p are provided, they are applied sequentially:
|
||||
# First top_k filters to k tokens, then top_p filters from those k tokens
|
||||
if top_k is not None and top_k > 0:
|
||||
logits = apply_top_k_filtering(logits, top_k)
|
||||
|
||||
if top_p is not None and 0 < top_p < 1.0:
|
||||
logits = apply_top_p_filtering(logits, top_p)
|
||||
|
||||
# Convert to log probabilities
|
||||
logp = torch.log_softmax(logits, dim=-1)
|
||||
|
||||
# Greedy or stochastic sampling
|
||||
if greedy:
|
||||
pred_next = logp.argmax(-1)
|
||||
else:
|
||||
pred_next = torch.distributions.Categorical(logits=logp).sample()
|
||||
|
||||
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
p = logp.exp()
|
||||
ent_next = -(p * logp).sum(-1)
|
||||
|
||||
# Shift predictions: pos i predicts token i+1
|
||||
pred_i = tokens.clone()
|
||||
conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
|
||||
ent_i = torch.zeros_like(ent_next)
|
||||
|
||||
pred_i[:, 1:] = pred_next[:, :-1]
|
||||
conf_i[:, 1:] = conf_next[:, :-1]
|
||||
ent_i[:, 1:] = ent_next[:, :-1]
|
||||
|
||||
return pred_i, conf_i, ent_i
|
||||
|
||||
pred_i, conf_i, ent_i = forward_scores(xt)
|
||||
total_masked = init_maskable.sum(1, keepdim=True)
|
||||
finf = torch.finfo(conf_i.dtype)
|
||||
|
||||
for step in range(num_steps - 1, 0, -1):
|
||||
rate = step / num_steps
|
||||
cutoff_len = (total_masked * rate).long().clamp(min=0)
|
||||
|
||||
# Choose HIGH-entropy tokens to keep masked
|
||||
sel_scores = ent_i.masked_fill(~maskable, -finf.max)
|
||||
B, L = sel_scores.shape
|
||||
k_max = cutoff_len.max().item()
|
||||
if k_max > 0:
|
||||
sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
|
||||
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
|
||||
for b in range(B):
|
||||
k_b = int(cutoff_len[b].item())
|
||||
if k_b > 0:
|
||||
keep_mask[b, idx[b, :k_b]] = True
|
||||
else:
|
||||
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
|
||||
|
||||
to_unmask = maskable & ~keep_mask
|
||||
if to_unmask.any():
|
||||
xt[to_unmask] = pred_i[to_unmask]
|
||||
maskable[to_unmask] = False
|
||||
|
||||
if visualizer:
|
||||
visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
|
||||
|
||||
if maskable.any():
|
||||
pred_i, conf_i, ent_i = forward_scores(xt)
|
||||
|
||||
if maskable.any():
|
||||
xt[maskable] = pred_i[maskable]
|
||||
|
||||
if visualizer:
|
||||
visualizer.stop_visualization()
|
||||
|
||||
return xt
|
||||
251
dllm/dllm/pipelines/rnd/models/terminal_visualizer.py
Normal file
251
dllm/dllm/pipelines/rnd/models/terminal_visualizer.py
Normal file
@ -0,0 +1,251 @@
|
||||
# Copyright 2025 Radical Numerics Inc.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0, found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Terminal visualization for RND1 generation.
|
||||
|
||||
This module provides real-time visualization of the diffusion denoising process,
|
||||
showing token evolution and generation progress in the terminal using rich
|
||||
formatting when available.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.text import Text
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.layout import Layout
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
|
||||
|
||||
class TerminalVisualizer:
|
||||
"""
|
||||
Rich-based visualization for diffusion process with live updates.
|
||||
|
||||
Provides real-time visualization of the token denoising process during
|
||||
diffusion-based language generation, with colored highlighting of masked
|
||||
positions and progress tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, show_visualization: bool = True):
|
||||
"""
|
||||
Initialize the terminal visualizer.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer for decoding tokens to text
|
||||
show_visualization: Whether to show visualization (requires rich)
|
||||
"""
|
||||
self.tokenizer = tokenizer
|
||||
self.show_visualization = show_visualization and RICH_AVAILABLE
|
||||
if not RICH_AVAILABLE and show_visualization:
|
||||
print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
|
||||
self.show_visualization = False
|
||||
|
||||
if self.show_visualization:
|
||||
self.console = Console()
|
||||
self.live = None
|
||||
self.progress = None
|
||||
self.layout = None
|
||||
else:
|
||||
self.pbar = None
|
||||
|
||||
self.current_tokens = None
|
||||
self.mask_positions = None
|
||||
self.total_steps = 0
|
||||
self.current_step = 0
|
||||
|
||||
def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
|
||||
"""
|
||||
Start the visualization.
|
||||
|
||||
Args:
|
||||
initial_tokens: Initial token IDs (possibly masked)
|
||||
mask_positions: Boolean mask indicating which positions are masked
|
||||
total_steps: Total number of diffusion steps
|
||||
"""
|
||||
if not self.show_visualization:
|
||||
self.pbar = tqdm(total=total_steps, desc="Diffusion")
|
||||
return
|
||||
|
||||
self.current_tokens = initial_tokens.clone()
|
||||
self.mask_positions = mask_positions
|
||||
self.total_steps = total_steps
|
||||
self.current_step = 0
|
||||
|
||||
self.layout = Layout()
|
||||
self.layout.split_column(
|
||||
Layout(name="header", size=3),
|
||||
Layout(name="text", ratio=1),
|
||||
Layout(name="progress", size=3)
|
||||
)
|
||||
|
||||
self.progress = Progress(
|
||||
TextColumn("[bold blue]Diffusion"),
|
||||
BarColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TextColumn("•"),
|
||||
TextColumn("[cyan]Masks: {task.fields[masks]}"),
|
||||
TimeRemainingColumn(),
|
||||
)
|
||||
self.progress_task = self.progress.add_task(
|
||||
"Generating",
|
||||
total=total_steps,
|
||||
masks=mask_positions.sum().item()
|
||||
)
|
||||
|
||||
self.live = Live(self.layout, console=self.console, refresh_per_second=4)
|
||||
self.live.start()
|
||||
self._update_display()
|
||||
|
||||
def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int,
|
||||
entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None):
|
||||
"""
|
||||
Update visualization for current step.
|
||||
|
||||
Args:
|
||||
tokens: Current token IDs
|
||||
maskable: Boolean mask of remaining masked positions
|
||||
step: Current step number
|
||||
entropy: Optional entropy scores for each position
|
||||
confidence: Optional confidence scores for each position
|
||||
"""
|
||||
if not self.show_visualization:
|
||||
if self.pbar:
|
||||
self.pbar.update(1)
|
||||
masks = maskable.sum().item() if maskable is not None else 0
|
||||
self.pbar.set_postfix({'masks': masks})
|
||||
return
|
||||
|
||||
self.current_tokens = tokens.clone()
|
||||
self.mask_positions = maskable
|
||||
self.current_step = step
|
||||
|
||||
masks_remaining = maskable.sum().item() if maskable is not None else 0
|
||||
self.progress.update(
|
||||
self.progress_task,
|
||||
advance=1,
|
||||
masks=masks_remaining
|
||||
)
|
||||
|
||||
self._update_display()
|
||||
|
||||
def _update_display(self):
|
||||
"""Update the live display."""
|
||||
if not self.live:
|
||||
return
|
||||
|
||||
header = Text("RND1-Base Generation", style="bold magenta", justify="center")
|
||||
self.layout["header"].update(Panel(header, border_style="bright_blue"))
|
||||
|
||||
text_display = self._format_text_with_masks()
|
||||
self.layout["text"].update(
|
||||
Panel(
|
||||
text_display,
|
||||
title="[bold]Generated Text",
|
||||
subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
|
||||
border_style="cyan"
|
||||
)
|
||||
)
|
||||
|
||||
self.layout["progress"].update(Panel(self.progress))
|
||||
|
||||
def _format_text_with_masks(self) -> Text:
|
||||
"""
|
||||
Format text with colored masks.
|
||||
|
||||
Returns:
|
||||
Rich Text object with formatted tokens
|
||||
"""
|
||||
text = Text()
|
||||
|
||||
if self.current_tokens is None:
|
||||
return text
|
||||
|
||||
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
|
||||
mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
|
||||
# Alternate colors for visual effect
|
||||
text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red")
|
||||
else:
|
||||
try:
|
||||
token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
|
||||
# Skip special tokens in display
|
||||
if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]:
|
||||
# Color based on position
|
||||
text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
|
||||
except:
|
||||
continue
|
||||
|
||||
return text
|
||||
|
||||
def stop_visualization(self):
|
||||
"""Stop the visualization and display final result."""
|
||||
if not self.show_visualization:
|
||||
if self.pbar:
|
||||
self.pbar.close()
|
||||
print("\n✨ Generation complete!\n")
|
||||
return
|
||||
|
||||
if self.live:
|
||||
self.live.stop()
|
||||
|
||||
self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
|
||||
|
||||
# Display final text
|
||||
if self.current_tokens is not None:
|
||||
try:
|
||||
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
|
||||
final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
|
||||
self.console.print(Panel(
|
||||
final_text,
|
||||
title="[bold]Final Generated Text",
|
||||
border_style="green",
|
||||
padding=(1, 2)
|
||||
))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class SimpleProgressBar:
|
||||
"""
|
||||
Simple progress bar fallback when rich is not available.
|
||||
|
||||
Provides basic progress tracking using tqdm when the rich library
|
||||
is not installed.
|
||||
"""
|
||||
|
||||
def __init__(self, total_steps: int):
|
||||
"""
|
||||
Initialize simple progress bar.
|
||||
|
||||
Args:
|
||||
total_steps: Total number of steps
|
||||
"""
|
||||
self.pbar = tqdm(total=total_steps, desc="Diffusion")
|
||||
|
||||
def update(self, masks_remaining: int = 0):
|
||||
"""
|
||||
Update progress bar.
|
||||
|
||||
Args:
|
||||
masks_remaining: Number of masks still remaining
|
||||
"""
|
||||
self.pbar.update(1)
|
||||
self.pbar.set_postfix({'masks': masks_remaining})
|
||||
|
||||
def close(self):
|
||||
"""Close the progress bar."""
|
||||
self.pbar.close()
|
||||
print("\n✨ Generation complete!\n")
|
||||
23
dllm/dllm/pipelines/rnd/trainer.py
Normal file
23
dllm/dllm/pipelines/rnd/trainer.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from dllm.core.trainers import MDLMTrainer
|
||||
|
||||
|
||||
class RNDTrainer(MDLMTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _preprocess_inputs(self, inputs):
|
||||
labels = inputs["labels"]
|
||||
assert (labels[:, 0] == -100).all()
|
||||
|
||||
def _postprocess_outputs(self, outputs):
|
||||
logits = outputs.logits
|
||||
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
242
dllm/dllm/tools/chat.py
Normal file
242
dllm/dllm/tools/chat.py
Normal file
@ -0,0 +1,242 @@
|
||||
import shutil
|
||||
from typing import List, Literal
|
||||
|
||||
import textwrap
|
||||
|
||||
import dllm
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Utility helpers
|
||||
# ============================================================
|
||||
|
||||
try:
|
||||
L = shutil.get_terminal_size().columns
|
||||
if not isinstance(L, int) or L <= 0:
|
||||
L = 120
|
||||
except Exception:
|
||||
L = 120
|
||||
DIV = "=" * L
|
||||
SUB = "-" * L
|
||||
|
||||
|
||||
def banner_line(text: str, width: int = L, fill: str = "=") -> str:
|
||||
"""Return a centered banner line with given width and fill."""
|
||||
text = f" {text.strip()} "
|
||||
fill_len = width - len(text)
|
||||
if fill_len <= 0:
|
||||
return text
|
||||
left = fill_len // 2
|
||||
right = fill_len - left
|
||||
return f"{fill * left}{text}{fill * right}"
|
||||
|
||||
|
||||
def print_wrapped(text: str, width: int = L):
|
||||
"""Print text with automatic line wrapping."""
|
||||
wrapped = textwrap.fill(text, width=width)
|
||||
print(wrapped)
|
||||
|
||||
|
||||
def boxed(text: str, width: int = L, padding: int = 1):
|
||||
"""Render a centered box with the given text and width."""
|
||||
lines = text.splitlines()
|
||||
content_width = max(len(line) for line in lines)
|
||||
box_width = min(width, content_width + padding * 2 + 2)
|
||||
|
||||
# compute left margin for centering
|
||||
terminal_width = width
|
||||
left_margin = max((terminal_width - box_width) // 2, 0)
|
||||
margin = " " * left_margin
|
||||
|
||||
top = margin + "┌" + "─" * (box_width - 2) + "┐"
|
||||
bottom = margin + "└" + "─" * (box_width - 2) + "┘"
|
||||
|
||||
print(top)
|
||||
for line in lines:
|
||||
inner = line.center(content_width)
|
||||
print(margin + "│" + " " * padding + inner + " " * padding + "│")
|
||||
print(bottom)
|
||||
|
||||
|
||||
def decode_trim(tokenizer, seq_ids_list, input_ids_list) -> str:
|
||||
"""
|
||||
Return only the generated text, truncated at the first EOS **after** the prompt.
|
||||
|
||||
Args:
|
||||
tokenizer: HF tokenizer with eos_token_id / pad_token_id.
|
||||
seq_ids: Full sequence token ids from the model (prompt + generation).
|
||||
input_ids: The prompt token ids that were fed into the model.
|
||||
|
||||
Behavior:
|
||||
- Finds the first eos_token_id that occurs at or after len(input_ids).
|
||||
- Slices generation up to (but not including) that EOS.
|
||||
- Decodes only the generation span, skipping special/pad tokens.
|
||||
"""
|
||||
# Make sure we can index these
|
||||
sequences = []
|
||||
for seq_ids, input_ids in zip(seq_ids_list, input_ids_list):
|
||||
full = list(seq_ids)
|
||||
prompt = list(input_ids)
|
||||
|
||||
# Skip left padding tokens (necessary for dream)
|
||||
pad_id = getattr(tokenizer, "pad_token_id", None)
|
||||
if pad_id is not None:
|
||||
while full and full[0] == pad_id:
|
||||
full.pop(0)
|
||||
|
||||
start = len(prompt)
|
||||
end = len(full)
|
||||
|
||||
eos_id = getattr(tokenizer, "eos_token_id", None)
|
||||
eot_id = getattr(tokenizer, "eot_token_id", None)
|
||||
if eos_id is not None:
|
||||
for i in range(start, len(full)):
|
||||
if full[i] in (eos_id, eot_id):
|
||||
end = i
|
||||
break
|
||||
|
||||
gen_ids = full[start:end]
|
||||
text = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
||||
# in case there is no eos_id or eot_id, just strings
|
||||
eos = getattr(tokenizer, "eos_token", None)
|
||||
eot = getattr(tokenizer, "eot_token", None)
|
||||
if eos:
|
||||
text = text.split(eos)[0]
|
||||
if eot:
|
||||
text = text.split(eot)[0]
|
||||
# return text.strip()
|
||||
sequences.append(text)
|
||||
return sequences
|
||||
|
||||
|
||||
def render_menu(round_idx: int):
|
||||
"""Render a boxed menu of possible actions."""
|
||||
if round_idx == 0:
|
||||
text = (
|
||||
"Possible next actions:\n"
|
||||
"[1] Continue this chat\n"
|
||||
"[2] End this chat and start a new one\n"
|
||||
"[3] Exit"
|
||||
)
|
||||
else:
|
||||
text = (
|
||||
f"(Round {round_idx})\n"
|
||||
"Possible next actions:\n"
|
||||
"[1] Continue this chat\n"
|
||||
"[2] End this chat and start a new one\n"
|
||||
"[3] Exit"
|
||||
)
|
||||
|
||||
print() # spacing
|
||||
boxed(text)
|
||||
|
||||
|
||||
def prompt_choice() -> Literal["1", "2", "3"]:
|
||||
while True:
|
||||
print("Select action [1/2/3]: ")
|
||||
choice = input().strip()
|
||||
if choice in ("1", "2", "3"):
|
||||
return choice
|
||||
print(banner_line("<Invalid choice. Please type 1, 2, or 3.>", fill=" "))
|
||||
|
||||
|
||||
def build_chat_inputs(tokenizer, messages: List[dict], add_generation_prompt: bool):
|
||||
"""Tokenize chat messages into inputs tensor."""
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
)
|
||||
|
||||
|
||||
def visualize_histories(tokenizer, histories):
|
||||
try:
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
terminal_visualizer.visualize(histories, rich=True)
|
||||
except Exception as e:
|
||||
print(f"(Visualization skipped: {e})")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Modes
|
||||
# ============================================================
|
||||
def single_turn_generate(generator, gen_config, visualize: bool):
|
||||
print()
|
||||
print(banner_line("continuation mode"))
|
||||
model, tokenizer = generator.model, generator.tokenizer
|
||||
|
||||
while True:
|
||||
print(banner_line("<Type your prompt below. Press Ctrl+C to exit.>", fill=" "))
|
||||
try:
|
||||
# user_text = input("Prompt > ").strip()
|
||||
print("[Prompt] > ")
|
||||
user_text = input().strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n" + banner_line("Exiting. Bye!", width=len(DIV)))
|
||||
return
|
||||
|
||||
# if not user_text:
|
||||
# print("(Empty input, skipped)\n")
|
||||
# continue
|
||||
|
||||
inputs = tokenizer([user_text], add_special_tokens=False)["input_ids"]
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
text = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
|
||||
|
||||
print(banner_line("Output"))
|
||||
print_wrapped(text if text else "<empty>")
|
||||
print(DIV + "\n")
|
||||
|
||||
if visualize:
|
||||
visualize_histories(tokenizer, outputs.histories)
|
||||
|
||||
|
||||
def multi_turn_chat(generator, gen_config, visualize: bool):
|
||||
# """Chat mode with chat template & message history."""
|
||||
print()
|
||||
print(banner_line("multi-turn chat mode"))
|
||||
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
|
||||
model, tokenizer = generator.model, generator.tokenizer
|
||||
|
||||
messages: List[dict] = []
|
||||
round_idx = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
print("[You]:")
|
||||
user_msg = input().strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nExiting. Bye!")
|
||||
return
|
||||
|
||||
messages.append({"role": "user", "content": user_msg})
|
||||
inputs = build_chat_inputs(tokenizer, [messages], add_generation_prompt=True)
|
||||
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
reply = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
|
||||
|
||||
print(DIV)
|
||||
print_wrapped("[Assistant]: " + reply if reply else "<empty>")
|
||||
print(DIV + "\n")
|
||||
|
||||
messages.append({"role": "assistant", "content": reply})
|
||||
|
||||
if visualize:
|
||||
visualize_histories(tokenizer, outputs.histories)
|
||||
|
||||
render_menu(round_idx)
|
||||
choice = prompt_choice()
|
||||
if choice == "1":
|
||||
print(banner_line("<Type your message.>", fill=" "))
|
||||
round_idx += 1
|
||||
continue
|
||||
elif choice == "2":
|
||||
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
|
||||
messages = []
|
||||
round_idx = 0
|
||||
continue
|
||||
else:
|
||||
print("\nExiting. Bye!")
|
||||
return
|
||||
30
dllm/dllm/tools/download_hf_dataset.py
Normal file
30
dllm/dllm/tools/download_hf_dataset.py
Normal file
@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import tyro
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
dataset_id: str = "Anthropic/hh-rlhf"
|
||||
allow_patterns: str = None
|
||||
|
||||
|
||||
script_args = tyro.cli(ScriptArguments)
|
||||
|
||||
# Replace with the dataset repo you want, e.g. "wikitext"
|
||||
dataset_id = script_args.dataset_id
|
||||
|
||||
# Replace with your desired local directory
|
||||
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/datasets/huggingface/{dataset_id}"
|
||||
|
||||
# Download the dataset snapshot
|
||||
snapshot_download(
|
||||
repo_id=dataset_id,
|
||||
repo_type="dataset", # 👈 tell HF it's a dataset
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=False, # ensures real files, not symlinks
|
||||
allow_patterns=script_args.allow_patterns,
|
||||
)
|
||||
|
||||
print(f"Dataset downloaded to: {local_dir}")
|
||||
27
dllm/dllm/tools/download_hf_model.py
Normal file
27
dllm/dllm/tools/download_hf_model.py
Normal file
@ -0,0 +1,27 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import tyro
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_id: str = "GSAI-ML/LLaDA-8B-Instruct"
|
||||
|
||||
|
||||
script_args = tyro.cli(ScriptArguments)
|
||||
|
||||
# Replace with the model repo you want, e.g. "bert-base-uncased"
|
||||
model_id = script_args.model_id
|
||||
|
||||
# Replace with your desired local directory
|
||||
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/models/huggingface/{model_id}"
|
||||
|
||||
# Download the model snapshot
|
||||
snapshot_download(
|
||||
repo_id=model_id,
|
||||
local_dir=local_dir,
|
||||
local_dir_use_symlinks=False, # ensures real files, not symlinks
|
||||
)
|
||||
|
||||
print(f"Model downloaded to: {local_dir}")
|
||||
1
dllm/dllm/tools/generate.py
Normal file
1
dllm/dllm/tools/generate.py
Normal file
@ -0,0 +1 @@
|
||||
# TODO
|
||||
80
dllm/dllm/tools/merge_peft_adapter.py
Normal file
80
dllm/dllm/tools/merge_peft_adapter.py
Normal file
@ -0,0 +1,80 @@
|
||||
"""
|
||||
Merge a PEFT/LoRA adapter into its base model (auto-detected from adapter_config.json).
|
||||
|
||||
Usage:
|
||||
python dllm_trainer/tools/merge_peft_adapter.py \
|
||||
--adapter_model_name_or_path your-org/your-lora \
|
||||
--output_model_name_or_path ./merged-model \
|
||||
--dtype bf16
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModel, AutoTokenizer, HfArgumentParser
|
||||
|
||||
import dllm # so that no need to trust_remote_code
|
||||
|
||||
DTYPE_MAP = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
adapter_model_name_or_path: str | None = field(
|
||||
default=None, metadata={"help": "Adapter repo or local path"}
|
||||
)
|
||||
output_model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to save the merged model (folder or repo id)"},
|
||||
)
|
||||
dtype: str | None = field(default="fp16", metadata={"help": "fp16|bf16|fp32"})
|
||||
push_to_hub: bool | None = field(
|
||||
default=False, metadata={"help": "Push merged weights to the Hub"}
|
||||
)
|
||||
# Optional override if adapter config lacks base info:
|
||||
base_model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Override base model if adapter config lacks it"},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
assert args.adapter_model_name_or_path, "please provide the adapter (repo or path)"
|
||||
assert args.output_model_name_or_path, "please provide output_model_name_or_path"
|
||||
assert args.dtype in DTYPE_MAP, f"dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
|
||||
# Read base path from adapter_config.json
|
||||
peft_cfg = PeftConfig.from_pretrained(args.adapter_model_name_or_path)
|
||||
base_id = args.base_model_name_or_path or getattr(
|
||||
peft_cfg, "base_model_name_or_path", None
|
||||
)
|
||||
assert base_id, (
|
||||
"adapter_config.json does not include base_model_name_or_path; "
|
||||
"pass --base_model_name_or_path to override."
|
||||
)
|
||||
|
||||
# Load base model and tokenizer
|
||||
model = AutoModel.from_pretrained(
|
||||
base_id, return_dict=True, dtype=DTYPE_MAP[args.dtype]
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_id)
|
||||
|
||||
# Attach adapter, merge, and unload PEFT layers
|
||||
model = PeftModel.from_pretrained(model, args.adapter_model_name_or_path)
|
||||
model.eval()
|
||||
model = model.merge_and_unload() # plain transformers model
|
||||
|
||||
# Save locally
|
||||
model.save_pretrained(args.output_model_name_or_path)
|
||||
tokenizer.save_pretrained(args.output_model_name_or_path)
|
||||
|
||||
print(f"✓ merged model saved to: {args.output_model_name_or_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
dllm/dllm/tools/preprocess_pt_dataset.py
Normal file
109
dllm/dllm/tools/preprocess_pt_dataset.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""
|
||||
python dllm/tools/preprocess_pt_dataset.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, asdict
|
||||
from functools import partial
|
||||
|
||||
import datasets
|
||||
import tyro
|
||||
import transformers
|
||||
from pprint import pprint
|
||||
|
||||
import dllm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""Preprocess PT dataset"""
|
||||
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
dataset_args: str = "OpenCoder-LLM/opc-annealing-corpus[lang:python]" # required
|
||||
output_dir: str = "data/pt/modernbert/opc-annealing-corpus[lang:python]" # required
|
||||
text_field: str = "text"
|
||||
max_length: int = 1024
|
||||
insert_eos: bool = True
|
||||
drop_tail: bool = True
|
||||
remove_columns: bool = False
|
||||
num_proc: int = 32
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
def preprocess_pt_dataset(
|
||||
dataset: datasets.DatasetDict,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
output_dir: str,
|
||||
text_field: str = "text",
|
||||
max_length: int = 1024,
|
||||
insert_eos: bool = True,
|
||||
drop_tail: bool = True,
|
||||
remove_columns: bool = False,
|
||||
num_proc: int = 32,
|
||||
):
|
||||
processed = dataset.map(
|
||||
partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=text_field,
|
||||
seq_length=max_length,
|
||||
insert_eos=insert_eos,
|
||||
drop_tail=drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
)
|
||||
|
||||
# Keep only the three required columns to save space.
|
||||
if remove_columns:
|
||||
keep = {"input_ids", "labels"}
|
||||
|
||||
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
|
||||
drop = [c for c in ds.column_names if c not in keep]
|
||||
return ds.remove_columns(drop) if drop else ds
|
||||
|
||||
if isinstance(processed, datasets.DatasetDict):
|
||||
for split in list(processed.keys()):
|
||||
processed[split] = strip_cols(processed[split])
|
||||
else:
|
||||
processed = strip_cols(processed)
|
||||
|
||||
output_dir = os.path.join(
|
||||
output_dir,
|
||||
f"max_length-{max_length}-insert_eos-{insert_eos}-drop_tail-{drop_tail}",
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
processed.save_to_disk(output_dir)
|
||||
print(f"[OK] Saved to: {output_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse with tyro
|
||||
args = tyro.cli(ScriptArguments)
|
||||
dllm.utils.print_args(args)
|
||||
|
||||
tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# Load your raw dataset (must contain a "messages" field per example).
|
||||
dataset = dllm.data.load_pt_dataset(args.dataset_args, streaming=False)
|
||||
|
||||
preprocess_pt_dataset(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
output_dir=args.output_dir,
|
||||
text_field=args.text_field,
|
||||
max_length=args.max_length,
|
||||
insert_eos=args.insert_eos,
|
||||
drop_tail=args.drop_tail,
|
||||
remove_columns=args.remove_columns,
|
||||
num_proc=args.num_proc,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
117
dllm/dllm/tools/preprocess_sft_dataset.py
Normal file
117
dllm/dllm/tools/preprocess_sft_dataset.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
Example:
|
||||
|
||||
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--num_proc 64
|
||||
"""
|
||||
|
||||
import os
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
|
||||
import datasets
|
||||
import tyro
|
||||
|
||||
import dllm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""Preprocess SFT dataset"""
|
||||
|
||||
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
|
||||
sft_map_fn_path: str = "dllm.utils.default_sft_map_fn"
|
||||
dataset_args: str = "HuggingFaceTB/smoltalk" # required
|
||||
output_dir: str = "data/sft/llada/smoltalk" # required
|
||||
mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
|
||||
num_proc: int = 32
|
||||
remove_columns: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
def preprocess_sft_dataset(
|
||||
dataset: datasets.DatasetDict,
|
||||
map_fn: callable,
|
||||
output_dir: str,
|
||||
remove_columns: bool = False,
|
||||
num_proc: int = 32,
|
||||
):
|
||||
processed = dataset.map(
|
||||
map_fn,
|
||||
batched=False,
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=True,
|
||||
writer_batch_size=512,
|
||||
desc="offline preprocessing",
|
||||
)
|
||||
|
||||
# Keep only the three required columns to save space.
|
||||
if remove_columns:
|
||||
keep = {"input_ids", "labels", "prompt_len", "attention_mask"}
|
||||
|
||||
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
|
||||
drop = [c for c in ds.column_names if c not in keep]
|
||||
return ds.remove_columns(drop) if drop else ds
|
||||
|
||||
if isinstance(processed, datasets.DatasetDict):
|
||||
for split in list(processed.keys()):
|
||||
processed[split] = strip_cols(processed[split])
|
||||
else:
|
||||
processed = strip_cols(processed)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
processed.save_to_disk(output_dir)
|
||||
print(f"[OK] Saved to: {output_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse with tyro
|
||||
args = tyro.cli(ScriptArguments)
|
||||
dllm.utils.print_args(args)
|
||||
|
||||
tokenizer = dllm.utils.get_tokenizer(args)
|
||||
|
||||
# Load your raw dataset (must contain a "messages" field per example).
|
||||
dataset = dllm.data.load_sft_dataset(args.dataset_args)
|
||||
|
||||
# 4. Dynamically import the function based on the argument
|
||||
try:
|
||||
# Split the path into module and function name
|
||||
module_path, function_name = args.sft_map_fn_path.rsplit(".", 1)
|
||||
|
||||
# Import the module
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Get the function from the module
|
||||
sft_map_fn = getattr(module, function_name)
|
||||
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
print(f"Error: Could not import '{args.sft_map_fn_path}'.")
|
||||
print(f"Details: {e}")
|
||||
return
|
||||
|
||||
map_fn = partial(
|
||||
sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=args.mask_prompt_loss,
|
||||
)
|
||||
preprocess_sft_dataset(
|
||||
dataset=dataset,
|
||||
map_fn=map_fn,
|
||||
output_dir=args.output_dir,
|
||||
remove_columns=args.remove_columns,
|
||||
num_proc=args.num_proc,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
dllm/dllm/utils/__init__.py
Normal file
6
dllm/dllm/utils/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from . import configs, generation_utils, model_utils, utils
|
||||
from .configs import *
|
||||
from .generation_utils import *
|
||||
from .data_utils import *
|
||||
from .model_utils import *
|
||||
from .utils import *
|
||||
77
dllm/dllm/utils/configs.py
Normal file
77
dllm/dllm/utils/configs.py
Normal file
@ -0,0 +1,77 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
|
||||
from dllm.utils.utils import resolve_with_base_env, get_default_logger
|
||||
|
||||
logger = get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = None # overwrite this
|
||||
dtype: str = "bfloat16"
|
||||
load_in_4bit: bool = False
|
||||
attn_implementation: str = None
|
||||
# --- fold PEFT args here ---
|
||||
lora: bool = False
|
||||
target_modules: str = "all-linear"
|
||||
r: int = 32
|
||||
lora_alpha: int = 64
|
||||
lora_dropout: float = 0.05
|
||||
bias: str = "none"
|
||||
modules_to_save: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset_args: str = None # overwrite this
|
||||
num_proc: int = 8
|
||||
disable_caching: bool = False
|
||||
max_length: int = 1024
|
||||
truncation: str = field(
|
||||
default="right",
|
||||
metadata={
|
||||
"help": (
|
||||
'The truncation strategy to use ("filter" or "right"). '
|
||||
'"filter" only keeps sequences that are shorter than max_length; '
|
||||
'"right" only keeps the rightmost max_length tokens for each sequence.'
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(transformers.TrainingArguments):
|
||||
output_dir: str = None # overwrite this
|
||||
report_to: str = "wandb"
|
||||
overwrite_output_dir: bool = True
|
||||
seed: int = 42
|
||||
per_device_train_batch_size: int = 4
|
||||
per_device_eval_batch_size: int = 4
|
||||
gradient_accumulation_steps: int = 1
|
||||
learning_rate: float = 2e-5
|
||||
lr_scheduler_type: str = "cosine"
|
||||
warmup_ratio: float = 0.1
|
||||
bf16: bool = True
|
||||
num_train_epochs: float = 4
|
||||
logging_steps: float = 10
|
||||
eval_on_start: bool = False
|
||||
eval_strategy: str = "steps"
|
||||
eval_steps: float = 0.25
|
||||
save_steps: float = 0.25
|
||||
save_only_model: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.group_by_length:
|
||||
logger.info(
|
||||
"training_args.group_by_length=True: preprocessing "
|
||||
"may take some time after `trainer.train()` starts."
|
||||
)
|
||||
222
dllm/dllm/utils/data_utils.py
Normal file
222
dllm/dllm/utils/data_utils.py
Normal file
@ -0,0 +1,222 @@
|
||||
import random
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import datasets
|
||||
import transformers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
|
||||
|
||||
|
||||
def tokenize_and_group(
|
||||
examples,
|
||||
tokenizer,
|
||||
text_field: str = "text",
|
||||
seq_length: int = 1024,
|
||||
insert_eos: bool = False,
|
||||
drop_tail: bool = True,
|
||||
add_special_tokens: bool = False,
|
||||
):
|
||||
# 1) Tokenize (batched input)
|
||||
tokenized = tokenizer(examples[text_field], add_special_tokens=add_special_tokens)
|
||||
ids = tokenized["input_ids"]
|
||||
|
||||
# --- optionally append EOS to each sample ---
|
||||
if insert_eos:
|
||||
eos_id = getattr(tokenizer, "eos_token_id")
|
||||
assert eos_id
|
||||
# append EOS only if the sample doesn't already end with it
|
||||
ids = [seq + ([] if (seq and seq[-1] == eos_id) else [eos_id]) for seq in ids]
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
# 2) Flatten and concatenate all token lists
|
||||
concatenated = list(chain.from_iterable(ids))
|
||||
if not concatenated:
|
||||
return {"input_ids": [], "labels": []} # Safe return for empty batch
|
||||
|
||||
# 3) Calculate the total length based on drop_tail
|
||||
if drop_tail:
|
||||
total_len = (len(concatenated) // seq_length) * seq_length
|
||||
concatenated = concatenated[:total_len] # Truncate the last incomplete chunk
|
||||
else:
|
||||
total_len = len(concatenated)
|
||||
|
||||
# Split into fixed-length chunks
|
||||
chunks = [concatenated[i : i + seq_length] for i in range(0, total_len, seq_length)]
|
||||
|
||||
return {
|
||||
"input_ids": chunks,
|
||||
"labels": [c[:] for c in chunks], # Labels are the same as input_ids
|
||||
}
|
||||
|
||||
|
||||
def clip_row(row: dict, max_length: int, truncation: str = "right") -> dict:
|
||||
for key in ("input_ids", "labels", "attention_mask"):
|
||||
if key in row:
|
||||
if truncation == "right":
|
||||
row[key] = row[key][:max_length]
|
||||
elif truncation == "left":
|
||||
row[key] = row[key][-max_length:]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return row
|
||||
|
||||
|
||||
def post_process_dataset(
|
||||
dataset: datasets.DatasetDict, data_args: "DataArguments"
|
||||
) -> datasets.DatasetDict:
|
||||
if data_args.truncation == "filter":
|
||||
return dataset.filter(
|
||||
lambda row: len(row["input_ids"]) <= data_args.max_length,
|
||||
num_proc=data_args.num_proc,
|
||||
desc=f"Filtering samples with length <= {data_args.max_length}",
|
||||
)
|
||||
elif data_args.truncation == "right":
|
||||
# do this only if dataset has "prompt_len"
|
||||
if "prompt_len" in dataset.column_names["train"]:
|
||||
dataset = dataset.filter(
|
||||
lambda row: row["prompt_len"] <= data_args.max_length,
|
||||
num_proc=data_args.num_proc,
|
||||
desc=f"Filtering samples with `prompt_len` <= {data_args.max_length}",
|
||||
)
|
||||
return dataset.map(
|
||||
lambda row: clip_row(row, data_args.max_length, truncation="right"),
|
||||
num_proc=data_args.num_proc,
|
||||
desc=f"Right-truncating samples to max_length={data_args.max_length}",
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def clip_row_streaming(row: dict, max_length: int, truncation: str = "right") -> dict:
|
||||
"""Clip whole sequence OR (if prompt_len present) preserve prompt and clip only the response."""
|
||||
if truncation not in {"right", "left"}:
|
||||
raise NotImplementedError(f"Unknown truncation: {truncation}")
|
||||
|
||||
def clip(seq):
|
||||
return seq[:max_length] if truncation == "right" else seq[-max_length:]
|
||||
|
||||
def clip_preserve_prompt(seq, prompt_len: int):
|
||||
prompt = seq[:prompt_len]
|
||||
resp = seq[prompt_len:]
|
||||
budget = max(0, max_length - len(prompt))
|
||||
resp = resp[:budget] if truncation == "right" else resp[-budget:]
|
||||
return prompt + resp
|
||||
|
||||
prompt_len = row.get("prompt_len", None)
|
||||
for k in ("input_ids", "labels", "attention_mask"):
|
||||
if k in row and isinstance(row[k], list):
|
||||
row[k] = (
|
||||
clip_preserve_prompt(row[k], prompt_len)
|
||||
if isinstance(prompt_len, int) and prompt_len >= 0
|
||||
else clip(row[k])
|
||||
)
|
||||
return row
|
||||
|
||||
|
||||
def post_process_dataset_streaming(
|
||||
dataset: datasets.IterableDatasetDict,
|
||||
data_args: "DataArguments",
|
||||
) -> datasets.IterableDatasetDict:
|
||||
|
||||
def _train_has_prompt_len_streaming(dataset: datasets.IterableDatasetDict) -> bool:
|
||||
"""Replicates: 'if \"prompt_len\" in dataset.column_names[\"train\"]' for streaming."""
|
||||
it = dataset["train"].take(1)
|
||||
try:
|
||||
ex = next(iter(it))
|
||||
except StopIteration:
|
||||
return False
|
||||
return "prompt_len" in ex
|
||||
|
||||
mode = data_args.truncation
|
||||
max_len = data_args.max_length
|
||||
|
||||
if mode == "filter":
|
||||
# Keep rows with len(input_ids) <= max_len (emulate .filter with generator map)
|
||||
def keep_if_short(row):
|
||||
if (
|
||||
"input_ids" in row
|
||||
and isinstance(row["input_ids"], list)
|
||||
and len(row["input_ids"]) <= max_len
|
||||
):
|
||||
yield row # keep
|
||||
# else: drop (yield nothing)
|
||||
|
||||
return datasets.IterableDatasetDict(
|
||||
{name: ds.map(keep_if_short) for name, ds in dataset.items()}
|
||||
)
|
||||
|
||||
elif mode == "right":
|
||||
ds_out = dataset
|
||||
|
||||
# Do this only if TRAIN split has "prompt_len" (same condition as your non-streaming code)
|
||||
if _train_has_prompt_len_streaming(ds_out):
|
||||
|
||||
def keep_if_prompt_fits(row):
|
||||
pl = row.get("prompt_len", None)
|
||||
if isinstance(pl, int) and pl <= max_len:
|
||||
yield row # keep
|
||||
elif pl is None:
|
||||
# If a row lacks prompt_len but train had it, the non-streaming code would try to access it and fail.
|
||||
# Here we conservatively drop such rows to mirror "requires prompt_len <= max_len".
|
||||
return
|
||||
# else: drop
|
||||
|
||||
ds_out = datasets.IterableDatasetDict(
|
||||
{name: ds.map(keep_if_prompt_fits) for name, ds in ds_out.items()}
|
||||
)
|
||||
|
||||
# Then clip right (same clipping as clip_row)
|
||||
def clip_right(row):
|
||||
return clip_row(row, max_len, truncation="right")
|
||||
|
||||
return datasets.IterableDatasetDict(
|
||||
{name: ds.map(clip_right) for name, ds in ds_out.items()}
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoAttentionMaskCollator(transformers.DataCollatorForSeq2Seq):
|
||||
def __call__(self, features, return_tensors=None):
|
||||
outputs = super().__call__(features, return_tensors)
|
||||
# fintune on padding <eos_token>; should not mask them out
|
||||
outputs.pop("attention_mask")
|
||||
return outputs
|
||||
|
||||
|
||||
def default_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
|
||||
"""
|
||||
Build input_ids and labels for SFT.
|
||||
|
||||
Args:
|
||||
row: a dataset row with `messages`
|
||||
tokenizer: a HF tokenizer
|
||||
mask_prompt_loss: whether to mask prompt tokens (set their labels to -100)
|
||||
|
||||
Returns:
|
||||
dict with keys: input_ids, labels, and optionally prompt_len
|
||||
"""
|
||||
prompt_response_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"], tokenize=True, add_generation_prompt=False
|
||||
)
|
||||
labels = prompt_response_tokens.copy()
|
||||
|
||||
if mask_prompt_loss:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
row["messages"][:-1], tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
|
||||
return {
|
||||
"input_ids": prompt_response_tokens,
|
||||
"labels": labels,
|
||||
"prompt_len": len(prompt_tokens),
|
||||
}
|
||||
|
||||
return {"input_ids": prompt_response_tokens, "labels": labels}
|
||||
53
dllm/dllm/utils/generation_utils.py
Normal file
53
dllm/dllm/utils/generation_utils.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
|
||||
from dllm.core.schedulers import BaseAlphaScheduler
|
||||
|
||||
|
||||
def get_num_transfer_tokens(
|
||||
mask_index: torch.Tensor,
|
||||
steps: int,
|
||||
scheduler: BaseAlphaScheduler,
|
||||
stochastic: bool = False,
|
||||
) -> torch.Tensor:
|
||||
mask_num = mask_index.sum(dim=1, keepdim=True)
|
||||
num_transfer_tokens = torch.zeros(
|
||||
mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
|
||||
)
|
||||
for i in range(mask_num.size(0)):
|
||||
for t, s, j in zip(range(steps, 0, -1), range(steps - 1, -1, -1), range(steps)):
|
||||
s /= steps
|
||||
t /= steps
|
||||
reverse_transfer_prob = 1 - scheduler.reverse_mask_prob(s=s, t=t)
|
||||
if not stochastic:
|
||||
x = mask_num[i, 0].to(torch.float64) * reverse_transfer_prob
|
||||
num_transfer_tokens[i, j] = torch.round(x).to(torch.int64)
|
||||
else:
|
||||
n = mask_num[i, 0].to(torch.float64)
|
||||
num_transfer_tokens[i, j] = (
|
||||
torch.distributions.Binomial(n, reverse_transfer_prob)
|
||||
.sample()
|
||||
.to(torch.int64)
|
||||
)
|
||||
num_transfer_tokens[i, j] = torch.minimum(
|
||||
num_transfer_tokens[i, j], mask_num[i, 0]
|
||||
)
|
||||
mask_num[i, 0] -= num_transfer_tokens[i, j]
|
||||
if mask_num[i, 0].item() == 0:
|
||||
break
|
||||
# Note: because llada is not conditioned on time, this allows us to skip steps with no unmasking (i.e. transfer).
|
||||
# Clear all zeros per row (compact) and right-pad with zeros
|
||||
# Remove zeros per row, then pad only up to the max length across rows
|
||||
rows = []
|
||||
max_len = 0
|
||||
for i in range(num_transfer_tokens.size(0)):
|
||||
nonzero = num_transfer_tokens[i][num_transfer_tokens[i] > 0]
|
||||
rows.append(nonzero)
|
||||
max_len = max(max_len, nonzero.numel())
|
||||
# Pad each row to max_len
|
||||
padded_rows = []
|
||||
for r in rows:
|
||||
if r.numel() < max_len:
|
||||
pad = torch.zeros(max_len - r.numel(), dtype=r.dtype, device=r.device)
|
||||
r = torch.cat([r, pad])
|
||||
padded_rows.append(r)
|
||||
return torch.stack(padded_rows, dim=0)
|
||||
180
dllm/dllm/utils/model_utils.py
Normal file
180
dllm/dllm/utils/model_utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
import torch
|
||||
import accelerate
|
||||
import transformers
|
||||
from peft import prepare_model_for_kbit_training
|
||||
|
||||
from dllm.utils.utils import disable_caching_allocator_warmup, print_main, load_peft
|
||||
from dllm.utils.configs import ModelArguments, TrainingArguments
|
||||
|
||||
|
||||
def get_model(
|
||||
model_args,
|
||||
config: transformers.PretrainedConfig | None = None,
|
||||
) -> transformers.PreTrainedModel:
|
||||
"""
|
||||
Load a model with flexible input sources.
|
||||
|
||||
Args:
|
||||
model_args: An optional dataclass or namespace containing model parameters.
|
||||
model_name_or_path: Optional direct model path or name (overrides model_args.model_name_or_path).
|
||||
dtype: Dtype (string or torch.dtype).
|
||||
load_in_4bit: Whether to load using 4-bit quantization (can override model_args.load_in_4bit).
|
||||
|
||||
Returns:
|
||||
transformers.PreTrainedModel
|
||||
"""
|
||||
model_name_or_path = getattr(model_args, "model_name_or_path")
|
||||
dtype = getattr(model_args, "dtype", "bfloat16")
|
||||
load_in_4bit = getattr(model_args, "load_in_4bit", False)
|
||||
attn_implementation = getattr(model_args, "attn_implementation", None)
|
||||
|
||||
# Device map: skip when ZeRO-3
|
||||
device_map = (
|
||||
{"": accelerate.PartialState().local_process_index}
|
||||
if not transformers.modeling_utils.is_deepspeed_zero3_enabled()
|
||||
and torch.cuda.is_available()
|
||||
else None
|
||||
)
|
||||
|
||||
quant_config = None
|
||||
if load_in_4bit and transformers.utils.is_bitsandbytes_available():
|
||||
quant_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
params = {
|
||||
"dtype": dtype,
|
||||
"device_map": device_map,
|
||||
"quantization_config": quant_config,
|
||||
"attn_implementation": attn_implementation,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
try:
|
||||
model = transformers.AutoModelForMaskedLM.from_pretrained(
|
||||
model_name_or_path, **params
|
||||
)
|
||||
except:
|
||||
model = transformers.AutoModel.from_pretrained(model_name_or_path, **params)
|
||||
|
||||
# --- if quantized, prepare for LoRA / QLoRA training ---
|
||||
if load_in_4bit and quant_config is not None:
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
|
||||
|
||||
# Optionally train with lora
|
||||
model = load_peft(model, model_args)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer:
|
||||
"""
|
||||
Load a tokenizer with flexible input sources.
|
||||
|
||||
Args:
|
||||
model_args: Optional dataclass or namespace containing model parameters.
|
||||
model: Optional model instance to configure tokenizer behavior.
|
||||
model_name_or_path: Optional direct model name or path (overrides model_args.model_name_or_path).
|
||||
|
||||
Returns:
|
||||
transformers.PreTrainedTokenizer
|
||||
"""
|
||||
# Lazy imports to avoid circular dependencies
|
||||
from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM
|
||||
from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM
|
||||
from dllm.pipelines.dream.models.modeling_dream import DreamModel
|
||||
from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
|
||||
from transformers import (
|
||||
BertPreTrainedModel,
|
||||
RobertaPreTrainedModel,
|
||||
ModernBertPreTrainedModel,
|
||||
)
|
||||
|
||||
model_name_or_path = getattr(model_args, "model_name_or_path")
|
||||
|
||||
# ---------------- Tokenizer loading ----------------
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
padding_side="right",
|
||||
)
|
||||
|
||||
assert tokenizer.eos_token != None or tokenizer.pad_token != None
|
||||
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if not tokenizer.eos_token:
|
||||
tokenizer.eos_token = tokenizer.pad_token
|
||||
if not tokenizer.bos_token:
|
||||
tokenizer.bos_token = tokenizer.pad_token
|
||||
|
||||
# If model is not provided, return as-is
|
||||
model_cfg = transformers.AutoConfig.from_pretrained(model_name_or_path)
|
||||
model_cls = transformers.AutoModel._model_mapping[type(model_cfg)]
|
||||
|
||||
# ---------------- Model-specific customization ----------------
|
||||
if issubclass(model_cls, LLaDAModelLM):
|
||||
tokenizer.add_special_tokens({"mask_token": "<|mdm_mask|>"})
|
||||
tokenizer.eot_token = "<|eot_id|>"
|
||||
# tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) # can not do this for llada base directly
|
||||
# TODO: for llada base, add special_tokens = {"<|start_header_id|>": 126346, "<|end_header_id|>": 126347, "<|eot_id|>": 126348}
|
||||
# fix bugs in chat template
|
||||
tokenizer.chat_template = """\
|
||||
{% set loop_messages = messages %}
|
||||
{% for message in loop_messages %}
|
||||
{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}
|
||||
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
|
||||
|
||||
{{ message['content'] | trim }}<|eot_id|>
|
||||
{%- endfor %}
|
||||
{% if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{% endif %}
|
||||
"""
|
||||
elif issubclass(model_cls, LLaDAMoEModelLM):
|
||||
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
|
||||
tokenizer.eot_token = "<|role_end|>"
|
||||
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
|
||||
elif issubclass(model_cls, DreamModel):
|
||||
tokenizer.eot_token = "<|im_end|>"
|
||||
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
|
||||
elif issubclass(model_cls, RND1LM):
|
||||
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
|
||||
elif issubclass(
|
||||
model_cls,
|
||||
(BertPreTrainedModel, RobertaPreTrainedModel, ModernBertPreTrainedModel),
|
||||
):
|
||||
tokenizer.eot_token = "[/Answer]"
|
||||
tokenizer.chat_template = """\
|
||||
{% if messages[0]['role'] == 'system' %}
|
||||
[SYS]
|
||||
{{ messages[0]['content'] | trim }}
|
||||
[/SYS]
|
||||
|
||||
{% set loop_messages = messages[1:] %}
|
||||
{% else %}
|
||||
{% set loop_messages = messages %}
|
||||
{% endif -%}
|
||||
{%- for message in loop_messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
[Question]
|
||||
{{ message['content'] | trim }}
|
||||
[/Question]
|
||||
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
[Answer]
|
||||
{{ message['content'] | trim }}
|
||||
[/Answer]
|
||||
|
||||
{% endif %}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
|
||||
[Answer]
|
||||
{% endif %}
|
||||
"""
|
||||
else:
|
||||
print_main("no tokenizer customization for model class:", model_cls)
|
||||
return tokenizer
|
||||
284
dllm/dllm/utils/utils.py
Normal file
284
dllm/dllm/utils/utils.py
Normal file
@ -0,0 +1,284 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
|
||||
|
||||
import pprint
|
||||
import torch
|
||||
import peft
|
||||
import accelerate
|
||||
import transformers
|
||||
|
||||
|
||||
def resolve_with_base_env(path: str, env_name: str) -> str:
|
||||
"""
|
||||
If `env_name` is set and `path` is NOT absolute, NOT a URL/scheme,
|
||||
and does not already exist locally, prepend the `env_name` directory.
|
||||
|
||||
If the resulting path does not exist, return the base environment directory instead.
|
||||
Otherwise return `path` unchanged.
|
||||
"""
|
||||
base = os.getenv(env_name, "").strip()
|
||||
if not base:
|
||||
return path
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
|
||||
candidate = os.path.join(base.rstrip("/"), path.lstrip("/"))
|
||||
if os.path.exists(candidate):
|
||||
return candidate
|
||||
else:
|
||||
raise FileNotFoundError
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_device_context_manager(device: str | torch.device | None = None):
|
||||
"""
|
||||
Temporarily set torch default dtype and default device so that tensors
|
||||
created inside the context are allocated on `device` with dtype `dtype`.
|
||||
Restores previous settings on exit.
|
||||
"""
|
||||
if transformers.integrations.is_deepspeed_zero3_enabled():
|
||||
yield
|
||||
return
|
||||
|
||||
# Resolve device
|
||||
if device is None:
|
||||
try:
|
||||
from accelerate import PartialState
|
||||
|
||||
idx = PartialState().local_process_index
|
||||
except Exception:
|
||||
idx = 0
|
||||
device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
|
||||
elif isinstance(device, int):
|
||||
device = f"cuda:{device}"
|
||||
|
||||
try:
|
||||
torch.set_default_device(device)
|
||||
yield
|
||||
finally:
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
|
||||
def print_main(*args, **kwargs):
|
||||
"""
|
||||
Print only from the global main process (rank 0 across all nodes).
|
||||
Usage: print_main("Hello from main process!")
|
||||
"""
|
||||
if accelerate.PartialState().is_main_process:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def pprint_main(*args, **kwargs):
|
||||
"""
|
||||
Print (with pprint) only from the global main process (rank 0 across all nodes).
|
||||
Usage: print_main("Hello from main process!")
|
||||
"""
|
||||
if accelerate.PartialState().is_main_process:
|
||||
pprint.pprint(*args, **kwargs)
|
||||
|
||||
|
||||
def load_peft(
|
||||
model: transformers.PreTrainedModel, model_args: "ModelArguments"
|
||||
) -> transformers.PreTrainedModel:
|
||||
"""
|
||||
e.g.,
|
||||
--modules_to_save "lm_head" --target_modules "q_proj,k_proj,v_proj,o_proj,up_proj,down_proj,gate_proj"
|
||||
--target_modules "all-linear"
|
||||
"""
|
||||
if not getattr(model_args, "lora", False):
|
||||
return model
|
||||
target_modules = (
|
||||
model_args.target_modules.split(",") if model_args.target_modules else None
|
||||
)
|
||||
# if it’s a single 'all-linear', drop the list and use the string directly
|
||||
if (
|
||||
target_modules
|
||||
and len(target_modules) == 1
|
||||
and target_modules[0].strip() == "all-linear"
|
||||
):
|
||||
target_modules = target_modules[0]
|
||||
modules_to_save = (
|
||||
model_args.modules_to_save.split(",") if model_args.modules_to_save else None
|
||||
)
|
||||
peft_config = peft.LoraConfig(
|
||||
r=model_args.r,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=model_args.lora_alpha,
|
||||
lora_dropout=model_args.lora_dropout,
|
||||
bias=model_args.bias,
|
||||
modules_to_save=modules_to_save,
|
||||
)
|
||||
model = peft.get_peft_model(model, peft_config)
|
||||
if accelerate.PartialState().is_main_process:
|
||||
print(model)
|
||||
model.print_trainable_parameters()
|
||||
return model
|
||||
|
||||
|
||||
def print_args_main(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments",
|
||||
):
|
||||
print_main("\n===== Parsed arguments =====")
|
||||
for name, args in [
|
||||
("model_args", model_args),
|
||||
("data_args", data_args),
|
||||
("training_args", training_args),
|
||||
]:
|
||||
d = asdict(args)
|
||||
# keep it tiny: just show first few entries
|
||||
short = {k: d[k] for k in list(d)} # adjust number as you like
|
||||
print_main(f"{name}:")
|
||||
pprint_main(short, width=100, compact=True, sort_dicts=False)
|
||||
print_main("============================\n")
|
||||
|
||||
|
||||
def print_args(args):
|
||||
print_main("\n===== Parsed arguments =====")
|
||||
d = asdict(args)
|
||||
# keep it tiny: just show first few entries
|
||||
short = {k: d[k] for k in list(d)} # adjust number as you like
|
||||
pprint_main(short, width=100, compact=True, sort_dicts=False)
|
||||
print_main("============================\n")
|
||||
|
||||
|
||||
def disable_caching_allocator_warmup():
|
||||
try:
|
||||
from transformers import modeling_utils as _mu
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
return
|
||||
|
||||
_mu.caching_allocator_warmup = _noop
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def disable_dataset_progress_bar_except_main():
|
||||
# state = accelerate.PartialState() # figures out your rank/world automatically
|
||||
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||
|
||||
if accelerate.PartialState().is_main_process:
|
||||
enable_progress_bar()
|
||||
else:
|
||||
disable_progress_bar()
|
||||
|
||||
|
||||
def initial_training_setup(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments",
|
||||
):
|
||||
transformers.set_seed(training_args.seed)
|
||||
disable_caching_allocator_warmup()
|
||||
disable_dataset_progress_bar_except_main()
|
||||
if getattr(data_args, "disable_caching", False):
|
||||
disable_dataset_caching()
|
||||
|
||||
|
||||
def disable_dataset_caching():
|
||||
from datasets import disable_caching
|
||||
|
||||
disable_caching()
|
||||
tmp_root = f"/tmp/hf_cache_rank{accelerate.PartialState().process_index}"
|
||||
os.environ["HF_DATASETS_CACHE"] = tmp_root
|
||||
os.environ["HF_DATASETS_TEMP_DIR"] = tmp_root
|
||||
os.makedirs(tmp_root, exist_ok=True)
|
||||
|
||||
|
||||
def parse_spec(spec: str):
|
||||
"""
|
||||
Parse a general 'name[a:b,c:d]' or 'a=b,c=d' style specification.
|
||||
|
||||
Supports:
|
||||
- Bare name, e.g. "foo/bar"
|
||||
- Optional bracket suffix with comma-separated entries:
|
||||
key:value or key:int_value (underscores allowed)
|
||||
- Optional "key=value" pairs outside the bracket.
|
||||
|
||||
Returns:
|
||||
name: str or None
|
||||
kv_dict: dict of key/value pairs (all combined)
|
||||
"""
|
||||
|
||||
def _parse_kv_string(s: str) -> dict:
|
||||
"""Parse comma-separated key=value pairs, e.g. 'a=1,b=2'."""
|
||||
return dict(part.split("=", 1) for part in s.split(",") if "=" in part)
|
||||
|
||||
s = spec.strip()
|
||||
|
||||
# Extract bracket content if present
|
||||
m = re.search(r"\[(.*?)\]$", s)
|
||||
bracket_kvs = {}
|
||||
numeric_kvs = {}
|
||||
if m:
|
||||
bracket = m.group(1).strip()
|
||||
if bracket:
|
||||
for part in bracket.split(","):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
if ":" not in part:
|
||||
raise ValueError(
|
||||
f"Invalid entry '{part}' in '{spec}' (expected key:value)."
|
||||
)
|
||||
key, value = part.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Integers (with optional underscores)
|
||||
if re.fullmatch(r"\d(?:_?\d)*", value):
|
||||
numeric_kvs[key] = int(value.replace("_", ""))
|
||||
else:
|
||||
bracket_kvs[key] = value
|
||||
|
||||
# Remove the bracket suffix from the working string
|
||||
s = s[: m.start()].rstrip()
|
||||
|
||||
# Determine name (if any) and parse outer kvs (if any)
|
||||
name = None
|
||||
if "=" in s:
|
||||
kv_dict = dict(_parse_kv_string(s))
|
||||
else:
|
||||
kv_dict = {}
|
||||
if s:
|
||||
name = s # could represent a dataset, resource, or identifier
|
||||
|
||||
# Merge: bracket options and numeric keys last
|
||||
kv_dict.update(bracket_kvs)
|
||||
kv_dict.update(numeric_kvs)
|
||||
|
||||
return name, kv_dict
|
||||
|
||||
|
||||
def get_default_logger(name):
|
||||
logger = logging.getLogger(name)
|
||||
if accelerate.PartialState().is_main_process:
|
||||
logger.setLevel(logging.INFO)
|
||||
else:
|
||||
logger.setLevel(logging.WARNING)
|
||||
handler = logging.StreamHandler(sys.stdout) # print to terminal
|
||||
formatter = logging.Formatter(
|
||||
fmt=(
|
||||
"\x1b[38;5;110m[%(asctime)s "
|
||||
"\x1b[38;5;174m%(levelname)s "
|
||||
"\x1b[38;5;109m%(name)s"
|
||||
"/%(lineno)d-%(processName)s\x1b[38;5;110m] "
|
||||
"\x1b[0m%(message)s"
|
||||
),
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
190
dllm/examples/bert/README.md
Normal file
190
dllm/examples/bert/README.md
Normal file
@ -0,0 +1,190 @@
|
||||
# Generative BERT
|
||||
|
||||
[](https://huggingface.co/collections/dllm-collection/bert-chat)
|
||||
[](https://api.wandb.ai/links/asap-zzhou/101h5xvg)
|
||||
|
||||
This directory provides two key sets of resources:
|
||||
|
||||
1. **Toy Examples ([Warmup](#warmup)):** Scripts for pretraining and SFTing any BERT-style model on small datasets to generate text.
|
||||
2. **Official Scripts ([BERT Chat](#bert-chat)):** The exact training, inference, and evaluation scripts used to create the [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) checkpoints, two BERTs finetuned as Chatbots. For a deep dive into experimental results, lessons learned, and more reproduction details, please see our full [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
<p align="center" style="margin-top: 15px;">
|
||||
<img src="/examples/bert/assets/chat.gif" alt="chat" width="70%">
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>
|
||||
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
|
||||
</em>
|
||||
</p>
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/bert
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
|
||||
## Warmup
|
||||
|
||||
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text.
|
||||
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
|
||||
|
||||
### Pretrain
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "Trelis/tiny-shakespeare" \
|
||||
--text_field "Text" \
|
||||
--insert_eos False \
|
||||
--max_length 128 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tiny-shakespeare"
|
||||
```
|
||||
|
||||
To run inference with the model:
|
||||
```shell
|
||||
# just press enter (empty prompt) if you want the model to generate text from scratch
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
|
||||
--chat False --remasking "random" --steps 128 --max_new_tokens 128
|
||||
```
|
||||
|
||||
### SFT
|
||||
|
||||
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "tatsu-lab/alpaca" \
|
||||
--max_length 512 \
|
||||
--num_train_epochs 20 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--per_device_eval_batch_size 64 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/alpaca"
|
||||
```
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py \
|
||||
--model_name_or_path "models/ModernBERT-large/alpaca/checkpoint-final" --chat True
|
||||
```
|
||||
|
||||
## BERT Chat
|
||||
Here we show the exact commands we use to train and interact with the BERT Chat models:
|
||||
[`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0).
|
||||
For training curves and other details, please see [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
|
||||
|
||||
### Training
|
||||
|
||||
To reproduce [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-base" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-base/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
To reproduce [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0), run:
|
||||
```shell
|
||||
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
|
||||
examples/bert/sft.py \
|
||||
--model_name_or_path "answerdotai/ModernBERT-large" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 10 \
|
||||
--per_device_train_batch_size 48 \
|
||||
--per_device_eval_batch_size 48 \
|
||||
--save_steps 0.1 \
|
||||
--output_dir "models/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
To chat with the model:
|
||||
```shell
|
||||
python -u examples/bert/chat.py --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" --chat True
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/bert/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "bert" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=dllm-collection/ModernBERT-large-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
```
|
||||
|
||||
To automatically evaluate [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-base-chat-v0"
|
||||
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0"
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results are taken from the original paper.
|
||||
> Because the original work does not fully disclose its evaluation techniques or implementation tricks, we reproduce the setup using the best available methods. As a result, our reproduced scores may show a small residual gap relative to the reported numbers. -->
|
||||
|
||||
<!-- | [`GPT-2`](https://huggingface.co/openai-community/gpt2)(reported) | 0.460 | – | | | | | | | |
|
||||
| [`GPT-2`](https://huggingface.co/openai-community/gpt2)(evaluated) | 0.438 | 0.020 | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(reported) | 0.555 | – | | | | | | | |
|
||||
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(evaluated) | 0.549 | 0.021 | | | | | | | | -->
|
||||
<!-- <div align="center" style="min-width:1500px;"> -->
|
||||
|
||||
| | LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU |
|
||||
|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
||||
| [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0)(evaluated) | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 |
|
||||
| [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0)(evaluated) | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 |
|
||||
| [`Qwen1.5-0.5B`](https://huggingface.co/Qwen/Qwen1.5-0.5B)(<ins>reported</ins> & evaluated) | 48.6 | <ins>22.0</ins> | <ins>50.5</ins> | <ins>18.3</ins> | <ins>3.1</ins> | <ins>39.2</ins> | 55.0 | 48.2 | <ins>46.6</ins> |
|
||||
| [`Qwen1.5-0.5B-Chat`](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat)(<ins>reported</ins> & evaluated) | 41.2 | <ins>11.3</ins> | <ins>37.2</ins> | 18.2 | 2.1 | <ins>35.0</ins> | 52.0 | 36.9 | 32.2 |
|
||||
| [`gpt2`](https://huggingface.co/openai-community/gpt2)(<ins>reported</ins> & evaluated) | <ins>46.0</ins> | 0.7 | 24.7 | 6.9 | 1.8 | 22.9 | 51.6 | 31.1 | 25.2 |
|
||||
| [`gpt2-medium`](https://huggingface.co/openai-community/gpt2-medium)(<ins>reported</ins> & evaluated) | <ins>55.5</ins> | 2.1 | 24.6 | 17.8 | 1.4 | 22.9 |53.1 | 39.4 | 0.3 |
|
||||
|
||||
|
||||
<p align="left" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-base-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0" style="color: #808080; text-decoration: none;">
|
||||
<code>ModernBERT-large-chat-v0</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat" style="color: #808080; text-decoration: none;">
|
||||
<code>Qwen1.5-0.5B-Chat</code>
|
||||
</a>,
|
||||
<a href="https://huggingface.co/openai-community/gpt2" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2</code>
|
||||
</a>, and
|
||||
<a href="https://huggingface.co/openai-community/gpt2-medium" style="color: #808080; text-decoration: none;">
|
||||
<code>gpt2-medium</code>
|
||||
</a>.
|
||||
<ins>Underlined entries</ins> are results from official reports: <a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf" style="color: #808080; text-decoration: none;">GPT-2 paper</a>, <a href="https://qwen.ai/blog?id=qwen1.5" style="color: #808080; text-decoration: none;">Qwen 1.5 blog</a>, and <a href="https://huggingface.co/Qwen/Qwen2-0.5B-Instruct" style="color: #808080; text-decoration: none;">Qwen2-0.5B-Instruct model card</a>. All other results are evaluated using our framework.
|
||||
</p>
|
||||
BIN
dllm/examples/bert/assets/chat.gif
Normal file
BIN
dllm/examples/bert/assets/chat.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.1 MiB |
71
dllm/examples/bert/chat.py
Normal file
71
dllm/examples/bert/chat.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Interactive chat / generation script for Bert models.
|
||||
|
||||
Examples
|
||||
--------
|
||||
# Raw multi-turn generation (default)
|
||||
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
|
||||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.pipelines import llada
|
||||
from dllm.tools.chat import multi_turn_chat, single_turn_generate
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
chat: bool = True
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# same base-path resolution logic as in generate.py
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 32
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
def main():
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
|
||||
if script_args.chat:
|
||||
multi_turn_chat(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
else:
|
||||
print("\nSingle-turn generation (no chat template).")
|
||||
single_turn_generate(
|
||||
generator=generator,
|
||||
gen_config=gen_config,
|
||||
visualize=script_args.visualize,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Bye!")
|
||||
sys.exit(0)
|
||||
50
dllm/examples/bert/eval.sh
Normal file
50
dllm/examples/bert/eval.sh
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env bash
|
||||
# ===== Mandatory for proper import and evaluation =====
|
||||
export PYTHONPATH=.:$PYTHONPATH
|
||||
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
|
||||
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
|
||||
|
||||
# ===== Optional but recommended for stability and debugging =====
|
||||
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
|
||||
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
|
||||
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
|
||||
|
||||
# ===== Basic Settings =====
|
||||
model_name_or_path="dllm-collection/ModernBERT-large-chat-v0"
|
||||
num_gpu=4
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--model_name_or_path)
|
||||
model_name_or_path="$2"; shift 2 ;;
|
||||
--num_gpu)
|
||||
num_gpu="$2"; shift 2 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ===== Common arguments =====
|
||||
common_args="--model bert --apply_chat_template" # BERT model is default to use chat template
|
||||
|
||||
# =======================
|
||||
# BERT Instruct (Chat) Tasks
|
||||
# =======================
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
|
||||
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
|
||||
--tasks winogrande --num_fewshot 0 ${common_args} \
|
||||
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
|
||||
73
dllm/examples/bert/generate.py
Normal file
73
dllm/examples/bert/generate.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
python -u examples/bert/generate.py --model_name_or_path "YOUR_MODEL_PATH"
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import transformers
|
||||
|
||||
import dllm
|
||||
from dllm.tools.chat import decode_trim
|
||||
from dllm.pipelines import llada
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
|
||||
seed: int = 42
|
||||
visualize: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_name_or_path = dllm.utils.resolve_with_base_env(
|
||||
self.model_name_or_path, "BASE_MODELS_DIR"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorConfig(llada.LLaDAGeneratorConfig):
|
||||
steps: int = 128
|
||||
max_new_tokens: int = 128
|
||||
block_length: int = 64
|
||||
temperature: float = 0.0
|
||||
remasking: str = "low_confidence"
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
|
||||
script_args, gen_config = parser.parse_args_into_dataclasses()
|
||||
transformers.set_seed(script_args.seed)
|
||||
|
||||
# Load model & tokenizer
|
||||
model = dllm.utils.get_model(model_args=script_args).eval()
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
|
||||
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
|
||||
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
# --- Example 1: Batch generation ---
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST: bert.generate()".center(80))
|
||||
print("=" * 80)
|
||||
|
||||
messages = [
|
||||
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
|
||||
[{"role": "user", "content": "Please write an educational python function."}],
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
)
|
||||
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
|
||||
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
|
||||
|
||||
for iter, s in enumerate(sequences):
|
||||
print("\n" + "-" * 80)
|
||||
print(f"[Case {iter}]")
|
||||
print("-" * 80)
|
||||
print(s.strip() if s.strip() else "<empty>")
|
||||
print("\n" + "=" * 80 + "\n")
|
||||
|
||||
if script_args.visualize:
|
||||
terminal_visualizer.visualize(outputs.histories, rich=True)
|
||||
127
dllm/examples/bert/pt.py
Normal file
127
dllm/examples/bert/pt.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/pt.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/pt.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/pt.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "Trelis/tiny-shakespeare"
|
||||
text_field: str = "Text"
|
||||
max_length: int = 128
|
||||
streaming: bool = False
|
||||
drop_tail: bool = True
|
||||
insert_eos: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "False when adjacent samples from the datasets are semantically coherent."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-base/tiny-shakespeare"
|
||||
num_train_epochs: int = 20
|
||||
learning_rate: float = 1e-4
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_pt_dataset(
|
||||
data_args.dataset_args,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
functools.partial(
|
||||
dllm.utils.tokenize_and_group,
|
||||
tokenizer=tokenizer,
|
||||
text_field=data_args.text_field,
|
||||
seq_length=data_args.max_length,
|
||||
insert_eos=data_args.insert_eos,
|
||||
drop_tail=data_args.drop_tail,
|
||||
),
|
||||
batched=True,
|
||||
remove_columns=dataset["train"].column_names,
|
||||
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
|
||||
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(seed=training_args.seed)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
127
dllm/examples/bert/sft.py
Normal file
127
dllm/examples/bert/sft.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Local users
|
||||
------------
|
||||
- 1 GPU:
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
|
||||
examples/bert/sft.py
|
||||
|
||||
- 8 GPUs (DDP):
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/ddp.yaml \
|
||||
examples/bert/sft.py
|
||||
|
||||
Slurm users
|
||||
# Note: run `mkdir logs` before running sbatch; and adjust
|
||||
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
|
||||
------------
|
||||
- 1 Node, 8 GPUs (DDP):
|
||||
sbatch --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
|
||||
- 2 Nodes, 16 GPUs (DDP):
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "ddp" \
|
||||
--script_path "examples/bert/sft.py"
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
|
||||
import transformers
|
||||
import accelerate
|
||||
|
||||
import dllm
|
||||
|
||||
logger = dllm.utils.get_default_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(dllm.utils.ModelArguments):
|
||||
model_name_or_path: str = "answerdotai/ModernBERT-large"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments(dllm.utils.DataArguments):
|
||||
dataset_args: str = "tatsu-lab/alpaca"
|
||||
max_length: int = 512
|
||||
load_preprocessed_data: bool = False
|
||||
mask_prompt_loss: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to mask the loss on the prompt tokens"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(dllm.utils.TrainingArguments):
|
||||
output_dir: str = "models/ModernBERT-large/alpaca"
|
||||
group_by_length: bool = True
|
||||
learning_rate: float = 1e-4
|
||||
num_train_epochs: int = 20
|
||||
per_device_train_batch_size: int = 64
|
||||
per_device_eval_batch_size: int = 64
|
||||
eval_steps: float = 0.1
|
||||
save_steps: float = 0.1
|
||||
|
||||
|
||||
def train():
|
||||
# ----- Argument parsing -------------------------------------------------------
|
||||
parser = transformers.HfArgumentParser(
|
||||
(ModelArguments, DataArguments, TrainingArguments)
|
||||
)
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
dllm.utils.print_args_main(model_args, data_args, training_args)
|
||||
dllm.utils.initial_training_setup(model_args, data_args, training_args)
|
||||
|
||||
# ----- Model ------------------------------------------------------------------
|
||||
model = dllm.utils.get_model(model_args=model_args)
|
||||
# ----- Tokenizer --------------------------------------------------------------
|
||||
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
|
||||
|
||||
# ----- Dataset ----------------------------------------------------------------
|
||||
with accelerate.PartialState().local_main_process_first():
|
||||
dataset = dllm.data.load_sft_dataset(
|
||||
data_args.dataset_args,
|
||||
load_preprocessed_data=data_args.load_preprocessed_data,
|
||||
)
|
||||
if not data_args.load_preprocessed_data:
|
||||
map_fn = partial(
|
||||
dllm.utils.default_sft_map_fn,
|
||||
tokenizer=tokenizer,
|
||||
mask_prompt_loss=data_args.mask_prompt_loss,
|
||||
)
|
||||
dataset = dataset.map(
|
||||
map_fn,
|
||||
num_proc=data_args.num_proc,
|
||||
desc="Mapping dataset to SFT format",
|
||||
)
|
||||
# truncate / filter long sequences if needed
|
||||
dataset = dllm.utils.post_process_dataset(dataset, data_args)
|
||||
|
||||
# ----- Training --------------------------------------------------------------
|
||||
accelerate.PartialState().wait_for_everyone()
|
||||
logger.info("Start training...")
|
||||
trainer = dllm.core.trainers.MDLMTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset.get("test", None),
|
||||
args=training_args,
|
||||
data_collator=dllm.utils.NoAttentionMaskCollator(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
|
||||
trainer.processing_class.save_pretrained(
|
||||
os.path.join(training_args.output_dir, "checkpoint-final")
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
187
dllm/examples/dream/README.md
Normal file
187
dllm/examples/dream/README.md
Normal file
@ -0,0 +1,187 @@
|
||||
# Dream
|
||||
|
||||
> 📄 Paper: [Dream 7B: Diffusion Large Language Models](https://arxiv.org/abs/2508.15487) | 💻 Code: [github.com/DreamLM/Dream](https://github.com/DreamLM/Dream)
|
||||
|
||||
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **Dream**.
|
||||
|
||||
## Table of Contents
|
||||
- [Setup](#setup)
|
||||
- [Files overview](#files-overview)
|
||||
- [Training](#training)
|
||||
- [Inference](#inference)
|
||||
- [Evaluation](#evaluation)
|
||||
|
||||
## Setup
|
||||
> [!IMPORTANT]
|
||||
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
|
||||
>
|
||||
|
||||
|
||||
## Files overview
|
||||
```
|
||||
# tools relevant with Dream
|
||||
dllm/pipelines/dream
|
||||
├── __init__.py # Package initialization
|
||||
├── models/
|
||||
│ ├── configuration_dream.py # Dream model configuration
|
||||
│ ├── generation_utils.py # Diffusion-based generation logic
|
||||
│ ├── modeling_dream.py # Core Dream model architecture
|
||||
│ └── tokenization_dream.py # Tokenizer implementation for Dream
|
||||
├── generator.py # Inference logic
|
||||
├── trainer.py # Training logic (pretraining and SFT)
|
||||
└── utils.py # Auxiliary utilities and helper functions
|
||||
|
||||
# example entry points for training / inference / evaluation
|
||||
examples/dream
|
||||
├── chat.py # Interactive inference example
|
||||
├── eval.sh # Automatic evaluation script
|
||||
├── generate.py # Inference example
|
||||
├── pt.py # Pretraining example
|
||||
├── README.md # Documentation (you are here)
|
||||
└── sft.py # Supervised finetuning example
|
||||
```
|
||||
<!-- > [!NOTE]
|
||||
> We slightly modified [`modeling_dream.py`](/dllm/pipelines/dream/models/modeling_dream.py) so that the `model.forward()` supports 2-D attention masks. We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
|
||||
>
|
||||
> We fixed bugs in `chat_template` and standardize `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
|
||||
|
||||
## Training
|
||||
|
||||
### Finetuning
|
||||
For example, to SFT [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) for instruction following on 8 GPUs, run:
|
||||
```shell
|
||||
accelerate launch \
|
||||
--config_file scripts/accelerate_configs/fsdp.yaml \
|
||||
examples/dream/sft.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
|
||||
```shell
|
||||
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
|
||||
--max_length 1024 \
|
||||
--num_train_epochs 4 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
|
||||
<!-- **Reproducing [Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Base-7B)**. We tried our best to reproduce Dream-v0-Instruct-7B by finetuning Dream-v0-Base-7B using our training pipeline on the public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
|
||||
#### Reproducing [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)
|
||||
We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) by finetuning [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) using our training pipeline on the public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
|
||||
|
||||
```shell
|
||||
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
|
||||
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
|
||||
--dataset_args "allenai/tulu-3-sft-mixture" \
|
||||
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--num_proc 64
|
||||
|
||||
# train on 24*8=192 A100s with FSDP, take about 8 hours
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/sft.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "data/sft/dream/tulu-3-sft-mixture" \
|
||||
--load_preprocessed_data True \
|
||||
--output_dir "models/Dream-7B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
|
||||
--max_length 2048 \
|
||||
--truncation "right" \
|
||||
--group_by_length True \
|
||||
--num_train_epochs 5 \
|
||||
--learning_rate 1e-5 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--eval_on_start False \
|
||||
--eval_steps 0.1 \
|
||||
--save_steps 0.05
|
||||
```
|
||||
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
|
||||
|
||||
### Pretraining
|
||||
|
||||
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
|
||||
```shell
|
||||
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
|
||||
--accelerate_config "fsdp" \
|
||||
--script_path "examples/dream/pt.py" \
|
||||
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
|
||||
--dataset_args "mlfoundations/dclm-baseline-1.0" \
|
||||
--output_dir "models/Dream-7B-PT/dclm-baseline-1.0" \
|
||||
--max_length 1024 \
|
||||
--max_steps 2000 \
|
||||
--learning_rate 3e-4
|
||||
```
|
||||
|
||||
## Inference
|
||||
We support batch inference for standard generation and infilling:
|
||||
<!-- See [`examples/dream/generate.py`](/examples/dream/generate.py) for a full example: -->
|
||||
```shell
|
||||
python examples/dream/generate.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
|
||||
```
|
||||
We also support interactive multi-turn dialogue with visualization:
|
||||
```shell
|
||||
python examples/dream/chat.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
|
||||
|
||||
For example, to evaluate [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
|
||||
```shell
|
||||
# Use model_args to adjust the generation arguments for evalution.
|
||||
accelerate launch --num_processes 4 \
|
||||
dllm/pipelines/dream/eval.py \
|
||||
--tasks "mmlu_pro" \
|
||||
--model "dream" \
|
||||
--apply_chat_template \
|
||||
--num_fewshot 0 \
|
||||
--model_args "pretrained=Dream-org/Dream-v0-Instruct-7B,mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
|
||||
```
|
||||
|
||||
To automatically evaluate [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) and [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on all benchmarks, run:
|
||||
```shell
|
||||
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" --instruct True
|
||||
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Base-7B" --instruct False
|
||||
```
|
||||
|
||||
### Evaluation results
|
||||
|
||||
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [Dream](https://github.com/DreamLM/Dream) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
|
||||
|
||||
| | MMLU | BBH | ARC‑C | ARC‑E | Hellaswag | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | RACE | Countdown | Sudoku | Trip planning |
|
||||
|:----------------|:-------:|:-------:|:-----:|:-----:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:-----------:|:----:|:-----------:|
|
||||
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (reported) | 69.5 | 57.9 | 59.9 | 83.9 | 73.3 | 74.8 | 75.8 | 77.2 | 39.6 | 36.6 | 57.9 | 56.2 | 44.7 | 16.0 | 81.0 | 17.8 |
|
||||
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (evaluated) | – | – | 59.7 | 83.3 | 73.1 | 72.9 | 72.0 | 69.6 | – | 35.5 | 45.8 | – | 43.0 | – | – | – |
|
||||
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 1. Evaluation results of
|
||||
<a href="https://huggingface.co/Dream-org/Dream-v0-Base-7B" style="color: #808080; text-decoration: none;">
|
||||
<code>Dream-8B-Base</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
| | MMLU | MMLU-Pro | GSM8K | Math | GPQA | HumanEval | MBPP | IFEval |
|
||||
|:----------------|:----:|:---------:|:-----:|:----:|:----:|:-----------:|:----:|:----:|
|
||||
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(reported) | 67.0 | 43.3 | 81.0 | 39.2 | 33.0 | 55.5 | 58.8 | 62.5 |
|
||||
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(evaluated) | – | 43.0 | 82.6 | 39.9 | 32.4 | 59.1 | – | 62.3 |
|
||||
|
||||
<p align="center" style="color: #808080; font-size: 0.9em;">
|
||||
Table 2. Evaluation results of
|
||||
<a href="https://huggingface.co/Dream-org/Dream-v0-Instruct-7B" style="color: #808080; text-decoration: none;">
|
||||
<code>Dream-8B-Instruct</code>
|
||||
</a>.
|
||||
</p>
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user