1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -1,8 +1,18 @@
from math import ceil
from typing import Optional, Union, Literal
from typing_extensions import Unpack
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import functional as F
from torch import Tensor
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.modules.activation import _is_make_fx_tracing, _check_arg_device, _arg_requires_grad
class MLP(nn.Module):
def __init__(self, in_size, out_size, hidden_size, dropout):
@ -49,6 +59,64 @@ class extendedMLP(nn.Module):
x = layer(x)
return x
class extendedMLP(nn.Module):
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
super().__init__()
self.input_size = in_size
self.layers = nn.ModuleList()
if num_layers == 1:
# Only one layer
self.layers.append(nn.Linear(in_size, out_size))
return
elif num_layers > 1:
# First layer
self.layers.append(nn.Linear(in_size, hidden_size))
self.layers.append(nn.Dropout(dropout))
self.layers.append(nn.ReLU())
# Intermediate layers
if num_layers > 2:
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
self.layers.append(nn.Linear(hidden_size, hidden_size))
self.layers.append(nn.Dropout(dropout))
self.layers.append(nn.ReLU())
# Last layer
self.layers.append(nn.Linear(hidden_size, out_size))
else:
raise ValueError("num_layers should be a positive integer")
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class SwiGLUFFN(nn.Module):
def __init__(self, in_size, out_size, num_layers=2, hidden_size=2048, dropout=0.1):
super().__init__()
self.input_size = in_size
if num_layers == 1:
# 单层情况,直接线性映射
self.ffn = nn.Linear(in_size, out_size)
elif num_layers == 2:
# 两层时使用 SwiGLU
self.w1 = nn.Linear(in_size, 2 * hidden_size) # 前半主分支,后半门控分支
self.w2 = nn.Linear(hidden_size, out_size)
self.dropout = nn.Dropout(dropout)
else:
raise ValueError("SwiGLU FFN 仅支持 num_layers=1 或 2")
def forward(self, x):
if hasattr(self, "ffn"):
return self.ffn(x)
else:
x_proj = self.w1(x)
x_main, x_gate = x_proj.chunk(2, dim=-1) # 一分为二
x = F.silu(x_main) * x_gate # SwiGLU: silu(a) * b
x = self.dropout(x)
x = self.w2(x)
return x
class multiMLP(nn.Module):
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
super().__init__()
@ -209,6 +277,28 @@ class TransformerLayer(nn.Module):
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
return output_dict
class TransformerLayerV2(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.self_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.cross_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.residual_FF = ResidualLayerNormModule(SwiGLUFFN(in_size=dim, out_size=dim, num_layers=2, hidden_size=4*dim, dropout=dropout))
self.dropout = nn.Dropout(dropout)
def forward(self, input_dict):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
# self attention
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
input_dict['input_seq'] = attn_output
# cross attention
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
return output_dict
class FeatureEnricher(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
@ -225,4 +315,483 @@ class FeatureEnricher(nn.Module):
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
return output_dict
return output_dict
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information from different representation subspaces.
.. note::
See `this tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
for an in depth discussion of the performant building blocks PyTorch offers for building your own
transformer layers.
Method described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
Multi-Head Attention is defined as:
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O
where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
``nn.MultiheadAttention`` will use the optimized implementations of
``scaled_dot_product_attention()`` when possible.
In addition to support for the new ``scaled_dot_product_attention()``
function, for speeding up Inference, MHA will use
fastpath inference with support for Nested Tensors, iff:
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
- inputs are batched (3D) with ``batch_first==True``
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
- training is disabled (using ``.eval()``)
- ``add_bias_kv`` is ``False``
- ``add_zero_attn`` is ``False``
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
nor ``attn_mask`` is passed
- autocast is disabled
If the optimized inference fastpath implementation is in use, a
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
``query``/``key``/``value`` to represent padding more efficiently than using a
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
will be returned, and an additional speedup proportional to the fraction of the input
that is padding can be expected.
Args:
embed_dim: Total dimension of the model.
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Examples::
>>> # xdoctest: +SKIP
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
https://arxiv.org/abs/2205.14135
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
) -> None:
if embed_dim <= 0 or num_heads <= 0:
raise ValueError(
f"embed_dim and num_heads must be greater than 0,"
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.0)
constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if "_qkv_same_embed_dim" not in state:
state["_qkv_same_embed_dim"] = True
super().__setstate__(state)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> tuple[Tensor, Optional[Tensor]]:
r"""Compute attention outputs using query, key, and value embeddings.
Supports optional parameters for padding, masks and attention weights.
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and float masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
and achieve the best performance for MHA.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
If both attn_mask and key_padding_mask are supplied, their types should match.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
is_causal: If specified, applies a causal mask as attention mask.
Default: ``False``.
Warning:
``is_causal`` provides a hint that ``attn_mask`` is the
causal mask. Providing incorrect hints can result in
incorrect execution, including forward and backward
compatibility.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
""" # noqa: B950
why_not_fast_path = ""
if (
(attn_mask is not None and torch.is_floating_point(attn_mask))
or (key_padding_mask is not None)
and torch.is_floating_point(key_padding_mask)
):
why_not_fast_path = "floating-point masks are not supported for fast path."
is_batched = query.dim() == 3
key_padding_mask = F._canonical_mask(
mask=key_padding_mask,
mask_name="key_padding_mask",
other_type=F._none_or_dtype(attn_mask),
other_name="attn_mask",
target_type=query.dtype,
)
attn_mask = F._canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
if not is_fastpath_enabled:
why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
elif not is_batched:
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif self.in_proj_weight is None:
why_not_fast_path = "in_proj_weight was None"
elif query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif (self.num_heads % 2) != 0:
why_not_fast_path = "self.num_heads is not even"
elif not self.batch_first:
why_not_fast_path = "batch_first was not True"
elif self.bias_k is not None:
why_not_fast_path = "self.bias_k was not None"
elif self.bias_v is not None:
why_not_fast_path = "self.bias_v was not None"
elif self.add_zero_attn:
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif query.is_nested and (
key_padding_mask is not None or attn_mask is not None
):
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
is not supported with NestedTensor input"
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"
if not why_not_fast_path:
tensor_args = (
query,
key,
value,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
)
# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif _is_make_fx_tracing():
why_not_fast_path = "we are running make_fx tracing"
elif not all(_check_arg_device(x) for x in tensor_args):
why_not_fast_path = (
"some Tensor argument's device is neither one of "
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
)
elif torch.is_grad_enabled() and any(
_arg_requires_grad(x) for x in tensor_args
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
if not why_not_fast_path:
merged_mask, mask_type = self.merge_masks(
attn_mask, key_padding_mask, query
)
if self.in_proj_bias is not None and self.in_proj_weight is not None:
return torch._native_multi_head_attention(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj.weight,
self.out_proj.bias,
merged_mask,
need_weights,
average_attn_weights,
mask_type,
)
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, (
"MultiheadAttention does not support NestedTensor outside of its fast path. "
+ f"The fast path was not hit because {why_not_fast_path}"
)
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = (x.transpose(1, 0) for x in (query, key))
value = key
else:
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
is_causal=is_causal,
)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
is_causal=is_causal,
)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
def merge_masks(
self,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
query: Tensor,
) -> tuple[Optional[Tensor], Optional[int]]:
r"""Determine mask type and combine masks if necessary.
If only one mask is provided, that mask
and the corresponding mask type will be returned. If both masks are provided, they will be both
expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
and mask type 2 will be returned
Args:
attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
Returns:
merged_mask: merged mask
mask_type: merged mask type (0, 1, or 2)
"""
mask_type: Optional[int] = None
merged_mask: Optional[Tensor] = None
if key_padding_mask is not None:
mask_type = 1
merged_mask = key_padding_mask
if attn_mask is not None:
# In this branch query can't be a nested tensor, so it has a shape
batch_size, seq_len, _ = query.shape
mask_type = 2
# Always expands attn_mask to 4D
if attn_mask.dim() == 3:
attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
else: # attn_mask.dim() == 2:
attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
batch_size, self.num_heads, -1, -1
)
merged_mask = attn_mask_expanded
if key_padding_mask is not None:
key_padding_mask_expanded = key_padding_mask.view(
batch_size, 1, 1, seq_len
).expand(-1, self.num_heads, -1, -1)
merged_mask = attn_mask_expanded + key_padding_mask_expanded
# no attn_mask and no key_padding_mask, returns None, None
return merged_mask, mask_type

View File

@ -347,13 +347,24 @@ class SelfAttention(SubDecoderClass):
causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size)
self.register_buffer('causal_mask', causal_mask)
# self.transformer_decoder = Decoder(
# dim = dim,
# depth = sub_decoder_depth,
# heads = heads,
# attn_dropout = dropout,
# ff_dropout = dropout,
# attn_flash = True)
self.transformer_decoder = Decoder(
dim = dim,
dim = dim,
depth = sub_decoder_depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
attn_flash = True,
use_rmsnorm=True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
@ -713,7 +724,7 @@ class DiffusionDecoder(SubDecoderClass):
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 6,
denoising_steps:int = 8,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
@ -1091,7 +1102,7 @@ class DiffusionDecoder(SubDecoderClass):
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
class DiffusionDecoder(SubDecoderClass):
class DiffusionDecoderV2(SubDecoderClass):
def __init__(
self,
prediction_order:list,
@ -1102,7 +1113,7 @@ class DiffusionDecoder(SubDecoderClass):
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 6,
denoising_steps:int = 8,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
@ -1129,7 +1140,7 @@ class DiffusionDecoder(SubDecoderClass):
self.input_norm = nn.LayerNorm(dim)
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
self.feature_boost_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
@ -1138,14 +1149,21 @@ class DiffusionDecoder(SubDecoderClass):
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causal_ca_mask)
# get depth of the sub-decoder
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
self.sub_decoder_layers = nn.Sequential(*[TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
else:
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
self.sub_decoder_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
self.aux_ar_decoder = SelfAttention(prediction_order=prediction_order,
vocab=vocab,
sub_decoder_depth=1,
dim=dim,
heads=heads,
dropout=dropout,
sub_decoder_enricher_use=False)
# simplified version of the forward process in diffusion model
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
@ -1273,9 +1291,11 @@ class DiffusionDecoder(SubDecoderClass):
# print("sampled_token_dict", sampled_token_dict)
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None, aux_ar=False):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
copy_input_dict = input_dict.copy()
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
@ -1307,6 +1327,10 @@ class DiffusionDecoder(SubDecoderClass):
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
if aux_ar: # inference with auxiliary auto-regressive decoder
aux_ar_logits, sampled_token_dict = self.aux_ar_decoder(copy_input_dict, sampling_method='auto-regressive', threshold=threshold, temperature=temperature, condition_step=condition_step)
# print("aux_ar_logits", aux_ar_logits)
return aux_ar_logits, sampled_token_dict
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
@ -1420,4 +1444,7 @@ class DiffusionDecoder(SubDecoderClass):
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
# get aux ar decoder logits
aux_ar_logits = self.aux_ar_decoder(copy_input_dict, target) # B x T
return (logits_dict, aux_ar_logits), (masked_indices, p_mask)
# return logits_dict, (masked_indices, p_mask)

View File

@ -3,6 +3,7 @@ import random
from pathlib import Path
from collections import OrderedDict
from typing import Union, List, Tuple, Dict
import torch
import numpy as np
import matplotlib.pyplot as plt
@ -157,6 +158,8 @@ class IterTuneCompiler(IterableDataset):
segment = self.augmentor(segment)
# use input_ids replace tune_name
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
# print(segment.shape, mask.shape, tune_name.shape)
# segment = segment[torch.randperm(segment.size(0))]
yield segment, mask, tune_name, encoded_caption
def __len__(self):

View File

@ -1,7 +1,9 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
- nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
# - nn_params: oct8_embSum_diff_t2m_300M_pretrainingv3
# - nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
- nn_params: oct8_embSum_har_t2m_600M_pretrainingv3
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
# - nn_params: nb8_embSum_subPararell
@ -15,7 +17,7 @@ defaults:
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
@ -23,28 +25,28 @@ captions_path: dataset/midicaps/train_set.json
use_ddp: True # True, False | distributed data parallel
use_fp16: True # True, False | mixed precision training
use_diff: True # True,use diffusion in subdecoder
use_diff: False # True,use diffusion in subdecoder
diff_steps: 8 # number of diffusion steps
use_dispLoss: True
use_dispLoss: False
lambda_weight: 0.5
tau: 0.5
train_params:
device: cuda
batch_size: 10
batch_size: 9
grad_clip: 1.0
num_iter: 300000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
iterations_per_training_cycle: 10 # number of iterations for logging training loss
iterations_per_validation_cycle: 3000 # number of iterations for validation process
input_length: 3072 # input sequence length3072
input_length: 2048 # input sequence length3072
# you can use focal loss, it it's not used, set focal_gamma to 0
focal_alpha: 1
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.00001
initial_lr: 0.0003
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps

View File

@ -5,7 +5,7 @@ model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8

View File

@ -0,0 +1,19 @@
encoding_scheme: oct
num_features: 8
vocab_name: MusicTokenVocabOct
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoderV2
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: oct
num_features: 8
vocab_name: MusicTokenVocabOct
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoderV2
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1080
num_layer: 13
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: oct
num_features: 8
vocab_name: MusicTokenVocabOct
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoderV2
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1080
num_layer: 13
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: oct
num_features: 8
vocab_name: MusicTokenVocabOct
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: SelfAttention
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1080
num_layer: 13
num_head: 12
sub_decoder:
decout_window_size: 3 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: oct
num_features: 8
vocab_name: MusicTokenVocabOct
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: SelfAttention
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1272
num_layer: 20
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

454
Amadeus/toy_train.py Normal file
View File

@ -0,0 +1,454 @@
from math import ceil
from typing import Optional, Union, Literal
from typing_extensions import Unpack
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from transformers import Qwen2Config
from transformers.activations import get_activation
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
x = self.weight * x
return x.to(dtype)
class FeedForward(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: Optional[int] = None,
dropout: float = 0.,
hidden_act: Literal['silu', 'gelu'] = 'silu'):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = get_activation(hidden_act)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
gate, hidden = self.up_proj(x).chunk(2, dim=-1)
gate = self.act_fn(gate)
return self.down_proj(self.dropout(gate * hidden))
class MultiHeadAttention(nn.Module):
def __init__(self,
hidden_size: int,
head_dim: int = 64,
num_attention_heads: Optional[int] = None,
attention_dropout: float = 0.,
bias: bool = False,
rms_norm_eps: float = 1e-6,
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
):
super().__init__()
self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim)
self.head_dim = head_dim
self.scaling = head_dim ** -0.5
self.attention_dropout = attention_dropout
self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.o_proj = nn.Linear(
self.num_attention_heads * self.head_dim, hidden_size,
bias=bias)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation]
# To be compatible with huggingface
self.config = Qwen2Config()
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
key_states: Optional[Tensor] = None,
value_states: Optional[Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
q = self.q_proj(hidden_states)
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2)
if(key_states is None or value_states is None):
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2)
value_states = v.view(hidden_shape).transpose(1, 2)
attn_output, attn_weights = self.attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DecoderLayerWithCrossAttention(nn.Module):
def __init__(self,
hidden_size: int,
attn_head_dim: int = 64,
num_attention_heads: Optional[int] = None,
attention_dropout: float = 0.,
attn_bias: bool = False,
mlp_intermediate_size: Optional[int] = None,
mlp_dropout: float = 0.,
mlp_act: Literal['gelu', 'silu'] = 'silu',
rms_norm_eps: float = 1e-6,
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
):
super().__init__()
self.hidden_size = hidden_size
self.self_attn = MultiHeadAttention(
hidden_size,
attn_head_dim,
num_attention_heads,
attention_dropout,
attn_bias,
rms_norm_eps,
attn_implementation)
self.cross_attn = MultiHeadAttention(
hidden_size,
attn_head_dim,
num_attention_heads,
attention_dropout,
attn_bias,
rms_norm_eps,
attn_implementation)
self.mlp = FeedForward(
hidden_size,
mlp_intermediate_size,
mlp_dropout,
mlp_act)
self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_dict,
):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
hidden_states = input_dict['input_seq']
cross_attn_states = input_dict['memory']
cross_attn_mask = input_dict['memory_mask']
attention_mask = input_dict['memory_mask']
output_attentions = False
# Self Attention
residual = hidden_states
hidden_states = self.self_attn_layernorm(hidden_states)
hidden_states, dec_attn_weight = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
# Cross Attention
if(cross_attn_states is not None):
residual = hidden_states
hidden_states = self.cross_attn_layernorm(hidden_states)
cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim)
key_states = self.cross_attn.k_proj(cross_attn_states)
value_states = self.cross_attn.v_proj(cross_attn_states)
key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2)
value_states = value_states.view(cross_attn_shape).transpose(1, 2)
hidden_states, cross_attn_weight = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attn_mask,
key_states=key_states,
value_states=value_states
)
hidden_states = residual + hidden_states
# Feed Forward
residual = hidden_states
hidden_states = self.ffn_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask}
if(output_attentions):
if(cross_attn_states is not None):
# return (hidden_states, (dec_attn_weight, cross_attn_weight))
output_dict['dec_attn_weight'] = dec_attn_weight
output_dict['cross_attn_weight'] = cross_attn_weight
return output_dict
else:
# return (hidden_states, dec_attn_weight)
output_dict['dec_attn_weight'] = dec_attn_weight
return output_dict
else:
return output_dict
def sinusoidal_positional_embedding(
x: Tensor,
n: float = 10000.0) -> Tensor:
device, dtype = x.device, x.dtype
T = x.shape[-2]
D = x.shape[-1]
positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1)
embeddings = torch.zeros(T, D, device=device, dtype=dtype)
denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model))
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model))
return x + embeddings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 创建一个简单的数据集来测试收敛性
class SimpleSeq2SeqDataset(Dataset):
def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256):
self.num_samples = num_samples
self.seq_len = seq_len
self.vocab_size = vocab_size
self.hidden_size = hidden_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 模拟编码器输出 (memory)
memory = torch.randn(self.seq_len, self.hidden_size)
# 模拟解码器输入和目标
decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,))
target = torch.randint(0, self.vocab_size, (self.seq_len,))
# 创建简单的注意力掩码
memory_mask = torch.ones(self.seq_len, self.seq_len)
return {
'memory': memory,
'decoder_input': decoder_input,
'target': target,
'memory_mask': memory_mask
}
# 创建简单的模型
class SimpleDecoderModel(nn.Module):
def __init__(self, vocab_size, hidden_size=256, num_layers=3):
super(SimpleDecoderModel, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# 嵌入层
self.embedding = nn.Embedding(vocab_size, hidden_size)
# 位置编码
self.pos_encoding = sinusoidal_positional_embedding
# 多个解码器层
self.layers = nn.ModuleList([
DecoderLayerWithCrossAttention(
hidden_size=hidden_size,
attn_head_dim=64,
num_attention_heads=2,
attention_dropout=0.1,
attn_bias=False,
mlp_intermediate_size=512,
mlp_dropout=0.1,
mlp_act='silu',
rms_norm_eps=1e-6,
attn_implementation="sdpa"
) for _ in range(num_layers)
])
# 输出层
self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6)
self.output_proj = nn.Linear(hidden_size, vocab_size)
def forward(self, decoder_input, memory, memory_mask):
# 嵌入 + 位置编码
x = self.embedding(decoder_input)
x = self.pos_encoding(x)
# 通过多个解码器层
input_dict = {
'input_seq': x,
'memory': memory,
'memory_mask': memory_mask
}
for layer in self.layers:
output_dict = layer(input_dict)
input_dict['input_seq'] = output_dict['input_seq']
# 最终输出
hidden_states = self.output_norm(output_dict['input_seq'])
logits = self.output_proj(hidden_states)
return logits
# 训练函数
def train_decoder_model():
# 超参数
vocab_size = 100
hidden_size = 256
seq_len = 10
batch_size = 32
num_epochs = 10
learning_rate = 0.001
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 创建数据集和数据加载器
dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len,
vocab_size=vocab_size, hidden_size=hidden_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size)
model.to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练记录
train_losses = []
print("开始训练...")
model.train()
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0
for batch in dataloader:
# 移动到设备
memory = batch['memory'].to(device)
decoder_input = batch['decoder_input'].long().to(device)
target = batch['target'].long().to(device)
memory_mask = batch['memory_mask'].to(device)
# 前向传播
optimizer.zero_grad()
logits = model(decoder_input, memory, memory_mask)
# 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len)
loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1))
# 反向传播
loss.backward()
optimizer.step()
epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
train_losses.append(avg_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
# 早期停止检查:如果损失明显下降,说明模型在收敛
if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值
print("✅ 模型正常收敛!")
return True
# 检查最终损失
final_loss = train_losses[-1]
if final_loss < 5.0:
print("✅ 训练完成,模型正常收敛!")
return True
else:
print("❌ 损失下降不明显,可能需要调整模型或超参数")
return False
# 运行训练
if __name__ == "__main__":
# 先测试单个decoder层的前向传播
print("测试DecoderLayerWithCrossAttention前向传播...")
# 创建测试数据
batch_size, seq_len, hidden_size = 2, 64, 256
decoder_input = torch.randn(seq_len, hidden_size)
memory = torch.randn(seq_len, hidden_size)
memory_mask = torch.ones(seq_len, seq_len)
# 测试单层
decoder_layer = DecoderLayerWithCrossAttention(
hidden_size=hidden_size,
attn_head_dim=64,
num_attention_heads=4
)
input_dict = {
'input_seq': decoder_input,
'memory': memory,
'memory_mask': memory_mask
}
try:
output = decoder_layer(input_dict)
print("✅ DecoderLayer前向传播测试通过")
print(f"输入形状: {decoder_input.shape}")
print(f"输出形状: {output['input_seq'].shape}")
# 运行完整训练
success = train_decoder_model()
if success:
print("\n🎉 测试完成DecoderLayerWithCrossAttention能够正常训练和收敛")
else:
print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数")
except Exception as e:
print(f"❌ 前向传播测试失败: {e}")

View File

@ -228,19 +228,39 @@ class DiffusionLoss4CompoundToken():
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
return loss
def get_aux_ar_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
train_loss_list = []
log_loss_dict_normal = {}
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
disp_loss = None
aux_ar_logits = None
# print(len(logits_dict))
if len(logits_dict) == 2: # has aux ar loss
logits_dict, aux_ar_logits = logits_dict
if input_dict is not None:
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
feat = hidden_vec.mean(dim=1) #bs,dim
disp_loss = dispersive_loss(feat, tau=tau) # scalar
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
if aux_ar_logits is not None:
aux_ar_loss = self.get_aux_ar_nll_loss(aux_ar_logits[key], shifted_tgt[..., idx], mask)
training_loss = 0.5 * training_loss + 0.5 * aux_ar_loss
train_loss_list.append(training_loss)
if valid:
if key == 'type' or key == 'timesig':

View File

@ -74,8 +74,8 @@ class LanguageModelTrainer:
sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings)
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
# model_checkpoint="wandb/run-20251114_151512-k21rnynj/files/checkpoints/iter104999_loss0.2490.pt", # Path to a pre-trained model checkpoint (optional)
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
):
# Save model, optimizer, and other configurations
self.model = model
@ -892,6 +892,10 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
try:
aux_ar_logits, logits_dict = logits_dict
except:
logits_dict = logits_dict
probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()}
num_nonmask_tokens = torch.sum(mask)
input_seq = input_seq.to(self.device)

View File

@ -729,11 +729,6 @@ class XtransformerNewPretrainingDecoder(nn.Module):
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,

View File

@ -20,10 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from symusic import Score
from miditok import Octuple, TokenizerConfig
from itertools import groupby, chain
from random import shuffle, seed
lock = RLock()
def shuffled(seq):
shuffle(seq)
return seq
def permute_inside_and_across_tracks(seq):
seq_sorted = sorted(seq, key=lambda x: x[5])
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
def convert_event_dicts(dict_list):
"""
将 event 词表列表按顺序转换为结构化输出
@ -59,7 +71,7 @@ def convert_event_dicts(dict_list):
return result
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool):
try:
score = Score(midi_path, ttype="tick")
@ -71,6 +83,8 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
# 分词
tok_seq = tokenizer(score)
token_ids = tok_seq.ids
if whether_shuffle:
token_ids = permute_inside_and_across_tracks(token_ids)
# add sos token at the beginning
vocab = tokenizer.vocab
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
@ -86,7 +100,7 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
print(f" × 处理文件时出错:{midi_path} -> {e}")
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False):
# === 1. 初始化分词器并保存词表 ===
print("初始化分词器 Octuple...")
config = TokenizerConfig(
@ -108,15 +122,20 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
os.makedirs(output_dir, exist_ok=True)
# === 3. 收集 MIDI 文件 ===
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
midi_paths = list(midi_paths)
# midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
# glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
# midi_paths = list(midi_paths)
midi_paths = []
for root, _, files in os.walk(midi_dir):
for file in files:
if file.lower().endswith(('.mid', '.midi')):
midi_paths.append(os.path.join(root, file))
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
# === 4. 并行处理 ===
results = []
with ProcessPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths}
for future in tqdm(as_completed(futures), total=len(futures)):
res = future.result()
@ -128,8 +147,18 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
if __name__ == "__main__":
import argparse
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)")
parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录")
parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序")
dataset_name = midi_directory.split("/")[-1]
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
output_dir = tuneidx_prefix
preprocess_midi_directory(midi_directory, output_dir)
args = parser.parse_args()
preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle)

View File

@ -0,0 +1,39 @@
from itertools import groupby, chain
from random import shuffle, seed
def shuffled(seq):
shuffle(seq)
return seq
def permute_inside_and_across_tracks(seq):
seq_sorted = sorted(seq, key=lambda x: x[5])
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
# (PitchDrum, Position, Bar, Velocity, Duration, Program, Tempo, TimeSignature)
seq = [
# Program 0
(60, 0, 0, 90, 96, 0, 120, 16),
(64, 48,0, 88, 96, 0, 120, 16),
(67, 96,0, 92, 96, 0, 120, 16),
# Program 32
(40, 0, 0, 80, 192, 32, 120, 16),
(43, 0, 1, 78, 192, 32, 120, 16),
# Program 40
(72, 24,0, 85, 72, 40, 120, 16),
(74, 72,0, 83, 72, 40, 120, 16),
(76, 24,1, 86, 72, 40, 120, 16),
]
# seed(42)
inside_track_permuted_and_track_permuted = permute_inside_and_across_tracks(seq)
print("原始 seq")
for e in seq:
print(e)
print("\n打乱后的 seq")
for e in inside_track_permuted_and_track_permuted:
print(e)

View File

@ -0,0 +1,634 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于token分布距离的重采样脚本
读取octuple_token_analysis_report.json计算每个数据与整体分布的距离
按照距离加权采样,距离越远的越容易被采样
"""
import os
import numpy as np
from pathlib import Path
from collections import Counter
from tqdm import tqdm
import json
from scipy.stats import entropy, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_distribution_from_json(json_path):
"""
从JSON文件中加载整体token分布
Args:
json_path: JSON文件路径
Returns:
dict: {column_name: {token: probability}}
"""
print(f"读取分布文件: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
report = json.load(f)
distributions = {}
columns = report.get('columns', {})
for col_name in COLUMN_NAMES:
if col_name not in columns:
print(f"警告: 列 {col_name} 不在报告中")
distributions[col_name] = {}
continue
col_data = columns[col_name]
token_counts = col_data.get('token_counts', {})
total_tokens = col_data.get('total_tokens', 1)
# 转换为概率分布
distribution = {}
for token_str, count in token_counts.items():
token = int(token_str)
distribution[token] = count / total_tokens
distributions[col_name] = distribution
print(f"{col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}")
return distributions
def compute_data_distribution(data, col_idx):
"""
计算单个数据在指定列的token分布
Args:
data: numpy数组 (num_tokens, num_columns)
col_idx: 列索引
Returns:
dict: {token: probability}
"""
if data.size == 0:
return {}
tokens = data[:, col_idx]
unique, counts = np.unique(tokens, return_counts=True)
total = len(tokens)
distribution = {}
for token, count in zip(unique, counts):
distribution[int(token)] = count / total
return distribution
def compute_emd_distance(dist1, dist2):
"""
使用推土机距离Earth Mover's Distance / Wasserstein距离计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability},已归一化
dist2: 分布2dict {token: probability},已归一化
Returns:
float: EMD距离
"""
# 获取所有token的并集并排序
all_tokens = sorted(set(dist1.keys()) | set(dist2.keys()))
if not all_tokens:
return 0.0
# 构建概率向量和token值向量
p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens])
q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens])
token_values = np.array(all_tokens, dtype=float)
# 归一化(处理数值误差)
p_sum = p_weights.sum()
q_sum = q_weights.sum()
if p_sum < 1e-10 or q_sum < 1e-10:
return 0.0
p_weights = p_weights / p_sum
q_weights = q_weights / q_sum
# 使用Wasserstein距离1-Wasserstein距离即推土机距离
# wasserstein_distance需要两个分布的样本值位置和权重
# 对于离散分布我们使用token值作为位置
emd = wasserstein_distance(token_values, token_values, p_weights, q_weights)
return emd
def compute_distribution_distance(dist1, dist2, method='emd'):
"""
计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability}
dist2: 分布2dict {token: probability}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 分布距离
"""
if method == 'emd':
return compute_emd_distance(dist1, dist2)
# 获取所有token的并集
all_tokens = set(dist1.keys()) | set(dist2.keys())
if not all_tokens:
return 0.0
# 构建概率向量
p = np.array([dist1.get(token, 0.0) for token in all_tokens])
q = np.array([dist2.get(token, 0.0) for token in all_tokens])
# 归一化(处理数值误差)
p = p / (p.sum() + 1e-10)
q = q / (q.sum() + 1e-10)
if method == 'js':
# Jensen-Shannon散度对称范围[0, 1]
return jensenshannon(p, q)
elif method == 'kl':
# KL散度非对称需要处理零值
# 添加小的平滑项避免log(0)
p = p + 1e-10
q = q + 1e-10
p = p / p.sum()
q = q / q.sum()
return entropy(p, q)
else:
raise ValueError(f"未知的距离方法: {method}")
def extract_subdistribution(global_dist, data_tokens):
"""
从全局分布中提取只包含数据中出现的token的子分布并归一化
Args:
global_dist: 全局分布dict {token: probability}
data_tokens: 数据中出现的token集合set或list
Returns:
dict: 子分布dict {token: probability},已归一化
"""
if not data_tokens or not global_dist:
return {}
# 提取子分布
sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens}
# 归一化
total = sum(sub_dist.values())
if total < 1e-10:
return {}
normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()}
return normalized_sub_dist
def compute_data_distance(data, global_distributions, method='emd'):
"""
计算单个数据与整体分布的距离
对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布,
都进行归一化然后计算推土机距离
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载)
global_distributions: 整体分布dict {column_name: {token: probability}}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 平均距离(跨所有列)
"""
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
try:
data = np.load(data)['arr_0']
except Exception as e:
# 不打印错误,让调用者处理
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
distances = []
for col_idx, col_name in enumerate(COLUMN_NAMES):
# 计算该数据在该列的分布
data_dist = compute_data_distribution(data, col_idx)
# 获取整体分布
global_dist = global_distributions.get(col_name, {})
if not data_dist or not global_dist:
continue
# 从全局分布中提取只包含数据中出现的token的子分布
data_tokens = set(data_dist.keys())
sub_global_dist = extract_subdistribution(global_dist, data_tokens)
if not sub_global_dist:
continue
# 归一化数据分布
data_dist_sum = sum(data_dist.values())
if data_dist_sum < 1e-10:
continue
normalized_data_dist = {token: prob / data_dist_sum
for token, prob in data_dist.items()}
# 计算距离(两个分布都已归一化)
dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method)
distances.append(dist)
# 返回平均距离
return np.mean(distances) if distances else 0.0
def _load_single_file(npz_file):
"""
加载单个npz文件的辅助函数用于多线程
Args:
npz_file: npz文件路径
Returns:
tuple: (data, file_path) 或 None如果加载失败
"""
try:
data = np.load(npz_file)['arr_0']
if data.ndim == 2:
return (data, npz_file)
elif data.ndim == 1:
print(f"警告: {npz_file} 是一维数组,跳过")
return None
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
return None
def get_data_file_paths(data_dir):
"""
获取所有数据文件路径(不加载数据)
Args:
data_dir: 数据目录路径
Returns:
list: 文件路径列表
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件")
return npz_files
def load_data_with_paths(data_dir, num_threads=None, lazy=False):
"""
加载所有数据并返回数据路径列表(多线程版本)
Args:
data_dir: 数据目录路径
num_threads: 线程数None表示使用CPU核心数
lazy: 如果为True只返回文件路径不加载数据
Returns:
tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return [], []
if lazy:
print(f"找到 {len(npz_files)} 个.npz文件延迟加载模式")
return None, npz_files
print(f"找到 {len(npz_files)} 个.npz文件开始加载...")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(npz_files))
all_data = []
file_paths = []
# 使用多线程加载文件
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(_load_single_file, npz_file): npz_file
for npz_file in npz_files}
for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"):
result = future.result()
if result is not None:
data, file_path = result
all_data.append(data)
file_paths.append(file_path)
# 保持原始顺序
if file_paths:
sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0]))
file_paths, all_data = zip(*sorted_pairs)
file_paths = list(file_paths)
all_data = list(all_data)
return all_data, file_paths
def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True):
"""
根据距离进行加权重采样
Args:
file_paths: 文件路径列表
distances: 距离列表
sample_ratio: 采样比例
method: 距离计算方法(用于确定权重方向)
lazy: 如果为True返回文件路径而不是数据
Returns:
tuple: (sampled_data_or_paths, sampled_paths, sampled_indices)
"""
n_samples = int(len(file_paths) * sample_ratio)
print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)")
# 将距离转换为权重
# 距离越远,权重越大
distances = np.array(distances)
# 处理零距离或异常值
min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10
distances = np.maximum(distances, min_dist * 0.1)
# 归一化距离到[0, 1],然后转换为权重
# 使用指数函数增强距离差异
normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10)
weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样
# 归一化权重
weights = weights / weights.sum()
# 加权随机采样
indices = np.arange(len(file_paths))
sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights)
sampled_paths = [file_paths[i] for i in sampled_indices]
# 如果lazy=True返回路径否则加载数据
if lazy:
sampled_data = sampled_paths # 返回路径,延迟加载
else:
# 加载采样后的数据
sampled_data = []
for path in tqdm(sampled_paths, desc="加载采样数据"):
try:
data = np.load(path)['arr_0']
sampled_data.append(data)
except Exception as e:
print(f"错误: 加载 {path} 时出错: {e}")
sampled_data.append(None)
print(f"采样完成:")
print(f" 采样数据数量: {len(sampled_paths)}")
print(f" 平均距离: {distances[sampled_indices].mean():.6f}")
print(f" 最小距离: {distances[sampled_indices].min():.6f}")
print(f" 最大距离: {distances[sampled_indices].max():.6f}")
return sampled_data, sampled_paths, sampled_indices
def _save_single_file(args_tuple):
"""
保存单个文件的辅助函数(用于多线程)
支持延迟加载如果data是路径则从文件加载
Args:
args_tuple: (data, original_path, output_dir)
Returns:
tuple: (success, original_path) 或 (False, original_path, error_msg)
"""
data, original_path, output_dir = args_tuple
try:
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
data = np.load(data)['arr_0']
# 保持相对路径结构
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
output_path = output_dir / relative_path
output_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(output_path, data)
return (True, original_path)
except Exception as e:
error_msg = str(e)
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
return (False, original_path, error_msg)
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
"""
保存采样后的数据(多线程版本)
Args:
sampled_data: 采样后的数据列表
sampled_paths: 采样后的文件路径列表
output_dir: 输出目录
num_threads: 线程数None表示使用CPU核心数
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n保存采样数据到: {output_dir}")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(sampled_data))
# 准备参数
save_args = [(data, original_path, output_dir)
for data, original_path in zip(sampled_data, sampled_paths)]
# 使用多线程保存文件
success_count = 0
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_save_single_file, args)
for args in save_args]
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
try:
result = future.result(timeout=300) # 设置超时避免卡死
if isinstance(result, tuple) and len(result) >= 2:
success = result[0]
if success:
success_count += 1
except Exception as e:
print(f"错误: 获取保存结果时出错: {e}")
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="基于token分布距离的重采样")
parser.add_argument("--json_path", type=str,
default="octuple_token_analysis_report.json",
help="token分析报告JSON文件路径")
parser.add_argument("--data_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
help="数据目录路径")
parser.add_argument("--output_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled",
help="输出目录路径")
parser.add_argument("--sample_ratio", type=float, default=0.3,
help="采样比例 (默认: 0.3)")
parser.add_argument("--distance_method", type=str, default="emd",
choices=["emd", "js", "kl"],
help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子")
parser.add_argument("--num_threads", type=int, default=1,
help="线程数None表示使用CPU核心数 (默认: None)")
args = parser.parse_args()
# 设置随机种子
np.random.seed(args.seed)
# 1. 加载整体分布
global_distributions = load_distribution_from_json(args.json_path)
# 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据)
_, file_paths = load_data_with_paths(args.data_dir, lazy=True)
if not file_paths:
print("错误: 未找到任何数据文件")
return
print(f"\n共找到 {len(file_paths)} 个数据文件")
# 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载)
print("\n计算每个数据与整体分布的距离(延迟加载模式)...")
def _compute_distance_wrapper(args_tuple):
"""计算距离的包装函数(用于多线程,支持延迟加载)"""
idx, file_path, global_dists, method = args_tuple
try:
distance = compute_data_distance(file_path, global_dists, method=method)
return (idx, distance, None)
except Exception as e:
return (idx, 0.0, str(e))
if args.num_threads is None:
num_threads = min(mp.cpu_count(), len(file_paths))
else:
num_threads = args.num_threads
# 准备参数(使用文件路径而不是数据,包含索引)
distance_args = [(i, file_path, global_distributions, args.distance_method)
for i, file_path in enumerate(file_paths)]
# 使用多线程计算距离(按需加载数据)
# 初始化结果列表,保持顺序
distances = [0.0] * len(file_paths)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_compute_distance_wrapper, args)
for args in distance_args]
# 收集结果,使用 tqdm 显示进度
for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"):
try:
idx, distance, error = future.result(timeout=300) # 设置超时避免卡死
distances[idx] = distance
if error:
print(f"警告: 计算距离时出错 (索引 {idx}): {error}")
except Exception as e:
print(f"错误: 获取结果时出错: {e}")
# 如果无法获取结果,保持默认值 0.0
distances = np.array(distances)
print(f"\n距离统计:")
print(f" 平均距离: {distances.mean():.6f}")
print(f" 最小距离: {distances.min():.6f}")
print(f" 最大距离: {distances.max():.6f}")
print(f" 标准差: {distances.std():.6f}")
# 4. 根据距离进行加权采样(延迟加载模式)
sampled_data, sampled_paths, sampled_indices = weighted_resample(
file_paths, distances,
sample_ratio=args.sample_ratio,
method=args.distance_method,
lazy=True # 使用延迟加载,避免重复加载数据
)
# 5. 保存采样结果(多线程,延迟加载)
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
# 6. 保存采样索引(可选,用于后续分析)
indices_file = Path(args.output_dir) / "sampled_indices.npy"
np.save(indices_file, sampled_indices)
print(f"\n采样索引已保存到: {indices_file}")
# 保存采样信息
info = {
"total_samples": len(file_paths),
"sampled_samples": len(sampled_data),
"sample_ratio": args.sample_ratio,
"distance_method": args.distance_method,
"distance_stats": {
"mean": float(distances.mean()),
"min": float(distances.min()),
"max": float(distances.max()),
"std": float(distances.std())
},
"sampled_distance_stats": {
"mean": float(distances[sampled_indices].mean()),
"min": float(distances[sampled_indices].min()),
"max": float(distances[sampled_indices].max()),
"std": float(distances[sampled_indices].std())
}
}
info_file = Path(args.output_dir) / "resample_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(info, f, indent=2, ensure_ascii=False)
print(f"采样信息已保存到: {info_file}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,472 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于position和duration token的重采样脚本V2
对于每首歌:
1. 如果包含的position和duration不在总数据集前3个则必定采样
2. 对于包含的,以某个固定的百分比采样
3. 两个条件满足一个即可
"""
import os
import numpy as np
from pathlib import Path
from collections import Counter
from tqdm import tqdm
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_top_tokens_from_json(json_path, column_name, top_k=3):
"""
从JSON文件中加载指定列的前top_k个最常见的token
Args:
json_path: JSON文件路径
column_name: 列名(如'position''duration'
top_k: 返回前k个最常见的token
Returns:
set: 前top_k个最常见的token集合
"""
print(f"读取分布文件: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
report = json.load(f)
columns = report.get('columns', {})
if column_name not in columns:
print(f"警告: 列 {column_name} 不在报告中")
return set()
col_data = columns[column_name]
token_counts = col_data.get('token_counts', {})
# 按出现次数排序获取前top_k个
sorted_tokens = sorted(token_counts.items(), key=lambda x: int(x[1]), reverse=True)
top_tokens = {int(token_str) for token_str, _ in sorted_tokens[:top_k]}
print(f"{column_name} 的前{top_k}个最常见token: {sorted(top_tokens)}")
return top_tokens
def get_data_tokens(data, col_idx):
"""
获取单个数据在指定列的所有唯一token
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径
col_idx: 列索引
Returns:
set: 唯一token集合
"""
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
try:
data = np.load(data)['arr_0']
except Exception as e:
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
if data.size == 0:
return set()
tokens = data[:, col_idx]
unique_tokens = set(int(token) for token in np.unique(tokens))
return unique_tokens
def should_sample_song(data, top_position_tokens, top_duration_tokens,
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, rng=None):
"""
判断一首歌是否应该被采样
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径
top_position_tokens: position列的前3个最常见token集合
top_duration_tokens: duration列的前3个最常见token集合
contain_sample_ratio: 对于包含前3个token的歌曲采样比例
not_contain_sample_ratio: 对于不包含前3个token的歌曲采样比例更高概率
rng: 随机数生成器如果为None则使用全局的np.random
Returns:
tuple: (是否应该采样, 是否在前3个) - 在前3个指position和duration都在前3个
"""
# 获取position和duration列的唯一token
position_idx = COLUMN_NAMES.index("position")
duration_idx = COLUMN_NAMES.index("duration")
position_tokens = get_data_tokens(data, position_idx)
duration_tokens = get_data_tokens(data, duration_idx)
# 判断是否在前3个
position_in_top3 = bool(position_tokens & top_position_tokens)
duration_in_top3 = bool(duration_tokens & top_duration_tokens)
in_top3 = position_in_top3 and duration_in_top3
if rng is None:
rng = np.random
# 条件1: 如果position或duration不包含前3个token以更高概率采样
if not position_in_top3 or not duration_in_top3:
should_sample = rng.random() < not_contain_sample_ratio
return should_sample, False
# 条件2: 如果包含前3个token则以固定百分比采样
should_sample = rng.random() < contain_sample_ratio
return should_sample, True
def _load_single_file(npz_file):
"""
加载单个npz文件的辅助函数用于多线程
Args:
npz_file: npz文件路径
Returns:
tuple: (data, file_path) 或 None如果加载失败
"""
try:
data = np.load(npz_file)['arr_0']
if data.ndim == 2:
return (data, npz_file)
elif data.ndim == 1:
print(f"警告: {npz_file} 是一维数组,跳过")
return None
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
return None
def get_data_file_paths(data_dir):
"""
获取所有数据文件路径(不加载数据)
Args:
data_dir: 数据目录路径
Returns:
list: 文件路径列表
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件")
return npz_files
def resample_songs(file_paths, top_position_tokens, top_duration_tokens,
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9,
num_threads=None, seed=42):
"""
根据新逻辑进行重采样
Args:
file_paths: 文件路径列表
top_position_tokens: position列的前3个最常见token集合
top_duration_tokens: duration列的前3个最常见token集合
contain_sample_ratio: 对于包含前3个token的歌曲采样比例
not_contain_sample_ratio: 对于不包含前3个token的歌曲采样比例更高概率
num_threads: 线程数None表示使用CPU核心数
seed: 随机种子
Returns:
tuple: (sampled_paths, sampled_indices, stats)
"""
import threading
# 为每个线程创建独立的随机数生成器
thread_local = threading.local()
def get_thread_rng():
"""获取当前线程的随机数生成器"""
if not hasattr(thread_local, 'rng'):
# 使用线程ID和种子创建独立的随机数生成器
thread_id = threading.current_thread().ident
thread_local.rng = np.random.RandomState(seed + hash(thread_id) % 1000000)
return thread_local.rng
if num_threads is None:
num_threads = min(mp.cpu_count(), len(file_paths))
print(f"\n开始重采样,使用 {num_threads} 个线程...")
print(f" 包含前3个token的采样比例: {contain_sample_ratio*100:.1f}%")
print(f" 不包含前3个token的采样比例: {not_contain_sample_ratio*100:.1f}%")
def _should_sample_wrapper(args_tuple):
"""判断是否采样的包装函数(用于多线程)"""
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio = args_tuple
try:
# 使用线程本地的随机数生成器
thread_rng = get_thread_rng()
should_sample, in_top3 = should_sample_song(
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio, thread_rng
)
return (file_path, should_sample, in_top3, None)
except Exception as e:
return (file_path, False, False, str(e))
# 准备参数
sample_args = [(file_path, top_position_tokens, top_duration_tokens,
contain_sample_ratio, not_contain_sample_ratio)
for file_path in file_paths]
# 使用多线程判断每首歌是否应该采样
sampled_paths = []
sampled_indices = []
stats = {
'not_in_top3_count': 0, # 不在前3个的歌曲数量
'not_in_top3_sampled': 0, # 不在前3个且被采样的歌曲数量
'in_top3_count': 0, # 在前3个的歌曲数量
'in_top3_sampled': 0 # 在前3个且被采样的歌曲数量
}
# 限制并发任务数量,避免一次性提交过多任务
batch_size = min(1000, len(file_paths))
results = {}
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 分批提交任务
for batch_start in range(0, len(sample_args), batch_size):
batch_end = min(batch_start + batch_size, len(sample_args))
batch_args = sample_args[batch_start:batch_end]
futures = {executor.submit(_should_sample_wrapper, args): args[0]
for args in batch_args}
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures),
desc=f"判断采样 [{batch_start+1}-{batch_end}/{len(file_paths)}]",
leave=False):
try:
file_path, should_sample, in_top3, error = future.result(timeout=60)
results[file_path] = (should_sample, in_top3, error)
if error:
print(f"警告: 处理 {file_path} 时出错: {error}")
except Exception as e:
print(f"错误: 获取结果时出错: {e}")
# 按原始顺序处理结果,并统计
for idx, file_path in enumerate(file_paths):
if file_path not in results:
continue
should_sample, in_top3, error = results[file_path]
if error:
continue
# 统计信息
if in_top3:
stats['in_top3_count'] += 1
if should_sample:
stats['in_top3_sampled'] += 1
else:
stats['not_in_top3_count'] += 1
if should_sample:
stats['not_in_top3_sampled'] += 1
if should_sample:
sampled_paths.append(file_path)
sampled_indices.append(idx)
print(f"\n采样完成:")
print(f" 总歌曲数: {len(file_paths)}")
print(f" 采样歌曲数: {len(sampled_paths)}")
print(f" 采样比例: {len(sampled_paths)/len(file_paths)*100:.2f}%")
print(f" 不在前3个的歌曲数: {stats['not_in_top3_count']}")
print(f" 不在前3个且被采样的歌曲数: {stats['not_in_top3_sampled']}")
if stats['not_in_top3_count'] > 0:
print(f" 不在前3个的歌曲采样比例: {stats['not_in_top3_sampled']/stats['not_in_top3_count']*100:.2f}%")
print(f" 在前3个的歌曲数: {stats['in_top3_count']}")
print(f" 在前3个且被采样的歌曲数: {stats['in_top3_sampled']}")
if stats['in_top3_count'] > 0:
print(f" 在前3个的歌曲采样比例: {stats['in_top3_sampled']/stats['in_top3_count']*100:.2f}%")
return sampled_paths, sampled_indices, stats
def _save_single_file(args_tuple):
"""
保存单个文件的辅助函数(用于多线程)
支持延迟加载如果data是路径则从文件加载
Args:
args_tuple: (data, original_path, output_dir)
Returns:
tuple: (success, original_path) 或 (False, original_path, error_msg)
"""
data, original_path, output_dir = args_tuple
try:
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
data = np.load(data)['arr_0']
# 保持相对路径结构
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
output_path = output_dir / relative_path
output_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(output_path, data)
return (True, original_path)
except Exception as e:
error_msg = str(e)
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
return (False, original_path, error_msg)
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
"""
保存采样后的数据(多线程版本)
Args:
sampled_data: 采样后的数据列表
sampled_paths: 采样后的文件路径列表
output_dir: 输出目录
num_threads: 线程数None表示使用CPU核心数
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n保存采样数据到: {output_dir}")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(sampled_data))
# 准备参数
save_args = [(data, original_path, output_dir)
for data, original_path in zip(sampled_data, sampled_paths)]
# 使用多线程保存文件
success_count = 0
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_save_single_file, args)
for args in save_args]
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
try:
result = future.result(timeout=300) # 设置超时避免卡死
if isinstance(result, tuple) and len(result) >= 2:
success = result[0]
if success:
success_count += 1
except Exception as e:
print(f"错误: 获取保存结果时出错: {e}")
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="基于position和duration token的重采样V2")
parser.add_argument("--json_path", type=str,
default="octuple_token_analysis_report.json",
help="token分析报告JSON文件路径")
parser.add_argument("--data_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
help="数据目录路径")
parser.add_argument("--output_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2",
help="输出目录路径")
parser.add_argument("--contain_sample_ratio", type=float, default=0.1,
help="对于包含前3个token的歌曲采样比例 (默认: 0.1)")
parser.add_argument("--not_contain_sample_ratio", type=float, default=0.9,
help="对于不包含前3个token的歌曲采样比例 (默认: 0.9)")
parser.add_argument("--top_k", type=int, default=3,
help="使用前k个最常见的token (默认: 3)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子")
parser.add_argument("--num_threads", type=int, default=None,
help="线程数None表示使用CPU核心数 (默认: None)")
args = parser.parse_args()
# 1. 加载position和duration的前top_k个最常见token
top_position_tokens = load_top_tokens_from_json(
args.json_path, "position", top_k=args.top_k
)
top_duration_tokens = load_top_tokens_from_json(
args.json_path, "duration", top_k=args.top_k
)
if not top_position_tokens or not top_duration_tokens:
print("错误: 无法加载前top_k个token")
return
# 2. 获取所有数据文件路径
file_paths = get_data_file_paths(args.data_dir)
if not file_paths:
print("错误: 未找到任何数据文件")
return
print(f"\n共找到 {len(file_paths)} 个数据文件")
# 3. 根据新逻辑进行重采样
sampled_paths, sampled_indices, stats = resample_songs(
file_paths,
top_position_tokens,
top_duration_tokens,
contain_sample_ratio=args.contain_sample_ratio,
not_contain_sample_ratio=args.not_contain_sample_ratio,
num_threads=args.num_threads,
seed=args.seed
)
# 4. 保存采样结果(延迟加载)
sampled_data = sampled_paths # 使用路径,延迟加载
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
# 5. 保存采样索引(可选,用于后续分析)
indices_file = Path(args.output_dir) / "sampled_indices.npy"
np.save(indices_file, np.array(sampled_indices))
print(f"\n采样索引已保存到: {indices_file}")
# 6. 保存采样信息
info = {
"total_samples": len(file_paths),
"sampled_samples": len(sampled_paths),
"contain_sample_ratio": args.contain_sample_ratio,
"not_contain_sample_ratio": args.not_contain_sample_ratio,
"top_k": args.top_k,
"top_position_tokens": sorted(list(top_position_tokens)),
"top_duration_tokens": sorted(list(top_duration_tokens)),
"stats": stats
}
info_file = Path(args.output_dir) / "resample_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(info, f, indent=2, ensure_ascii=False)
print(f"采样信息已保存到: {info_file}")
if __name__ == "__main__":
main()

View File

@ -1,14 +1,282 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统计octuple分词结果中每一列每个token的出现次数并生成分析报告
"""
import os
import numpy as np
from pathlib import Path
from collections import defaultdict, Counter
from tqdm import tqdm
import json
# 读取 npz 文件
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
# 查看保存的键
print(data.files)
# 输出:['filename', 'sequence']
# 访问数据
sequence = data["arr_0"]
def load_octuple_data(data_dir):
"""
加载所有octuple分词后的.npz文件
Args:
data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录
Returns:
list: 所有加载的numpy数组列表
"""
data_dir = Path(data_dir)
npz_files = []
# 如果目录存在,查找所有.npz文件
if data_dir.exists():
npz_files = list(data_dir.rglob("*.npz"))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件开始加载...")
all_data = []
for npz_file in tqdm(npz_files, desc="加载数据"):
try:
data = np.load(npz_file)['arr_0']
# 确保数据是二维数组 (num_tokens, num_columns)
if data.ndim == 2:
all_data.append(data)
elif data.ndim == 1:
# 如果是一维可能需要reshape但octuple应该是二维的
print(f"警告: {npz_file} 是一维数组,跳过")
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
continue
return all_data
def count_tokens_by_column(all_data):
"""
统计每一列每个token的出现次数
Args:
all_data: 所有数据的列表每个元素是一个numpy数组 (num_tokens, num_columns)
Returns:
dict: {column_index: Counter({token_value: count})}
"""
column_counters = defaultdict(Counter)
print("统计token出现次数...")
for data in tqdm(all_data, desc="处理数据"):
if data.size == 0:
continue
num_columns = data.shape[1] if data.ndim == 2 else 1
for col_idx in range(num_columns):
if data.ndim == 2:
tokens = data[:, col_idx]
else:
tokens = data
# 统计该列中每个token的出现次数
unique, counts = np.unique(tokens, return_counts=True)
for token, count in zip(unique, counts):
column_counters[col_idx][int(token)] += int(count)
return dict(column_counters)
def generate_report(column_counters, output_file=None):
"""
生成分析报告
Args:
column_counters: 统计结果字典
output_file: 输出文件路径(可选)
"""
report_lines = []
report_lines.append("=" * 80)
report_lines.append("OCTUPLE分词结果统计分析报告")
report_lines.append("=" * 80)
report_lines.append("")
# 总体统计
total_tokens = sum(sum(counter.values()) for counter in column_counters.values())
report_lines.append(f"总token数: {total_tokens:,}")
report_lines.append(f"分析的列数: {len(column_counters)}")
report_lines.append("")
# 每一列的详细统计
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
report_lines.append("-" * 80)
report_lines.append(f"{col_idx}: {col_name}")
report_lines.append("-" * 80)
total_in_column = sum(counter.values())
unique_tokens = len(counter)
min_token = min(counter.keys()) if counter else 0
max_token = max(counter.keys()) if counter else 0
report_lines.append(f" 总token数: {total_in_column:,}")
report_lines.append(f" 唯一token数: {unique_tokens:,}")
report_lines.append(f" Token值范围: [{min_token}, {max_token}]")
report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A")
report_lines.append("")
# Top 20 最常见的token
report_lines.append(f" Top 20 最常见的token:")
top_tokens = counter.most_common(20)
for rank, (token, count) in enumerate(top_tokens, 1):
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
report_lines.append("")
# Top 20 最不常见的token出现次数>0的
report_lines.append(f" Top 20 最不常见的token (出现次数>0):")
bottom_tokens = counter.most_common()[-20:]
bottom_tokens.reverse()
for rank, (token, count) in enumerate(bottom_tokens, 1):
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
report_lines.append("")
# 分布统计
counts_list = list(counter.values())
if counts_list:
report_lines.append(f" 分布统计:")
report_lines.append(f" 最小出现次数: {min(counts_list):,}")
report_lines.append(f" 最大出现次数: {max(counts_list):,}")
report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}")
report_lines.append(f" 标准差: {np.std(counts_list):,.2f}")
report_lines.append("")
# 跨列分析
report_lines.append("=" * 80)
report_lines.append("跨列分析")
report_lines.append("=" * 80)
report_lines.append("")
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
total_in_column = sum(counter.values())
percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0
report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)")
report_lines.append("")
report_lines.append("=" * 80)
report_lines.append("报告生成完成")
report_lines.append("=" * 80)
# 输出报告
report_text = "\n".join(report_lines)
print("\n" + report_text)
# 保存到文件
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(report_text)
print(f"\n报告已保存到: {output_path}")
# 同时保存JSON格式的详细数据
if output_file:
json_output = output_path.with_suffix('.json')
json_data = {
'summary': {
'total_tokens': total_tokens,
'num_columns': len(column_counters)
},
'columns': {}
}
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
json_data['columns'][col_name] = {
'total_tokens': sum(counter.values()),
'unique_tokens': len(counter),
'token_counts': dict(counter),
'top_20': dict(counter.most_common(20)),
'bottom_20': dict(counter.most_common()[-20:])
}
with open(json_output, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
print(f"详细数据已保存到: {json_output}")
def main():
"""主函数"""
# 默认数据目录 - 可以根据需要修改
default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi"
# 可以指定具体的数据目录,例如:
data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2"
# 或者使用默认目录扫描所有oct8目录
# import sys
# if len(sys.argv) > 1:
# data_dir = sys.argv[1]
# else:
# # 自动查找所有oct8目录
# base_dir = Path(default_data_dir)
# oct8_dirs = list(base_dir.rglob("oct8"))
# if oct8_dirs:
# print(f"找到以下oct8目录:")
# for i, d in enumerate(oct8_dirs, 1):
# print(f" {i}. {d}")
# if len(oct8_dirs) == 1:
# data_dir = str(oct8_dirs[0])
# print(f"\n使用目录: {data_dir}")
# else:
# # 使用第一个找到的目录,或者合并所有目录
# print(f"\n使用第一个目录: {oct8_dirs[0]}")
# print("如需分析其他目录,请指定路径作为参数")
# data_dir = str(oct8_dirs[0])
# else:
# data_dir = default_data_dir
# print(f"未找到oct8目录使用默认目录: {data_dir}")
# 加载数据
all_data = load_octuple_data(data_dir)
if not all_data:
print("错误: 未加载到任何数据")
return
# 检查数据格式
if all_data:
sample = all_data[0]
print(f"\n数据格式检查:")
print(f" 样本形状: {sample.shape}")
print(f" 样本数据类型: {sample.dtype}")
print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}")
print()
# 统计token出现次数
column_counters = count_tokens_by_column(all_data)
# 生成报告
output_file = "octuple_token_analysis_report_part.txt"
generate_report(column_counters, output_file)
if __name__ == "__main__":
main()
print("token 序列长度:", len(sequence))
print("前 20 个 token", sequence[:20])

139
dllm/.gitignore vendored Normal file
View File

@ -0,0 +1,139 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.idea
.Python
build/
develop-eggs/
dist/
downloads/
applications/DeepSpeed-Chat/data
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Others
/.vscode/
/tmp/
/data/
/wandb/
/logs/
/models*/

4
dllm/.gitmodules vendored Normal file
View File

@ -0,0 +1,4 @@
[submodule "lm-evaluation-harness"]
path = lm-evaluation-harness
url = https://github.com/ZHZisZZ/lm-evaluation-harness
branch = dllm

21
dllm/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Zhanhui Zhou
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

283
dllm/README.md Normal file
View File

@ -0,0 +1,283 @@
<h1 align="center">dLLM</h1>
<p align="center">
Simple Diffusion Language Modeling
</p>
<p align="center">
<img
src="assets/logo.gif"
alt="dLLM logo">
</p>
## Overview
**dLLM** is a library that unifies the training and evaluation of **diffusion language models**, bringing transparency and reproducibility to the entire development pipeline:
<!-- and [RND1](https://www.radicalnumerics.ai/assets/rnd1_report.pdf) -->
- dLLM provides scalable training pipelines (inspired by [`transformers`](https://github.com/huggingface/transformers/blob/main/src/transformers) [Trainer](https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py)), with support for [LoRA](https://github.com/huggingface/peft), [DeepSpeed](https://github.com/deepspeedai/DeepSpeed) and [FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and beyond.
- dLLM provides unified evaluation pipelines (inspired by [`lm-evaluation-harness`](https://github.com/EleutherAI/lm-evaluation-harness)) that abstracts away inference details and making customization simple.
- Built on these components, dLLM provide the minimal **pretraining / finetuning / evaluation** recipes for open-weight models (e.g., [LLaDA](https://arxiv.org/abs/2502.09992) and [Dream](https://arxiv.org/abs/2508.15487)), and implementations of training algorithms (e.g., [Edit Flows](https://arxiv.org/abs/2506.09018)).
<!-- > [!NOTE]
> This repository is primarily for educational purposes and does not aim for 100% exact reproduction of official models (which is impossible). We hope it serves as a helpful reference for the community — contributions and improvements are always welcome! -->
## News
**[2025/11]** We released a collection of BERTs finetuned for instruction-following: [`ModernBERT-{large,base}-chat-v0`](https://huggingface.co/collections/dllm-collection/bert-chat). This proof-of-concept shows that BERTs internal knowledge can be leveraged for generative tasks via masked instruction tuning. See [![blog](https://img.shields.io/badge/W&B-white?logo=weightsandbiases) BERT Chat Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg) for detailed recipes, experimental results and lessons learned; See [`examples/bert`](/examples/bert) for training / inference / evaluation instructions.
## Table of Contents
- [Features](#features)
- [Setup](#setup)
- [Files overview](#files-overview)
- [Training](#training)
- [Inference](#inference)
- [Evaluation](#evaluation)
- [Citation](#citation)
## Features
<!-- - [`examples/rnd`](/examples/rnd): (WIP) Finetuning open-weight RND1 [RND1-Base](https://www.radicalnumerics.ai/assets/rnd1_report.pdf). -->
- [`examples/llada`](/examples/llada): Pretraining, finetuning and evaluating LLaDA [LLaDA](https://arxiv.org/abs/2502.09992) / [LLaDA-MoE](https://arxiv.org/abs/2509.24389).
- [`examples/dream`](/examples/dream): Pretraining, finetuning and evaluating Dream [Dream](https://arxiv.org/abs/2508.15487).
- [`examples/bert`](/examples/bert): Finetuning any [BERT](https://arxiv.org/abs/1810.04805) to be lightweight Chatbots.
<details>
<summary>🎬 Click to show BERT Chat Demo</summary>
<p align="center">
<img src="/examples/bert/assets/chat.gif" alt="chat" width="80%">
</p>
<p align="center">
<em>
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
</em>
</p>
</details>
- [`examples/editflow`](/examples/editflow): Educational reference for training [EditFlow](https://arxiv.org/abs/2506.09018) models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT Chat) with *edit operations*—insertion, deletion, and substitution—and how to pretrain or finetune EditFlow models from scratch on public data.
<details>
<summary>🎬 Click to show EditFlow Demo</summary>
<p align="center">
<img src="/examples/editflow/assets/all.gif" alt="EditFlow demo" width="100%">
</p>
<p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p>
</details>
- More upcoming.
## Setup
### Installation
```bash
# create and activate conda environment
conda create -n dllm python=3.10 -y
conda activate dllm
# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
conda install cuda=12.4 -c nvidia
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
--index-url https://download.pytorch.org/whl/cu124
# install dllm package
pip install -e .
```
### (optional) Evaluation setup
```bash
# initialize `lm-evaluation-harness` submodule
git submodule update --init --recursive
# install submodule in editable mode with IFEval & Math dependencies
pip install -e "lm-evaluation-harness[ifeval,math]"
```
### (optional) Slurm setup
For [Slurm](https://slurm.schedmd.com/) users, update [`scripts/train.slurm.sh`](/scripts/train.slurm.sh) for your cluster:
```diff
- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
- #SBATCH --quotatype=spot # Note: adjust this for your cluster
+ #SBATCH --partition=YOUR_PARTITION
+ #SBATCH --quotatype=YOUR_QUOTATYPE
```
Next, create a directory for your job logs:
```shell
mkdir logs
```
This folder will store the log files generated by your sbatch jobs.
## Files overview
```
# modules for training / sampling
dllm
├── core # Core reusable modules shared across `dllm/pipelines`
│ ├── generation
│ ├── schedulers
│ └── trainers
├── data
├── pipelines # Application-specific training & inference pipelines
| ├── bert
│ ├── dream
│ ├── editflow
│ └── llada
│ ├── models # Model architecture and configs
│ ├── generator.py # Generation utilities
│ ├── trainer.py # Core training logic
│ └── eval.py # Evaluation entry point
├── tools
└── utils
# entry points for training / sampling
examples
├── bert
├── dream
├── editflow
└── llada
├── chat.py # Interactive inference example
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
├── sft.py # Supervised finetuning example
└── eval.sh # Evalution script
```
## Training
A typical training entry script looks like (for example, [`examples/llada/sft.py`](/examples/llada/sft.py)) looks like this:
```python
import transformers
import dllm
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
dataset = "..."
# ----- Training --------------------------------------------------------------
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id,
),
)
trainer.train()
```
You can launch training job locally with `accelerate`, or submit it to a [Slurm](https://slurm.schedmd.com/) cluster using `sbatch`.
```shell
# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA)
accelerate launch \
--config_file scripts/accelerate_configs/zero2.yaml \
examples/llada/sft.py \
--num_train_epochs 4 \
--load_in_4bit True --lora True
```
```shell
# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs)
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4
# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs)
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4
```
See [Features](#features) for specific training recipes.
> Here are some useful tips for training:
> 1. Use a subset of data:
> `--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]"`
> 2. Concatenate datasets:
> `--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk"`
> 3. Train with LoRA and 4bit quantization:
> `--load_in_4bit True --lora True`
> 4. Train with different distributed training methods:
> `--accelerate_config "ddp,zero-{1,2,3},fsdp"`
## Inference
We provide unified [generators](/dllm/core/generation/generator.py) that abstracts away inference details.
A typical inference entry script (for example, [`examples/llada/generate.py`](/examples/llada/generate.py)) looks like this:
```python
import dllm
from dllm import llada
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
# for other models, change your generator and keep others unchanged
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(inputs, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
```
You can also try interactive chat script (for example, [`examples/llada/chat.py`](/examples/llada/chat.py)) for visualized multi-turn dialogue:
<p align="center">
<img src="/assets/chat.gif" alt="chat" width="80%">
</p>
<!-- <p align="center"><em>EditFlow performing insertion (blue), substitution from mask tokens (black), substitution from non-mask tokens (red), and deletion (strikethrough → removed) during generation.</em></p> -->
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [`LLaDA-8B-Instruct`](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) on [`MMLU_Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro), run:
```shell
accelerate launch --num_processes 4 \
dllm/pipelines/llada/eval.py \
--tasks "mmlu_pro" \
--model "llada" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256,cfg=0.0"
```
We also provide scripts to automatically evaluate [LLaDA](https://arxiv.org/abs/2502.09992), [Dream](https://arxiv.org/abs/2508.15487), and [BERT-Chat](https://huggingface.co/collections/dllm-collection/bert-chat) on all benchmarks.
For example, you can launch [`examples/llada/eval.sh`](/examples/llada/eval.sh) directly using the following commands:
```shell
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False
```
## Citation
```
@misc{dllm,
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
title = {dLLM: Simple Diffusion Language Modeling},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
}
```

Binary file not shown.

BIN
dllm/assets/chat.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.5 MiB

BIN
dllm/assets/logo.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 956 KiB

BIN
dllm/assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

119
dllm/assets/logo.py Normal file
View File

@ -0,0 +1,119 @@
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import os
# ---------- Configuration (smaller size) ----------
W, H = 480, 210 # lower resolution
TOTAL_DURATION = 3.0
FPS = 15 # lower fps
TEXT = "dLLM"
# TEXT_COLOR = (235, 235, 235)
TEXT_COLOR = (0, 0, 0)
OUTPUT = "logo.gif"
LAST_FRAME_PNG = "logo.png"
DIFFUSION_PORTION = 0.3 # fewer diffusion frames
SEED = 8
# ---------- Auto font size ----------
def load_font_auto_size(text, w, h, target_width_ratio=0.95, target_height_ratio=0.95):
lo, hi = 10, 2000
best_font, best_size = None, lo
while lo <= hi:
mid = (lo + hi) // 2
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", size=mid
)
except:
font = ImageFont.load_default()
dummy = Image.new("L", (w, h), 0)
d = ImageDraw.Draw(dummy)
bbox = d.textbbox((0, 0), text, font=font)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
width_ok = tw <= w * target_width_ratio
height_ok = th <= h * target_height_ratio
if width_ok and height_ok:
best_font, best_size = font, mid
lo = mid + 1
else:
hi = mid - 1
return best_font if best_font is not None else font
# ---------- Text rendering ----------
def render_text_mask(w, h, text, font):
img = Image.new("L", (w, h), 0)
d = ImageDraw.Draw(img)
bbox = d.textbbox((0, 0), text, font=font)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
x = (w - tw) // 2 - bbox[0]
y = (h - th) // 2 - bbox[1]
d.text((x, y), text, font=font, fill=255)
return np.asarray(img, np.float32) / 255.0
# ---------- Initialization ----------
font = load_font_auto_size(TEXT, W, H)
mask = render_text_mask(W, H, TEXT, font)
num_frames = int(TOTAL_DURATION * FPS)
diffusion_frames = max(1, int(num_frames * DIFFUSION_PORTION))
hold_ms = int((TOTAL_DURATION - diffusion_frames / FPS) * 1000)
rng = np.random.default_rng(SEED)
frames = []
# ---------- Diffusion stage ----------
for i in range(diffusion_frames):
t = i / max(1, diffusion_frames - 1)
progress = t**0.9
noise_sigma = (1.0 - progress) ** 2.2
noise = rng.standard_normal((H, W, 1)).astype(np.float32)
noise_img = 1.0 - noise_sigma * 0.5 * np.abs(noise)
np.clip(noise_img, 0.0, 1.0, out=noise_img)
alpha = progress**2.0
alpha_map = (mask * alpha).astype(np.float32)[..., None]
text_rgb = np.zeros((H, W, 3), dtype=np.float32)
for c in range(3):
text_rgb[..., c] = (mask > 0).astype(np.float32) * (TEXT_COLOR[c] / 255.0)
frame = (1.0 - alpha_map) * noise_img + alpha_map * text_rgb
frame = (np.clip(frame, 0.0, 1.0) * 255).astype(np.uint8)
frames.append(Image.fromarray(frame, mode="RGB"))
# ---------- Last frame ----------
final_frame = frames[-1]
# ---------- Save last frame as PNG ----------
final_frame.save(LAST_FRAME_PNG)
print(f"🖼️ Last frame saved as: {LAST_FRAME_PNG}")
# ---------- Quantization (reduce size) ----------
pal_frames = [f.convert("P", palette=Image.ADAPTIVE, colors=64) for f in frames]
pal_final = final_frame.convert("P", palette=Image.ADAPTIVE, colors=64)
# ---------- Save GIF ----------
normal_ms = int(1000 / FPS)
durations = [normal_ms] * len(pal_frames) + [hold_ms]
pal_frames[0].save(
OUTPUT,
save_all=True,
append_images=pal_frames[1:] + [pal_final],
duration=durations,
loop=0,
optimize=True,
)
print(f"✅ GIF saved: {OUTPUT}")
print(
f"Frames (diffusion only): {len(pal_frames)} at {FPS} FPS, final hold {hold_ms} ms, resolution {W}x{H}"
)

1
dllm/dllm/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import core, data, pipelines, utils

View File

@ -0,0 +1 @@
from dllm.core import trainers, schedulers, generation

View File

@ -0,0 +1 @@
from . import generator, visualizer

View File

@ -0,0 +1,49 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from transformers import PreTrainedTokenizer, PreTrainedModel
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
@dataclass
class GeneratorOutput:
sequences: torch.Tensor
histories: list[torch.Tensor] | None = None
@dataclass
class GeneratorConfig:
return_dict_in_generate: bool = False
@dataclass
class BaseGenerator(ABC):
model: PreTrainedModel
tokenizer: PreTrainedTokenizer
scheduler: BaseAlphaScheduler | None = None
def __post_init__(self):
if self.scheduler is None:
self.scheduler = LinearAlphaScheduler()
@abstractmethod
@torch.no_grad()
def generate(
self,
prompts: list[torch.Tensor, list],
config: GeneratorConfig | None = None,
**kwargs,
) -> GeneratorOutput:
raise NotImplementedError
@abstractmethod
@torch.no_grad()
def infill(
self,
inputs: list[torch.Tensor, list],
config: GeneratorConfig | None = None,
**kwargs,
) -> GeneratorOutput:
raise NotImplementedError

View File

@ -0,0 +1,427 @@
from __future__ import annotations
import os
import re
import sys
import time
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Sequence, Optional
import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizer
@dataclass
class BaseVisualizer(ABC):
tokenizer: PreTrainedTokenizer
@abstractmethod
def visualize(self, history: list[torch.Tensor, list], **kwargs):
raise NotImplementedError
@dataclass
class VideoVisualizer(BaseVisualizer):
def visualize(
self,
history: list[torch.Tensor, list],
output_path: str = "visualization.gif",
**kwargs,
):
raise NotImplementedError
@dataclass
class TerminalVisualizer(BaseVisualizer):
# Configuration (adjust as needed)
HEADER_SIZE = 3 # Fixed number of lines for the header (0 if show_header is False)
PROGRESS_SIZE = 3 # Fixed number of lines for the progress bar
PANEL_PADDING_TOP = 1 # Top padding of the Panel (padding=(top, side))
PANEL_PADDING_BOTTOM = 1 # Bottom padding of the Panel
PANEL_PADDING_SIDE = 1 # Number of characters used for left and right padding
PANEL_BORDER = 2 # Number of columns taken by the Panel border (usually 2)
MIN_TOTAL_HEIGHT = 10 # Minimum terminal height (in lines)
MAX_TOTAL_HEIGHT = 60 # Maximum terminal height to prevent overflowing the terminal
DEFAULT_TERM_WIDTH = 120 # Default terminal width (in columns)
ansi_escape = re.compile(r"\x1b\[[0-9;]*m") # Regex to match ANSI escape codes
def visualize(
self,
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
fps: int = 16,
rich: bool = True,
title: str = "dllm",
max_chars: int = None,
every_n_steps: int = 1,
show_header: bool = True,
skip_special_tokens: bool = False,
) -> None:
"""
Visualize a masked-diffusion decoding trajectory stored in `history`.
If items have batch dimension [B, T], visualize each sequence separately.
"""
try:
# detect batch size
first_step = history[0]
if first_step.dim() > 1 and first_step.shape[0] > 1:
B = first_step.shape[0]
for b_idx in range(B):
# build per-sequence history
seq_history = [step[b_idx].unsqueeze(0) for step in history]
self.visualize_one_history(
seq_history,
fps,
rich,
title=f"{title} (Batch {b_idx})",
max_chars=max_chars,
every_n_steps=every_n_steps,
show_header=show_header,
skip_special_tokens=skip_special_tokens,
)
else:
# no batch, just visualize normally
self.visualize_one_history(
history,
fps,
rich,
title,
max_chars,
every_n_steps,
show_header,
skip_special_tokens,
)
except Exception as e:
print(f"(Visualization skipped due to error: {e})")
def visualize_one_history(
self,
history: list[torch.Tensor], # list of tokens per step: [T] or [B,T]
fps: int = 16,
rich: bool = True,
title: str = "dllm",
max_chars: int = None,
every_n_steps: int = 1, # re-render frequency (perf knob)
show_header: bool = True,
skip_special_tokens: bool = False, # NEW ARGUMENT
) -> None:
"""
Visualize a masked-diffusion decoding trajectory stored in `history`.
Args:
history: Sequence of token tensors for each step. Each item is [T] or [B,T].
fps: Frames per second for the live UI (Rich) or sleep cadence for tqdm fallback.
title: Header title.
max_chars: Cap on rendered characters to keep terminal snappy.
every_n_steps: Only redraw text every N steps (progress still updates every step).
show_header: Show the magenta header bar (Rich path).
skip_special_tokens: Whether to skip special/pad/eos tokens when rendering (default: False).
Notes:
- Masked positions are detected via `self.tokenizer.mask_token_id`.
- Special tokens are determined via `self.tokenizer.all_special_ids`.
- All layout, styling, and progress are encapsulated here.
"""
# --------- imports & env checks ----------
try:
from rich.console import Console
from rich.live import Live
from rich.text import Text
from rich.panel import Panel
from rich.progress import (
Progress,
BarColumn,
TextColumn,
TimeRemainingColumn,
MofNCompleteColumn,
SpinnerColumn,
)
from rich.layout import Layout
_RICH_IMPORTED = True
except Exception:
_RICH_IMPORTED = False
try:
from tqdm import tqdm
_TQDM_IMPORTED = True
except Exception:
_TQDM_IMPORTED = False
if self.tokenizer is None:
raise ValueError(
"TerminalVisualizer.tokenizer must be set to a valid tokenizer."
)
tokenizer = self.tokenizer
specials: set[int] = set(getattr(tokenizer, "all_special_ids", []) or [])
self._specials = specials # store for helpers
self._mask_token_id: Optional[int] = getattr(tokenizer, "mask_token_id", None)
self._pad_token_id: Optional[int] = getattr(tokenizer, "pad_token_id", None)
self._eos_token_id: Optional[int] = getattr(tokenizer, "eos_token_id", None)
# --------- helpers inside class scope ----------
# (keep everything inside this class as requested)
# throttle settings
sleep_s = 0.0 if fps <= 0 else 1.0 / float(max(1, fps))
total_steps = len(history)
every_n_steps = max(1, int(every_n_steps))
# decode final text up-front (used after render)
final_text = self._detok(history[-1], skip_special_tokens=skip_special_tokens)
final_text = self._truncate(final_text, max_chars)
# ------------------ new: estimate height from final_text ------------------
import textwrap
import shutil
def strip_ansi(s: str) -> str:
return self.ansi_escape.sub("", s) if s else ""
def estimate_height_from_text(text: str, console_width: int) -> int:
"""
Estimate how many terminal rows the panel with `text` will need given console_width.
Uses class constants for paddings/borders and header/progress sizes.
"""
plain = strip_ansi(text or "")
# inner width = console width minus left/right panel paddings & border
inner_width = max(
10, console_width - 2 * self.PANEL_PADDING_SIDE - self.PANEL_BORDER
)
lines = 0
# preserve existing newlines: wrap each paragraph separately
for para in plain.splitlines() or [""]:
if para.strip() == "":
lines += 1
continue
wrapped = textwrap.wrap(
para,
width=inner_width,
replace_whitespace=False,
drop_whitespace=False,
)
lines += max(1, len(wrapped))
text_block_lines = (
lines + self.PANEL_PADDING_TOP + self.PANEL_PADDING_BOTTOM
)
extra = 2 # for panel title / subtitle / small margin
header_h = self.HEADER_SIZE if show_header else 0
total = header_h + text_block_lines + self.PROGRESS_SIZE + extra
# clamp
total = max(self.MIN_TOTAL_HEIGHT, min(total, self.MAX_TOTAL_HEIGHT))
return int(total)
# try to detect terminal width; fallback to 100
try:
term_width = shutil.get_terminal_size().columns
if not isinstance(term_width, int) or term_width <= 0:
term_width = self.DEFAULT_TERM_WIDTH
except Exception:
term_width = self.DEFAULT_TERM_WIDTH
est_height = estimate_height_from_text(final_text, console_width=term_width)
# ------------------ end new ----------------------------------------------
# choose rich or tqdm
use_rich = bool(rich and _RICH_IMPORTED)
if not use_rich or not _RICH_IMPORTED:
# ---------- tqdm fallback ----------
if not _TQDM_IMPORTED:
for i, toks in enumerate(history, start=1):
if sleep_s > 0:
time.sleep(sleep_s)
print("\n✨ Generation complete!\n")
print(final_text)
return
pbar = tqdm(total=total_steps, desc="Diffusion", leave=True)
for i, toks in enumerate(history, start=1):
pbar.update(1)
pbar.set_postfix(
{
"masks": self._count_masks(toks),
"pct": f"{int(100 * i / max(total_steps, 1))}%",
}
)
if sleep_s > 0:
time.sleep(sleep_s)
pbar.close()
print("\n✨ Generation complete!\n")
if final_text:
print(final_text)
return
# ---------- rich live UI ----------
# replaced fixed height=100 with the estimated height from history[-1]
console = Console(
force_terminal=True,
color_system="truecolor",
width=term_width,
height=est_height,
)
layout = Layout()
layout.split_column(
(
Layout(name="header", size=3)
if show_header
else Layout(name="header", size=0)
),
Layout(name="text", ratio=1),
Layout(name="progress", size=3),
)
progress = Progress(
SpinnerColumn(),
TextColumn("[bold blue]Diffusion"),
BarColumn(),
MofNCompleteColumn(),
TextColumn(""),
TextColumn("[cyan]Masks: {task.fields[masks]}"),
TextColumn(""),
TextColumn("[magenta]{task.fields[pct]:>4s}"),
TimeRemainingColumn(),
expand=True,
)
init_masks = self._count_masks(history[0]) if history else 0
task_id = progress.add_task(
"Generating", total=total_steps, masks=init_masks, pct="0%"
)
with Live(layout, console=console, refresh_per_second=max(1, fps)):
for step_idx, toks in enumerate(history, start=1):
if show_header:
header = Text(title, style="bold magenta", justify="center")
layout["header"].update(Panel(header, border_style="bright_blue"))
# progress bar
masks_remaining = self._count_masks(toks)
pct = f"{int(100 * step_idx / max(total_steps, 1))}%"
progress.update(task_id, advance=1, masks=masks_remaining, pct=pct)
# text panel: decode whole sequence (avoids Ġ/Ċ artifacts)
if (
every_n_steps <= 1
or (step_idx % every_n_steps == 0)
or step_idx in (1, total_steps)
):
text_str = self._detok(
toks, skip_special_tokens=skip_special_tokens
)
text_str = self._truncate(text_str, max_chars)
text_rich = Text.from_ansi(text_str) if text_str else Text("")
layout["text"].update(
Panel(
(
text_rich
if text_rich.plain
else Text("[dim]— no tokens —[/dim]")
),
title="[bold]Generated Text",
subtitle=f"[dim]Step {step_idx}/{total_steps}[/dim]",
border_style="cyan",
padding=(1, 1),
)
)
layout["progress"].update(Panel(progress))
if sleep_s > 0:
time.sleep(sleep_s)
console.print("\n[bold green]✨ Generation complete![/bold green]\n")
# console.print(
# Panel(
# final_text if final_text else "[dim]— no decodable text —[/dim]",
# title="[bold]Final Generated Text",
# border_style="green",
# padding=(1, 2),
# )
# )
# ======================== helpers (kept inside class) ========================
def _has_tty(self) -> bool:
return sys.stdout.isatty() and os.environ.get("TERM", "") not in ("", "dumb")
def _first_item(self, x: torch.Tensor) -> torch.Tensor:
return x[0] if x.dim() > 1 else x
def _count_masks(self, toks: torch.Tensor) -> int:
if getattr(self, "_mask_token_id", None) is None:
return 0
t = self._first_item(toks)
return int((t == self._mask_token_id).sum().item())
def _detok(self, ids_or_tensor, *, skip_special_tokens: bool) -> str:
"""
Robust detokenize for list[int] / torch.Tensor([T]) / torch.Tensor([B,T]).
Decode the whole sequence to avoid byte-level artifacts like Ġ/Ċ.
"""
tokenizer = self.tokenizer
# normalize to python list[int]
if isinstance(ids_or_tensor, torch.Tensor):
t = self._first_item(ids_or_tensor).long()
ids = t.tolist()
elif isinstance(ids_or_tensor, (list, tuple)):
ids = list(ids_or_tensor)
else:
# unknown type
return ""
# Optionally drop specials/pad/eos *before* decode if desired
if skip_special_tokens:
keep = []
specials = getattr(self, "_specials", set())
pad_id = getattr(self, "_pad_token_id", None)
eos_id = getattr(self, "_eos_token_id", None)
for tid in ids:
if tid in specials:
continue
if pad_id is not None and tid == pad_id:
continue
if eos_id is not None and tid == eos_id:
continue
keep.append(tid)
ids = keep
# Prefer tokenizer.decode (handles Ġ/Ċ, merges properly)
text = ""
try:
if hasattr(tokenizer, "decode"):
text = tokenizer.decode(
ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
)
else:
# fallback: tokens -> string
toks = tokenizer.convert_ids_to_tokens(ids)
if hasattr(tokenizer, "convert_tokens_to_string"):
text = tokenizer.convert_tokens_to_string(toks)
else:
text = " ".join(map(str, toks))
except Exception:
# extremely defensive fallback
try:
text = tokenizer.decode(ids, skip_special_tokens=True)
except Exception:
text = ""
# sanitize control chars for terminal
if text:
text = text.replace("\r", "")
return text
def _truncate(self, s: str, max_chars: Optional[int]) -> str:
if max_chars is None or (isinstance(max_chars, int) and max_chars < 0):
return s
return s[:max_chars]
if __name__ == "__main__":
pass

View File

@ -0,0 +1,2 @@
from .alpha import *
from .kappa import *

View File

@ -0,0 +1,132 @@
from __future__ import annotations
import dataclasses
import math
from typing import ClassVar, Dict, Type, Any, Union
import torch
Number = Union[float, torch.Tensor]
# ---------------- Registry-enabled Base ---------------- #
@dataclasses.dataclass
class BaseAlphaScheduler:
__registry__: ClassVar[dict[str, type[BaseAlphaScheduler]]] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
BaseAlphaScheduler.__registry__[cls.__name__] = cls
BaseAlphaScheduler.__registry__[cls.__name__.lower()] = cls
# Make instances callable (sched(i) -> alpha(i))
def __call__(self, t: Number) -> Number:
return self.alpha(t)
# ---- common API ----
def alpha(self, i: Number) -> Number:
i_t = torch.as_tensor(
i,
dtype=torch.float32,
device=i.device if isinstance(i, torch.Tensor) else None,
)
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
raise ValueError(f"i={i} not in [0,1]")
out = self._alpha(i_t)
return out.item() if isinstance(i, float) else out
def alpha_derivative(self, i: Number) -> Number:
i_t = torch.as_tensor(
i,
dtype=torch.float32,
device=i.device if isinstance(i, torch.Tensor) else None,
)
if not torch.all((0.0 <= i_t) & (i_t <= 1.0)):
raise ValueError(f"i={i} not in [0,1]")
out = self._alpha_derivative(i_t)
return out.item() if isinstance(i, float) else out
def reverse_mask_prob(self, s: Number, t: Number) -> Number:
t_t = torch.as_tensor(
t,
dtype=torch.float32,
device=t.device if isinstance(t, torch.Tensor) else None,
)
s_t = torch.as_tensor(
s,
dtype=torch.float32,
device=s.device if isinstance(s, torch.Tensor) else None,
)
if not torch.all((0.0 <= s_t) & (s_t < 1.0) & (0.0 < t_t) & (t_t <= 1.0)):
raise ValueError(f"(t={t}, s={s}) out of range")
if not torch.all(s_t < t_t):
raise ValueError(f"Require s < t elementwise, but got (t={t}, s={s})")
out = (1 - self(s_t)) / (1 - self(t_t))
return out.item() if isinstance(t, float) and isinstance(s, float) else out
def weight(self, i: Number) -> Number:
# w(t) = - α'(t) / (1 - α(t))
return - self.alpha_derivative(i) / (1 - self.alpha(i) + 1e-6)
# ---- hooks implemented by subclasses ----
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
# ---------------- Implementations ---------------- #
@dataclasses.dataclass
class LinearAlphaScheduler(BaseAlphaScheduler):
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
return 1 - i
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
return -torch.ones_like(i)
@dataclasses.dataclass
class CosineAlphaScheduler(BaseAlphaScheduler):
def _alpha(self, i: torch.Tensor) -> torch.Tensor:
return 1 - torch.cos((math.pi / 2) * (1 - i))
def _alpha_derivative(self, i: torch.Tensor) -> torch.Tensor:
return -(math.pi / 2) * torch.sin((math.pi / 2) * (1 - i))
# ---------------- Factory helpers ---------------- #
def get_alpha_scheduler_class(name: str) -> type[BaseAlphaScheduler]:
"""Return the scheduler class by name (case-insensitive)."""
cls = BaseAlphaScheduler.__registry__.get(
name
) or BaseAlphaScheduler.__registry__.get(name.lower())
if cls is None:
available = sorted(k for k in BaseAlphaScheduler.__registry__ if k[0].isupper())
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
return cls
def make_alpha_scheduler(name: str, **kwargs: Any) -> BaseAlphaScheduler:
"""Instantiate a scheduler by name with optional kwargs."""
cls = get_alpha_scheduler_class(name)
return cls(**kwargs)
# ---------------- Example usage ---------------- #
if __name__ == "__main__":
lin_sched = make_alpha_scheduler("LinearalphaScheduler")
print("Linear α(0.5):", lin_sched.alpha(0.5))
print("Linear w(0.5):", lin_sched.weight(0.5))
print("Linear α([.25,.5,.75]):", lin_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
print("==========================================")
cos_sched = make_alpha_scheduler("CosinealphaScheduler")
print("Cosine α(0.5):", cos_sched.alpha(0.5))
print("Cosine w(0.5):", cos_sched.weight(0.5))
print("Cosine α([.25,.5,.75]):", cos_sched.alpha(torch.tensor([0.25, 0.5, 0.75])))
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import dataclasses
import math
from typing import ClassVar, Dict, Type, Any, Union
import torch
Number = Union[float, torch.Tensor]
# ---------------- Registry-enabled Base ---------------- #
@dataclasses.dataclass
class BaseKappaScheduler:
__registry__: ClassVar[dict[str, type[BaseKappaScheduler]]] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
BaseKappaScheduler.__registry__[cls.__name__] = cls
BaseKappaScheduler.__registry__[cls.__name__.lower()] = cls
# Make instances callable (sched(t) -> kappa(t))
def __call__(self, t: Number) -> Number:
return self.kappa(t)
# ---- common API ----
def kappa(self, t: Number) -> Number:
t_tensor = torch.as_tensor(
t,
dtype=torch.float32,
device=t.device if isinstance(t, torch.Tensor) else None,
)
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
raise ValueError(f"t={t} not in [0,1]")
out = self._kappa(t_tensor)
return out.item() if isinstance(t, float) else out
def kappa_derivative(self, t: Number) -> Number:
t_tensor = torch.as_tensor(
t,
dtype=torch.float32,
device=t.device if isinstance(t, torch.Tensor) else None,
)
if not torch.all((0.0 <= t_tensor) & (t_tensor <= 1.0)):
raise ValueError(f"t={t} not in [0,1]")
out = self._kappa_derivative(t_tensor)
return out.item() if isinstance(t, float) else out
def weight(self, t: Number) -> Number:
# w(t) = κ'(t) / (1 - κ(t))
return self.kappa_derivative(t) / (1 - self.kappa(t) + 1e-6)
# ---- hooks implemented by subclasses ----
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
# ---------------- Implementations ---------------- #
@dataclasses.dataclass
class CubicKappaScheduler(BaseKappaScheduler):
a: float = 1.0
b: float = 1.0
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
# κ(t) = (a+1) t^3 - (a+b+1) t^2 + (b+1) t
return (self.a + 1) * (t**3) - (self.a + self.b + 1) * (t**2) + (self.b + 1) * t
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
# κ'(t) = 3(a+1) t^2 - 2(a+b+1) t + (b+1)
return 3 * (self.a + 1) * (t**2) - 2 * (self.a + self.b + 1) * t + (self.b + 1)
@dataclasses.dataclass
class LinearKappaScheduler(CubicKappaScheduler):
# Special case: κ(t) = t corresponds to a=-1, b=0
a: float = -1.0
b: float = 0.0
@dataclasses.dataclass
class CosineKappaScheduler(BaseKappaScheduler):
def _kappa(self, t: torch.Tensor) -> torch.Tensor:
# κ(t) = 1 - cos((π/2) * t)
return 1.0 - torch.cos(0.5 * math.pi * t)
def _kappa_derivative(self, t: torch.Tensor) -> torch.Tensor:
# κ'(t) = (π/2) * sin((π/2) * t)
return 0.5 * math.pi * torch.sin(0.5 * math.pi * t)
# ---------------- Factory helpers ---------------- #
def get_kappa_scheduler_class(name: str) -> type[BaseKappaScheduler]:
"""Return the scheduler class by name (case-insensitive)."""
cls = BaseKappaScheduler.__registry__.get(
name
) or BaseKappaScheduler.__registry__.get(name.lower())
if cls is None:
available = sorted(k for k in BaseKappaScheduler.__registry__ if k[0].isupper())
raise ValueError(f"Unknown scheduler '{name}'. Available: {available}")
return cls
def make_kappa_scheduler(name: str, **kwargs: Any) -> BaseKappaScheduler:
"""Instantiate a scheduler by name with optional kwargs."""
cls = get_kappa_scheduler_class(name)
return cls(**kwargs)
# ---------------- Example usage ---------------- #
if __name__ == "__main__":
lin_sched = make_kappa_scheduler("LinearKappaScheduler")
print("Linear κ(0.5):", lin_sched.kappa(0.5))
print("Linear w(0.5):", lin_sched.weight(0.5))
print("Linear κ([.25,.5,.75]):", lin_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
print("Linear w([.25,.5,.75]):", lin_sched.weight(torch.tensor([0.25, 0.5, 0.75])))
print("==========================================")
cos_sched = make_kappa_scheduler("CosineKappaScheduler")
print("Cosine κ(0.5):", cos_sched.kappa(0.5))
print("Cosine w(0.5):", cos_sched.weight(0.5))
print("Cosine κ([.25,.5,.75]):", cos_sched.kappa(torch.tensor([0.25, 0.5, 0.75])))
print("Cosine w([.25,.5,.75]):", cos_sched.weight(torch.tensor([0.25, 0.5, 0.75])))

View File

@ -0,0 +1 @@
from dllm.core.trainers.mdlm import MDLMTrainer

View File

@ -0,0 +1,140 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from typing import Any
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
class MDLMTrainer(transformers.Trainer):
"""
Masked Diffusion Language Model Trainer.
"""
def __init__(
self,
*args,
scheduler: BaseAlphaScheduler | None = None,
time_epsilon: float = 1e-3,
loss_weight_type: str = "scheduler", # "ones"
**kwargs,
):
super().__init__(*args, **kwargs)
self.scheduler = scheduler or LinearAlphaScheduler()
if not (0.0 < time_epsilon < 1.0):
raise ValueError("time_epsilon must be in (0, 1)")
self.time_epsilon = time_epsilon
self.loss_weight_type = loss_weight_type
def _preprocess_inputs(self, inputs):
pass
def _postprocess_outputs(self, outputs):
pass
def _compute_loss_weights(
self,
t: torch.Tensor,
inputs: dict[str, Any],
*args,
**kwargs,
) -> torch.Tensor:
"""Compute loss weights given timestep t and other arguments."""
b, l = inputs["input_ids"].shape
if self.loss_weight_type == "scheduler":
loss_weights = self.scheduler.weight(t).unsqueeze(1).repeat(1, l) # b, 1
elif self.loss_weight_type == "ones":
loss_weights = torch.ones_like(inputs["input_ids"])
else:
raise NotImplementedError
return loss_weights
@torch.no_grad()
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
if prediction_loss_only:
return (loss.detach(), None, None)
logits = getattr(outputs, "logits", outputs)
if isinstance(logits, torch.Tensor):
logits = logits.detach().contiguous()
labels = inputs.get("labels")
if isinstance(labels, torch.Tensor):
labels = labels.detach().contiguous()
return (loss.detach(), logits, labels)
def compute_loss(
self,
model: transformers.PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
**kwargs,
):
assert self.processing_class.padding_side == "right"
self._preprocess_inputs(inputs)
input_ids, labels, attention_mask = (
inputs["input_ids"],
inputs["labels"],
inputs.get("attention_mask", None),
)
b, l = input_ids.shape
# === 1. Sample diffusion timesteps ===
# Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0.
# The scheduler defines the masking rate α(t); we convert it to a masking probability p_mask = 1 - α(t).
t = self.time_epsilon + (1 - self.time_epsilon) * torch.rand(
b, device=input_ids.device
)
p_mask = 1 - self.scheduler(t).unsqueeze(1).expand(b, l)
# === 2. Apply stochastic masking ===
# Tokens are masked independently according to p_mask(t).
# Positions with label = -100 are excluded (ignored in loss).
masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & (
labels != -100
)
# Replace masked tokens with the special [MASK] token.
noised_input_ids = torch.where(
masked_indices, self.processing_class.mask_token_id, input_ids
)
# === 3. Forward pass through the model ===
# The model predicts clean tokens given noised inputs.
outputs = model(input_ids=noised_input_ids, attention_mask=attention_mask)
self._postprocess_outputs(outputs)
logits = outputs.logits
# === 4. Handle degenerate cases (no tokens masked) ===
# If no positions were masked, return a zero loss to keep gradients valid.
# This step is necessary for Deepspeed Zero-{2,3}
if not masked_indices.any():
return (
(logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0
)
# === 5. Compute per-token loss weights ===
# Depending on the configuration, weights may depend on timestep t
# (e.g., scheduler-based) or be uniform (ones).
loss_weights = self._compute_loss_weights(
t=t, inputs=inputs, masked_indices=masked_indices
)
# === 6. Compute weighted cross-entropy ===
# Only masked tokens contribute to the loss.
assert (input_ids[masked_indices] == labels[masked_indices]).all()
token_loss = F.cross_entropy(
logits[masked_indices], input_ids[masked_indices], reduction="none"
)
token_loss = token_loss * loss_weights[masked_indices]
# === 7. Normalize loss per effective token length ===
# Normalize each sequences contribution by its number of valid tokens,
# then average over the batch for stability across variable-length inputs.
effective_lengths = torch.sum(labels != -100, dim=1, keepdim=True).expand(b, l)
loss = torch.sum(token_loss / effective_lengths[masked_indices]) / b
# === 8. Return final loss (and optionally model outputs) ===
return (loss, outputs) if return_outputs else loss

View File

@ -0,0 +1 @@
from .utils import load_sft_dataset, load_pt_dataset

63
dllm/dllm/data/alpaca.py Normal file
View File

@ -0,0 +1,63 @@
from typing import Optional
from datasets import load_dataset, DatasetDict
def _build_alpaca_prompt(instruction: str, input_text: str | None) -> str:
"""Construct a clean text prompt from Alpaca fields.
We intentionally *do not* include Anthropic-style role tags (e.g., "Human:", "Assistant:")
in the returned prompt, to mirror the return shape of `load_hh_rlhf_dataset` which removes
those tags from the prompt it returns.
"""
instruction = (instruction or "").strip()
input_text = (input_text or "").strip()
if input_text:
# Keep instruction and input separated by a blank line for readability.
return f"{instruction}\n\n{input_text}"
else:
return instruction
def load_dataset_alpaca(dataset_name_or_path: str) -> DatasetDict:
"""Load the Alpaca dataset (tatsu-lab/alpaca) and expose unified fields.
Returns a `DatasetDict` where each split contains:
- prompt: Combined instruction (+ optional input), with clean formatting
- response: The target output (model answer)
Parameters
----------
dataset_name_or_path : str
Usually "tatsu-lab/alpaca" or a local path.
"""
dataset = load_dataset(dataset_name_or_path)
def map_fn(example):
prompt = _build_alpaca_prompt(
example.get("instruction", ""), example.get("input", "")
)
response = (example.get("output", "") or "").strip()
return {
"messages": [
{"role": "user", "content": prompt},
{"role": "assistant", "content": response},
]
}
dataset = dataset.map(
map_fn, remove_columns=dataset["train"].column_names, num_proc=4
)
# make train test split
dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
return dataset
if __name__ == "__main__":
from dllm.utils import resolve_with_base_env
dataset_name_or_path = resolve_with_base_env(
"tatsu-lab/alpaca", "BASE_DATASETS_DIR"
)
dataset = load_dataset_alpaca(dataset_name_or_path)
breakpoint()

133
dllm/dllm/data/opc.py Normal file
View File

@ -0,0 +1,133 @@
from typing import Optional, Text, List, Dict
from datasets import (
load_dataset,
get_dataset_config_names,
concatenate_datasets,
DatasetDict,
Dataset,
IterableDatasetDict,
)
from dllm.data.utils import (
_merge_datasetdicts,
_merge_iterabledatasetdicts,
_ensure_datasetdict,
_ensure_iterabledatasetdict,
_ensure_datasetdict,
)
def load_dataset_opc_sft(
dataset_name_or_path: str, name: str | None = None, lang: str | None = None
) -> DatasetDict:
"""
Load OpenCoder OPC SFT dataset(s) and produce a DatasetDict with a train/test split.
- If `name` is provided: load that specific config.
- If `name` is None: load *all* available configs and concatenate them.
"""
def _map_to_messages(ds: Dataset) -> Dataset:
def map_fn(example):
return {
"messages": [
{"role": "user", "content": example["instruction"]},
{"role": "assistant", "content": example["output"]},
]
}
# Remove all original columns after mapping
remove_cols = ds.column_names
return ds.map(map_fn, remove_columns=remove_cols, num_proc=4)
def _load_one_config(dataset_name_or_path: str, cfg_name: str) -> Dataset:
ds = load_dataset(dataset_name_or_path, cfg_name, split="train")
return _map_to_messages(ds)
if name is not None:
train_ds = _load_one_config(dataset_name_or_path, name)
else:
# Enumerate and load all configs, then concatenate
cfgs: list[str] = get_dataset_config_names(dataset_name_or_path)
if not cfgs:
raise ValueError(f"No configs found for dataset: {dataset_name_or_path}")
parts = [_load_one_config(dataset_name_or_path, c) for c in cfgs]
train_ds = concatenate_datasets(parts)
# Final split
ds_dict = train_ds.train_test_split(test_size=0.1, seed=42)
if lang is not None:
ds_dict = ds_dict.filter(lambda row: lang in row["messages"][1]["content"])
return DatasetDict(ds_dict)
def load_dataset_opc_annealing(
dataset_name_or_path: str,
name: str | None = None,
lang: str | None = None,
streaming: bool = True,
) -> DatasetDict:
def _load_one_config(_name):
ds = load_dataset(
dataset_name_or_path, _name, split="train", streaming=streaming
)
if lang:
if _name in ["synthetic_code_snippet", "algorithmic_corpus"]:
ds = ds.filter(lambda row: row["lang"] == lang)
elif _name in ["synthetic_qa"]:
ds = ds.filter(lambda row: row["program_lang"] == lang)
else:
raise NotImplementedError
# return IterableDatasetDict({"train": ds})
if streaming:
return _ensure_iterabledatasetdict(ds)
return _ensure_datasetdict(ds)
if name is not None:
return _load_one_config(name)
if streaming:
parts = [
_load_one_config(name)
for name in get_dataset_config_names(dataset_name_or_path)
]
merged = parts[0]
for p in parts[1:]:
merged = _merge_iterabledatasetdicts(merged, p)
return merged
else:
parts = [
_load_one_config(name)
for name in get_dataset_config_names(dataset_name_or_path)
]
if len(parts) == 1:
return _ensure_datasetdict(parts[0])
merged = parts[0]
for p in parts[1:]:
merged = _merge_datasetdicts(merged, p)
return _ensure_datasetdict(merged)
if __name__ == "__main__":
from dllm.utils import resolve_with_base_env
dataset_name_or_path = resolve_with_base_env(
"OpenCoder-LLM/opc-sft-stage1", "BASE_DATASETS_DIR"
)
# If you want a specific config:
dataset_edu = load_dataset_opc_sft(dataset_name_or_path, "realuser_instruct")
# Otherwise, all configs concatenated:
dataset_all = load_dataset_opc_sft(dataset_name_or_path, None)
dataset_all_python = load_dataset_opc_sft(dataset_name_or_path, None, "python")
breakpoint()
# streaming = True
# dataset_name_or_path = resolve_with_base_env(
# "OpenCoder-LLM/opc-annealing-corpus", "BASE_DATASETS_DIR"
# )
# # If you want a specific config:
# dataset_alg_all = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus")
# dataset_alg_python = load_dataset_opc_annealing(dataset_name_or_path, "algorithmic_corpus", "python")
# # Otherwise, all configs concatenated:
# dataset_all_python = load_dataset_opc_annealing(dataset_name_or_path, None, "python")
# dataset_all_all = load_dataset_opc_annealing(dataset_name_or_path, None)
# breakpoint()

108
dllm/dllm/data/ultrachat.py Normal file
View File

@ -0,0 +1,108 @@
from typing import Optional, List, Dict
from datasets import load_dataset, DatasetDict
def _extract_first_turn(messages: list[dict[str, str]]) -> dict[str, str] | None:
"""
Given a list of chat messages like:
[{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."},
...]
return a dict with the first user/assistant exchange as:
{"prompt": <user content>, "response": <assistant content>}
If no valid first turn exists, return None.
"""
if not isinstance(messages, list) or len(messages) < 2:
return None
# Find the first user message and the first assistant *after* that user msg
# (Most entries start as [user, assistant, ...], but we guard anyway.)
user_idx = None
for i, m in enumerate(messages):
if (
isinstance(m, dict)
and m.get("role") == "user"
and isinstance(m.get("content"), str)
):
user_idx = i
break
if user_idx is None:
return None
# Find first assistant after that user
for j in range(user_idx + 1, len(messages)):
m = messages[j]
if (
isinstance(m, dict)
and m.get("role") == "assistant"
and isinstance(m.get("content"), str)
):
user_text = messages[user_idx]["content"].strip()
assistant_text = m["content"].strip()
if user_text and assistant_text:
return {"prompt": user_text, "response": assistant_text}
return None
return None
def load_dataset_ultrachat(dataset_name_or_path: str) -> DatasetDict:
"""
Load the UltraChat 200k dataset (HuggingFaceH4/ultrachat_200k) and keep only the *first turn*
(first user message and the assistant reply).
Returns a `DatasetDict` where each split contains:
- prompt: first user message content
- response: first assistant reply content
Parameters
----------
dataset_name_or_path : str
Typically "HuggingFaceH4/ultrachat_200k" or a local path.
data_dir : Optional[str]
Optional subdirectory (for local paths).
"""
dataset = load_dataset(dataset_name_or_path)
# We only keep examples that have a valid first (user, assistant) turn.
def has_first_turn(example):
messages = example.get("messages")
return _extract_first_turn(messages) is not None
dataset = dataset.filter(has_first_turn, num_proc=4)
def map_fn(example):
first = _extract_first_turn(example["messages"])
# Fallbacks for robustness (shouldn't be hit after filter, but just in case)
if first is None:
first = {"prompt": (example.get("prompt") or "").strip(), "response": ""}
return {"prompt": first["prompt"], "response": first["response"]}
# Remove original columns for a clean schema (infer from any available split)
cols_to_remove = None
for split_name in dataset.keys():
cols_to_remove = dataset[split_name].column_names
break
dataset = dataset.map(map_fn, remove_columns=cols_to_remove, num_proc=4)
dataset = DatasetDict(
{
new: dataset[old]
for old, new in {
"train_sft": "train",
"test_sft": "test",
}.items()
if old in dataset
}
)
return dataset
if __name__ == "__main__":
# Mirrors the style from your previous loaders: resolve path via env helper if available.
from dllm.utils import resolve_with_base_env
dataset_name_or_path = resolve_with_base_env(
"HuggingFaceH4/ultrachat_200k", "BASE_DATASETS_DIR"
)
dataset = load_dataset_ultrachat(dataset_name_or_path)
breakpoint()

377
dllm/dllm/data/utils.py Normal file
View File

@ -0,0 +1,377 @@
from datasets import (
Dataset,
DatasetDict,
IterableDatasetDict,
IterableDataset,
load_dataset,
load_from_disk,
)
from dllm.utils.utils import resolve_with_base_env, parse_spec, get_default_logger
logger = get_default_logger(__name__)
def load_sft_dataset(
dataset_args: str, load_preprocessed_data: bool = False
) -> DatasetDict:
"""
Examples of dataset_args:
- "tatsu-lab/alpaca"
- "OpenCoder-LLM/opc-sft-stage2[name:educational_instruct]"
- "tatsu-lab/alpaca[train:5000]"
- "tatsu-lab/alpaca[train:5000] | HuggingFaceH4/ultrachat_200k[train:5000]"
"""
from dllm.data.alpaca import load_dataset_alpaca
from dllm.data.opc import load_dataset_opc_sft
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
all_parts = []
for raw in specs:
dataset_name_or_path, kvs = parse_spec(raw)
dataset_name_or_path = resolve_with_base_env(
dataset_name_or_path, "BASE_DATASETS_DIR"
)
if load_preprocessed_data:
logger.info("Load preprocessed data from disk.")
ds = load_from_disk(dataset_name_or_path)
# Implement your customized dataset here
elif _match(dataset_name_or_path, "tatsu-lab/alpaca"):
ds = load_dataset_alpaca(dataset_name_or_path)
elif _match(dataset_name_or_path, "allenai/tulu-3-sft-mixture"):
ds = load_dataset(dataset_name_or_path)
ds = ds["train"].train_test_split(test_size=0.1, seed=42)
elif _match(dataset_name_or_path, "HuggingFaceTB/smoltalk"):
name = kvs.pop("name", "all")
ds = load_dataset(dataset_name_or_path, name=name)
elif _match(dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage1") or _match(
dataset_name_or_path, "OpenCoder-LLM/opc-sft-stage2"
):
name = kvs.pop("name", None)
lang = kvs.pop("lang", None)
ds = load_dataset_opc_sft(dataset_name_or_path, name=name, lang=lang)
elif _match(dataset_name_or_path, "HuggingFaceH4/ultrachat_200k"):
ds = load_dataset(dataset_name_or_path)
ds = DatasetDict({"train": ds["train_sft"], "test": ds["test_sft"]})
else:
ds = load_dataset(dataset_name_or_path)
# Normalize to DatasetDict and apply per-split limits
ds = _ensure_datasetdict(ds)
ds = _truncate_dataset(ds, kvs)
all_parts.append(ds)
# If only one part, return as DatasetDict
if len(all_parts) == 1:
return _ensure_datasetdict(all_parts[0])
# Merge all parts into a single DatasetDict
merged = all_parts[0]
for part in all_parts[1:]:
merged = _merge_datasetdicts(merged, part)
return _ensure_datasetdict(merged)
def load_pt_dataset(
dataset_args: str, streaming: bool = True, load_preprocessed_data: bool = False
) -> DatasetDict | IterableDatasetDict:
"""
Examples of dataset_args:
- "mlfoundations/dclm-baseline-1.0"
- "OpenCoder-LLM/opc-fineweb-code-corpus"
- "OpenCoder-LLM/opc-fineweb-math-corpus"
- "OpenCoder-LLM/opc-annealing-corpus[lang:python]"
- "wikitext[name:wikitext-103-v1}]"
"""
from dllm.data.opc import load_dataset_opc_annealing
specs = [p.strip() for p in dataset_args.split("|") if p.strip()]
if not specs:
raise ValueError("Empty dataset_args for load_pt_dataset.")
# ---------- Shared loader (only differs by streaming flag) ----------
def _load_base_dataset(
raw: str, *, streaming: bool
) -> tuple[DatasetDict | IterableDatasetDict, dict, str]:
"""
Returns: (base, kvs, dataset_name_or_path)
- Pops 'name' from kvs when applicable (e.g., wikitext).
- Applies identical matching logic for both streaming/non-streaming.
"""
dataset_name_or_path, kvs = parse_spec(raw)
dataset_name_or_path = resolve_with_base_env(
dataset_name_or_path, "BASE_DATASETS_DIR"
)
name = kvs.pop("name", None)
if load_preprocessed_data:
base = load_from_disk(dataset_name_or_path)
elif _match(dataset_name_or_path, ["OpenCoder-LLM/opc-annealing-corpus"]):
lang = kvs.pop("lang", None)
base = load_dataset_opc_annealing(
dataset_name_or_path, name=name, lang=lang, streaming=streaming
)
else:
base = load_dataset(dataset_name_or_path, name=name, streaming=streaming)
return base, kvs, dataset_name_or_path
# ---------- Streaming path ----------
def _load_one_streaming_spec(raw: str) -> IterableDatasetDict:
base, kvs, dataset_name_or_path = _load_base_dataset(raw, streaming=True)
split_names = list(base.keys())
single_split = len(split_names) == 1
single_split_name = split_names[0] if single_split else None
n_train = kvs.get("train")
n_test = kvs.get("test")
if (n_train is not None) or (n_test is not None):
if (n_train is not None) and (n_test is not None):
if single_split:
stream = base[single_split_name]
head = stream.take(n_train + n_test)
test = head.take(n_test)
train = head.skip(n_test).take(n_train)
return IterableDatasetDict({"train": train, "test": test})
else:
if "train" not in base or "test" not in base:
raise ValueError(
f"{dataset_name_or_path}: require 'train' and 'test' splits for train+test limits."
)
train = base["train"].take(n_train)
test = base["test"].take(n_test)
return IterableDatasetDict({"train": train, "test": test})
if n_train is not None:
if single_split:
train = base[single_split_name].take(n_train)
else:
if "train" not in base:
raise ValueError(
f"{dataset_name_or_path}: missing 'train' split for train limit."
)
train = base["train"].take(n_train)
return IterableDatasetDict({"train": train})
if n_test is not None:
if single_split:
test = base[single_split_name].take(n_test)
else:
if "test" not in base:
raise ValueError(
f"{dataset_name_or_path}: missing 'test' split for test limit."
)
test = base["test"].take(n_test)
return IterableDatasetDict({"test": test})
return base # already an IterableDatasetDict
# ---------- Non-streaming path (mirror load_sft_dataset; NO shuffle) ----------
def _load_one_nonstreaming_spec(raw: str) -> DatasetDict:
base, kvs, _ = _load_base_dataset(raw, streaming=False)
ds = _ensure_datasetdict(base) # normalize
ds = _truncate_dataset(ds, kvs) # apply limits (train/test/...)
return ds
# ---------- Load & Merge ----------
if streaming:
logger.info("Loading dataset in streaming mode.")
parts = [_load_one_streaming_spec(raw) for raw in specs]
merged = parts[0]
for p in parts[1:]:
merged = _merge_iterabledatasetdicts(merged, p)
# repeat streaming dataset infinitely
merged = IterableDatasetDict(
{k: (v.repeat(None) if k == "train" else v) for k, v in merged.items()}
)
return merged
else:
logger.info("Loading dataset in non-streaming mode.")
parts = [_load_one_nonstreaming_spec(raw) for raw in specs]
if len(parts) == 1:
return _ensure_datasetdict(parts[0])
merged = parts[0]
for p in parts[1:]:
merged = _merge_datasetdicts(merged, p)
return _ensure_datasetdict(merged)
def _truncate_split(split_data, n: int):
if n is None:
return split_data
try:
if hasattr(split_data, "select"):
# Hugging Face Dataset path
total = getattr(split_data, "num_rows", None)
if total is None:
# some Dataset types expose len(...)
total = len(split_data)
idx = list(range(min(n, total)))
return split_data.select(idx)
except Exception:
pass
try:
return split_data[:n]
except Exception:
# Last resort: iterate
return type(split_data)(item for i, item in enumerate(split_data) if i < n)
def _truncate_dataset(ds, limits: dict):
"""
Ensure and return a DatasetDict, truncating splits mentioned in `limits`.
"""
ds = _ensure_datasetdict(ds) # normalize first
out = {}
for split, data in ds.items():
n = limits.get(split, None)
out[split] = _truncate_split(data, n) if n is not None else data
return DatasetDict(out)
def _concat_splits(a, b):
"""
Concatenate two split objects (prefer 🤗 datasets).
"""
if a is b:
return a
if a is None:
return b
if b is None:
return a
# Prefer datasets' concatenate_datasets when both are Datasets
try:
from datasets import concatenate_datasets
if isinstance(a, Dataset) and isinstance(b, Dataset):
return concatenate_datasets([a, b])
except Exception:
pass
# Fallbacks
try:
return a + b
except Exception:
pass
try:
return type(a)(list(a) + list(b))
except Exception:
pass
raise TypeError(
f"Cannot concatenate split objects of types {type(a)} and {type(b)}"
)
def _merge_datasetdicts(d1, d2):
"""
Merge two DatasetDict-like mappings by concatenating splits present in either.
Always returns a DatasetDict.
"""
d1 = _ensure_datasetdict(d1)
d2 = _ensure_datasetdict(d2)
all_splits = set(d1.keys()) | set(d2.keys())
out = {}
for split in all_splits:
a = d1.get(split, None)
b = d2.get(split, None)
if a is None:
out[split] = b
elif b is None:
out[split] = a
else:
out[split] = _concat_splits(a, b)
return DatasetDict(out)
def _ensure_datasetdict(ds):
"""
Normalize various loader outputs into a DatasetDict.
- If loader returns a DatasetDict, return as is.
- If loader returns a mapping (e.g., dict of splits), wrap into DatasetDict.
- If loader returns a single Dataset/list/etc., assume it's 'train'.
"""
if isinstance(ds, DatasetDict):
return ds
if isinstance(ds, dict):
# Try to convert each split value to a Dataset if they aren't already.
# If they are already Datasets, DatasetDict will accept them directly.
return DatasetDict(ds)
# Single split -> assume train
return DatasetDict({"train": ds})
def _match(name: str, needle) -> bool:
"""
Returns True if `name` matches any of the provided needles.
Accepts a single string or a list/tuple of strings.
Match condition: name endswith(needle) or needle in name.
"""
if isinstance(needle, (list, tuple)):
return any(name.endswith(n) or n in name for n in needle)
return name.endswith(needle) or needle in name
def _concat_iterable_datasets(parts: list[IterableDataset]) -> IterableDataset:
"""
Concatenate IterableDatasets sequentially without materialization.
Preserves streaming nature; supports downstream .take()/.skip()/.shuffle().
"""
if not parts:
raise ValueError("No IterableDatasets to concatenate.")
# Try to reuse features from the first dataset when available
features = getattr(parts[0], "features", None)
def _gen():
for ds in parts:
yield from ds
return IterableDataset.from_generator(_gen, features=features)
def _ensure_iterabledatasetdict(obj) -> IterableDatasetDict:
if isinstance(obj, IterableDatasetDict):
return obj
if isinstance(obj, dict):
return IterableDatasetDict(obj)
# Single stream -> assume train
return IterableDatasetDict({"train": obj})
def _merge_iterabledatasetdicts(
d1: IterableDatasetDict, d2: IterableDatasetDict
) -> IterableDatasetDict:
"""
Merge by concatenating any overlapping splits (streaming-safe).
"""
d1 = _ensure_iterabledatasetdict(d1)
d2 = _ensure_iterabledatasetdict(d2)
all_splits = set(d1.keys()) | set(d2.keys())
out = {}
for split in all_splits:
a = d1.get(split, None)
b = d2.get(split, None)
if a is None:
out[split] = b
elif b is None:
out[split] = a
else:
out[split] = _concat_iterable_datasets([a, b])
return IterableDatasetDict(out)
def _truncate_stream(ds: IterableDataset, n: int | None) -> IterableDataset:
if n is None:
return ds
return ds.take(n)
if __name__ == "__main__":
breakpoint()

View File

@ -0,0 +1 @@
from . import llada, dream, rnd, editflow

View File

@ -0,0 +1,362 @@
"""
accelerate launch \
--num_processes 2 \
dllm/pipelines/bert/eval.py \
--tasks gsm8k \
--batch_size 1 \
--model bert \
--device cuda \
--num_fewshot 8 \
--model_args "pretrained=dllm-collection/ModernBERT-base-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
"""
from types import SimpleNamespace
from dataclasses import dataclass
import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
import dllm
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
@dataclass
class BERTEvalConfig(LLaDAGeneratorConfig):
max_new_tokens: int = 128
max_length: int = 512
steps: int = 128
block_length: int = 128
pretrained: str = ""
dtype: str | torch.dtype = "auto"
batch_size: int = 32
mc_num: int = 128
is_check_greedy: bool = True
device: str = "cuda"
@register_model("bert")
class BERTEvalHarness(LM):
def __init__(
self,
config: BERTEvalConfig | None = None,
**kwargs,
):
super().__init__()
# Initialize config if not provided
if config is None:
config = BERTEvalConfig()
# Pull args from config, allow kwargs to override
pretrained = kwargs.get("pretrained", config.pretrained)
dtype = kwargs.get("dtype", config.dtype)
batch_size = kwargs.get("batch_size", config.batch_size)
mc_num = kwargs.get("mc_num", config.mc_num)
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
device = kwargs.get("device", config.device)
cfg = kwargs.get("cfg", config.cfg_scale)
steps = kwargs.get("steps", config.steps)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
block_length = kwargs.get("block_length", config.block_length)
max_length = kwargs.get("max_length", config.max_length)
remasking = kwargs.get("remasking", config.remasking)
accelerator = accelerate.Accelerator()
# Get GLOBAL rank from torch.distributed (not accelerator)
if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
self._world_size = (
torch.distributed.get_world_size()
) # ← GLOBAL world size (16)
else:
self._rank = 0
self._world_size = 1
# Use accelerator for device placement
self.model = dllm.utils.get_model(
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
)
self.model.eval()
if accelerator.num_processes > 1:
# Let accelerator handle device placement
self.model = accelerator.prepare(self.model)
self.device = (
accelerator.device
) # ← Accelerator figures out local device correctly
self.accelerator = accelerator
else:
# Single GPU
self.model = self.model.to(device)
self.device = torch.device(device)
self.accelerator = None
self.tokenizer = dllm.utils.get_tokenizer(
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
)
# generation params
self.mask_id = self.tokenizer.mask_token_id
self.batch_size = int(batch_size)
self.max_length = int(max_length)
self.max_new_tokens = int(max_new_tokens)
self.block_length = int(block_length)
self.steps = int(steps)
self.cfg = float(cfg)
self.remasking = remasking
self.is_check_greedy = is_check_greedy
# loglikelihood params
self.mc_num = int(mc_num)
assert mc_num % self.batch_size == 0
self.sampling_eps = 0.0
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def _forward_process(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
b, l = batch.shape
target_len = (l - prompt_index.sum()).item()
k = torch.randint(1, target_len + 1, (), device=batch.device)
x = torch.round(
torch.linspace(
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
)
).long()
x = ((x - 1) % target_len) + 1
assert x.min() >= 1 and x.max() <= target_len
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
is_mask = indices < x.unsqueeze(1)
for i in range(b):
is_mask[i] = is_mask[i][torch.randperm(target_len)]
is_mask = torch.cat(
(
torch.zeros(
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
),
is_mask,
),
dim=1,
)
noisy_batch = torch.where(is_mask, self.mask_id, batch)
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
@torch.no_grad()
def get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
if self.cfg > 0.0:
assert len(prompt_index) == batch.shape[1]
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
un_batch = batch.clone()
un_batch[prompt_index] = self.mask_id
batch = torch.cat([batch, un_batch])
logits = self.model(batch).logits
if self.cfg > 0.0:
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
return logits[:, : batch.shape[1]]
@torch.no_grad()
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
seq = torch.concatenate([prefix, target])[None, :]
seq = seq.repeat((self.batch_size, 1)).to(self.device)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
loss_acc = []
for _ in range(self.mc_num // self.batch_size):
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
mask_indices = perturbed_seq == self.mask_id
logits = self.get_logits(perturbed_seq, prompt_index)
loss = (
F.cross_entropy(
logits[mask_indices], seq[mask_indices], reduction="none"
)
/ p_mask[mask_indices]
)
loss = loss.sum() / self.batch_size
loss_acc.append(loss.item())
return -sum(loss_acc) / len(loss_acc)
@torch.no_grad()
def suffix_greedy_prediction(
self, prefix: torch.Tensor, target: torch.Tensor
) -> bool:
if not self.is_check_greedy:
return False
seq = torch.full(
(1, len(prefix) + len(target)), self.mask_id, device=self.device
)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
prefix, target = prefix.to(self.device), target.to(self.device)
seq[0, : len(prefix)] = prefix
for i in range(len(target)):
mask_index = seq == self.mask_id
logits = self.get_logits(seq, prompt_index)[mask_index]
x0 = torch.argmax(logits, dim=-1)
p = torch.softmax(logits.to(torch.float32), dim=-1)
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
dim=-1
)
_, index = torch.sort(confidence, descending=True)
x0[index[1:]] = self.mask_id
seq[mask_index] = x0.clone()
correct = target == seq[0, len(prefix) :]
correct = torch.all(correct)
return correct
def _encode_pair(
self, context: str, continuation: str
) -> tuple[torch.Tensor, torch.Tensor]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tokenizer(context + continuation)["input_ids"]
context_enc = self.tokenizer(context)["input_ids"]
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
def _tokenize(e):
prefix, target = self._encode_pair(e["prefix"], e["target"])
return {
"prefix_text": e["prefix"],
"target_text": e["target"],
"prefix": prefix,
"target": target,
}
ds = []
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
assert max(prompt_len) <= 4096
out = []
with torch.no_grad():
for elem in tqdm(ds, desc="Computing likelihood..."):
prefix = elem["prefix"]
target = elem["target"]
ll = self.get_loglikelihood(prefix, target)
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
torch.cuda.empty_cache()
return out
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
raise NotImplementedError
def generate_until(self, requests: list[Instance]):
def _tokenize(e):
return {
"question": self.tokenizer(e["question"])["input_ids"],
"question_text": e["question"],
"until": e["until"],
}
ds = [
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
out = []
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
for elem in tqdm(ds, desc="Generating..."):
prompt = [elem["question"][1:-1].to(self.device)]
stop_tokens = elem["until"]
generated_ids = generator.generate(
inputs=prompt,
steps=self.steps,
max_new_tokens=self.max_new_tokens,
block_length=self.block_length,
temperature=0.0,
cfg_scale=self.cfg,
remasking=self.remasking,
)
generated_answer = self.tokenizer.decode(
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
)
breakpoint()
for stop_seq in stop_tokens:
if stop_seq in generated_answer:
generated_answer = generated_answer.split(stop_seq)[0]
# remove special tokens
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
generated_answer = self.tokenizer.decode(
generated_answer_ids, skip_special_tokens=True
)
out.append(generated_answer)
if self.accelerator is not None:
self.accelerator.wait_for_everyone()
return out
if __name__ == "__main__":
cli_evaluate()

View File

@ -0,0 +1,6 @@
from . import generator, models, trainer, utils
from .models.modeling_dream import DreamModel
from .models.configuration_dream import DreamConfig
from .models.tokenization_dream import DreamTokenizer
from .generator import DreamGeneratorConfig, DreamGenerator
from .trainer import DreamTrainer

View File

@ -0,0 +1,533 @@
"""
accelerate launch \
--num_processes 2 \
dllm/pipelines/dream/eval.py \
--tasks gsm8k \
--batch_size 1 \
--model dream \
--device cuda
--num_fewshot 0 \
--model_args "pretrained=Dream-org/Dream-v0-Base-7B,mc_num=1,max_new_tokens=512,max_length=512,steps=512,temperature=0.2,top_p=0.95,add_bos_token=true,escape_until=true"
"""
import logging
from types import SimpleNamespace
from dataclasses import dataclass
import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
import dllm
from dllm.pipelines.dream import DreamGenerator, DreamGeneratorConfig
eval_logger = logging.getLogger(__name__)
@dataclass
class DreamEvalConfig(DreamGeneratorConfig):
top_p: float | None = None
top_k: float | None = None
max_new_tokens: int = 128
max_length: int = 2048
steps: int = 128
temperature: float = 0.0
alg: str = "entropy"
pretrained: str = ""
batch_size: int = 1
device: str = "cuda"
dtype: str | torch.dtype = "auto"
add_bos_token: bool = False
nll_type: str = "mc"
log_type: str = "ftb"
mc_num: int = 128
classifier_free_guidance: float = 1.0
sampling_eps: float = 1e-3
escape_until: bool = False
@register_model("dream")
class DreamEvalHarness(LM):
def __init__(
self,
config: DreamEvalConfig | None = None,
**kwargs,
) -> None:
super().__init__()
# Initialize config if not provided
if config is None:
config = DreamEvalConfig()
# Pull args from config, allow kwargs to override
pretrained = kwargs.get("pretrained", config.pretrained)
batch_size = kwargs.get("batch_size", config.batch_size)
device = kwargs.get("device", config.device)
dtype = kwargs.get("dtype", config.dtype)
max_length = kwargs.get("max_length", config.max_length)
add_bos_token = kwargs.get("add_bos_token", config.add_bos_token)
nll_type = kwargs.get("nll_type", config.nll_type)
log_type = kwargs.get("log_type", config.log_type)
mc_num = kwargs.get("mc_num", config.mc_num)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
classifier_free_guidance = kwargs.get(
"classifier_free_guidance", config.classifier_free_guidance
)
sampling_eps = kwargs.get("sampling_eps", config.sampling_eps)
steps = kwargs.get("steps", config.steps)
temperature = kwargs.get("temperature", config.temperature)
top_p = kwargs.get("top_p", config.top_p)
top_k = kwargs.get("top_k", config.top_k)
alg = kwargs.get("alg", config.alg)
alg_temp = kwargs.get("alg_temp", config.alg_temp)
escape_until = kwargs.get("escape_until", config.escape_until)
accelerator = accelerate.Accelerator()
# Get GLOBAL rank from torch.distributed (not accelerator)
if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
self._world_size = (
torch.distributed.get_world_size()
) # ← GLOBAL world size (16)
else:
self._rank = 0
self._world_size = 1
# Use accelerator for device placement
self.model = dllm.utils.get_model(
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
)
self.model.eval()
if accelerator.num_processes > 1:
# Let accelerator handle device placement
self.model = accelerator.prepare(self.model)
self.device = (
accelerator.device
) # ← Accelerator figures out local device correctly
self.accelerator = accelerator
else:
# Single GPU
self.model = self.model.to(device)
self.device = torch.device(device)
self.accelerator = None
self.tokenizer = dllm.utils.get_tokenizer(
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
)
# generation params
self.mask_id = self.tokenizer.mask_token_id
self.max_length = max_length
self.add_bos_token = add_bos_token
self.batch_size = int(batch_size)
self.max_new_tokens = max_new_tokens
self.steps = steps
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alg = alg
self.alg_temp = alg_temp
self.escape_until = escape_until
# loglikelihood params
self.nll_type = nll_type
self.log_type = log_type
self.mc_num = mc_num
self.classifier_free_guidance = classifier_free_guidance
self.sampling_eps = sampling_eps
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_decode(
self, tokens: torch.Tensor | list[int], skip_special_tokens: bool = True
) -> str:
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def tok_encode(self, text: str, add_special_tokens: bool = True) -> torch.Tensor:
return self.tokenizer(
text, return_tensors="pt", add_special_tokens=add_special_tokens
).input_ids
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
def generate_until(
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
generator = DreamGenerator(model=self.model, tokenizer=self.tokenizer)
for batch_idx in range(0, len(requests), self.batch_size):
batch_requests = requests[batch_idx : batch_idx + self.batch_size]
contexts, gen_args = zip(*[req.arguments for req in batch_requests])
# ====== BEGIN merged _generate_batch logic ======
prompts = list(contexts)
if self.add_bos_token:
prompts = [self.tokenizer.bos_token + p for p in prompts]
# tokenize
prompt_ids = [
self.tokenizer(
p, return_tensors="pt", padding=False
).input_ids.squeeze()
for p in prompts
]
prompt_lens = [len(p_id) for p_id in prompt_ids]
if max(prompt_lens) > self.max_length - self.max_new_tokens:
cutoff_len = self.max_length - self.max_new_tokens
eval_logger.warning(
f"Prompt length {max(prompt_lens)} exceeds {cutoff_len}, cutoff on the left side"
)
# ✅ Correct: trim from the left side (keep the last cutoff_len tokens)
prompt_ids = [p_id[-cutoff_len:] for p_id in prompt_ids]
# generation
generation_ids = generator.generate(
max_new_tokens=self.max_new_tokens,
inputs=prompt_ids,
steps=self.steps,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
alg=self.alg,
alg_temp=self.alg_temp,
output_history=False,
return_dict_in_generate=False,
)
# decode and cleanup
cleaned_generation_ids = [
(
seq[seq.ne(self.tokenizer.eos_token_id).float().argmax().long() :]
if (seq != self.tokenizer.eos_token_id).any()
else seq[-1:]
)
for seq in generation_ids
]
truncated_generation_ids = [
seq[prompt_lens[i] :] for i, seq in enumerate(cleaned_generation_ids)
]
responses = [
g.lstrip("<|endoftext|>").split(self.tokenizer.eos_token, 1)[0]
for g in self.tokenizer.batch_decode(truncated_generation_ids)
]
# ====== END merged _generate_batch logic ======
# handle "until" truncation
if not self.escape_until:
for i, r in enumerate(responses):
for s in gen_args[0]["until"]:
r = r.split(s)[0]
responses[i] = r
res.extend(responses)
pbar.update(len(contexts))
return res
def _forward_process(
self, batch: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
b, l = batch.shape
# sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
indices = torch.arange(b, device=batch.device).float()
t = (u0 + indices / b) % 1
p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
p_mask = p_mask[:, None].repeat(1, l)
mask_indices = torch.rand((b, l), device=batch.device) < p_mask
# always unmask bos and eos
mask_indices[:, 0] = False
mask_indices[:, -1] = False
noisy_batch = torch.where(mask_indices, self.mask_id, batch)
return noisy_batch, p_mask
@torch.no_grad()
def get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
"""
prompt_index : 1D bool tensor, length=batch.shape[1]
"""
if self.classifier_free_guidance > 1.0:
assert len(prompt_index) == batch.shape[1]
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
un_batch = batch.clone()
un_batch[prompt_index] = self.mask_id
batch = torch.cat([batch, un_batch])
input = batch
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
logits = self.model(input).logits
# since bos always unmask, the first logits will not be used
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
if self.classifier_free_guidance > 1.0:
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + self.cfg * (logits - un_logits)
return logits[:, : batch.shape[1]]
@torch.no_grad()
def _eval_target_nll_mc(
self, prefix: torch.Tensor | None, target: torch.Tensor
) -> float:
if prefix is None:
seq = target[None, :]
else:
seq = torch.concatenate([prefix, target])[None, :]
seq = seq.repeat((self.batch_size, 1)).to(self.device)
if self.log_type == "ftb":
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
else:
prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
loss_acc = []
for _ in range(max(self.mc_num // self.batch_size, 1)):
perturbed_seq = seq.clone()
# eval_logger.info("before noising")
perturbed_seq_, p_mask = self._forward_process(seq)
# eval_logger.info("end noising")
if self.log_type == "ftb":
perturbed_seq[:, -len(target) :] = perturbed_seq_[:, -len(target) :]
elif self.log_type == "btf":
perturbed_seq[:, : len(prefix)] = perturbed_seq_[:, : len(prefix)]
elif self.log_type == "union":
perturbed_seq = perturbed_seq_
else:
raise NotImplementedError(self.log_type)
mask_indices = perturbed_seq == self.mask_id
logits = self.get_logits(perturbed_seq, prompt_index)
loss = (
F.cross_entropy(
logits[mask_indices], seq[mask_indices], reduction="none"
)
/ p_mask[mask_indices]
)
loss = loss.sum() / self.batch_size
loss_acc.append(loss.item())
return sum(loss_acc) / len(loss_acc)
@torch.no_grad()
def _eval_target_nll_ar(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
assert self.log_type in ["ftb", "btf"]
assert self.nll_type in ["ar_ftb", "ar_btf"]
if self.log_type == "ftb":
prompt_index = (
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
< prefix.shape[1]
)
else:
prompt_index = (
torch.arange(prefix.shape[1] + target.shape[1], device=self.device)
>= prefix.shape[1]
)
if self.log_type == "ftb":
perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
else:
perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
mask_index = torch.ones(
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
)
if self.nll_type == "ar_ftb":
mask_index = torch.triu(mask_index)
else:
mask_index = torch.tril(mask_index)
perturbed_[mask_index] = self.mask_id
if self.log_type == "ftb":
perturbed_seq = torch.cat(
[prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1
)
else:
perturbed_seq = torch.cat(
[perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1
)
logits_ = []
num = (
len(perturbed_seq) // self.batch_size
if len(perturbed_seq) % self.batch_size == 0
else len(perturbed_seq) // self.batch_size + 1
)
for i in range(num):
end = (
(i + 1) * self.batch_size
if (i + 1) * self.batch_size < len(perturbed_seq)
else len(perturbed_seq)
)
perturbed_seq_ = perturbed_seq[i * self.batch_size : end]
perturbed_seq_ = perturbed_seq_.to(self.device)
if len(perturbed_seq_.shape) == 1:
perturbed_seq_ = perturbed_seq_.unsqueeze(0)
logits = self.get_logits(perturbed_seq_, prompt_index)
logits_.append(logits.cpu())
logits = torch.cat(logits_, dim=0)
temp_index = torch.ones(
(perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool
)
if self.nll_type == "ar_ftb":
temp_index = torch.triu(temp_index, diagonal=1)
else:
temp_index = torch.tril(temp_index, diagonal=-1)
mask_index[temp_index] = False
if self.log_type == "ftb":
logits_index = torch.cat(
[
torch.zeros(
(perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool
),
mask_index,
],
dim=-1,
)
else:
logits_index = torch.cat(
[
mask_index,
torch.zeros(
(perturbed_.shape[1], target.shape[1]), dtype=torch.bool
),
],
dim=-1,
)
if self.log_type == "ftb":
loss = (
F.cross_entropy(logits[logits_index], target[0], reduction="sum")
.cpu()
.item()
)
else:
loss = (
F.cross_entropy(logits[logits_index], prefix[0], reduction="sum")
.cpu()
.item()
)
return loss
def _encode_pair(
self, context: str, continuation: str
) -> tuple[torch.Tensor, torch.Tensor]:
if self.add_bos_token:
context = self.tokenizer.bos_token + context
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tokenizer.encode(context + continuation) + [
self.tokenizer.eos_token_id
]
context_enc = self.tokenizer.encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
# by default truncate on the left
cutoff_length = max(len(whole_enc) - self.max_length, 0)
if cutoff_length > 0:
eval_logger.warning(
f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side"
)
context_remain = context_enc_len - cutoff_length
if context_remain > 0:
context_enc = context_enc[-context_remain:]
else:
eval_logger.warning(f"All context (prompt) is truncated.")
context_enc = ""
continuation_enc = whole_enc[-self.max_length :]
return context_enc, continuation_enc
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
def _tokenize(e):
prefix, target = self._encode_pair(e["prefix"], e["target"])
return {
"prefix_text": e["prefix"],
"target_text": e["target"],
"prefix": prefix,
"target": target,
}
ds = []
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
out = []
with torch.no_grad():
for elem in tqdm(ds, desc="Computing likelihood..."):
prefix = elem["prefix"]
target = elem["target"]
# likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
if self.nll_type == "mc":
ll = -self._eval_target_nll_mc(prefix, target)
if self.log_type == "union":
ll = ll / (len(target) + len(prefix))
elif self.nll_type == "ar_ftb" or self.nll_type == "ar_btf":
ll = -self._eval_target_nll_ar(prefix, target)
else:
raise NotImplementedError(self.nll_type)
# TODO: greedy decoding
is_target_greedy_dec = False
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
return out
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
raise NotImplementedError
if __name__ == "__main__":
cli_evaluate()

View File

@ -0,0 +1,426 @@
"""
reference: https://huggingface.co/Dream-org/Dream-v0-Base-7B/blob/main/generation_utils.py
"""
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import torch.distributions as dists
from dllm.utils.generation_utils import get_num_transfer_tokens
from dllm.pipelines.dream.utils import top_p_logits, top_k_logits
from dllm.core.generation.generator import (
GeneratorOutput,
GeneratorConfig,
BaseGenerator,
)
def sample_tokens(
logits: torch.Tensor,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
margin_confidence: bool = False,
neg_entropy: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except Exception:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
top1_probs = sorted_probs[:, 0]
top2_probs = sorted_probs[:, 1]
confidence = top1_probs - top2_probs
if neg_entropy:
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
@dataclass
class DreamGeneratorConfig(GeneratorConfig):
max_new_tokens: int = 20
max_length: int = (
None # The max_length is set as input_ids.shape[1] + 20: generation_config.max_length = generation_config.max_length + input_ids_length
)
steps: int = 512
eps: float = 1e-3
alg: str = "origin"
alg_temp: float = 0.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
stochastic_transfer: bool = False
@dataclass
class DreamGenerator(BaseGenerator):
@torch.no_grad()
def generate(
self,
inputs: list[torch.Tensor, list],
config: DreamGeneratorConfig | None = None,
generation_tokens_hook_func=lambda step, x, logits: x,
generation_logits_hook_func=lambda step, x, logits: logits,
**kwargs,
) -> GeneratorOutput | torch.Tensor:
"""
Diffusion-style masked decoding for *generation from inputs*.
(docstring unchanged)
"""
if config is None:
config = DreamGeneratorConfig()
# ----- pull args from config, allow kwargs to override -----
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
max_length = kwargs.get("max_length", config.max_length)
steps = kwargs.get("steps", config.steps)
eps = kwargs.get("eps", config.eps)
alg = kwargs.get("alg", config.alg)
alg_temp = kwargs.get("eps", config.alg_temp)
temperature = kwargs.get("temperature", config.temperature)
top_p = kwargs.get("top_p", config.top_p)
top_k = kwargs.get("top_k", config.top_k)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
# generation_tokens_hook_func = kwargs.get("generation_tokens_hook_func", config.generation_tokens_hook_func)
# generation_logits_hook_func = kwargs.get("generation_logits_hook_func", config.generation_logits_hook_func)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
# --- Initialization ---
mask_token_id = self.tokenizer.mask_token_id
eos_token_id = self.tokenizer.eos_token_id
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
prompt_lens = [p.shape[0] for p in inputs]
if max_new_tokens:
max_length = max_new_tokens + max(prompt_lens)
else:
max_new_tokens = max_length - max(prompt_lens)
B = len(inputs)
T = max_length
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
seq_length = []
for i, p in enumerate(inputs):
total_len = prompt_lens[i] + max_new_tokens
seq_length.append(total_len)
start = T - total_len
x[i, start : start + prompt_lens[i]] = p
x[i, start + prompt_lens[i] : T] = mask_token_id
attention_mask = torch.zeros(
(B, T), dtype=torch.float32, device=self.model.device
)
for j, L in enumerate(seq_length):
if L > 0:
attention_mask[j, -L:] = 1.0 # Mandate to be left-padding
if attention_mask is not None and torch.any(attention_mask == 0.0):
pos_id = attention_mask.long().cumsum(-1) - 1
pos_id.masked_fill_(attention_mask == 0, 1)
else:
pos_id = None
mask_index = x == mask_token_id
num_transfer_tokens_list = get_num_transfer_tokens(
mask_index=mask_index,
steps=steps,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
effective_steps = num_transfer_tokens_list.size(1)
# --- Iterative refinement ---
x = generation_tokens_hook_func(None, x, None)
histories = [x.clone()] if return_dict_in_generate else None
for i in range(effective_steps):
mask_index = x == mask_token_id
logits = self.model(x, attention_mask, pos_id).logits
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
logits = generation_logits_hook_func(i, x, logits)
mask_logits = logits[mask_index]
if alg == "maskgit_plus":
confidence, x0 = sample_tokens(
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
)
elif alg == "topk_margin":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
margin_confidence=True,
)
elif alg == "entropy":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
neg_entropy=True,
)
else:
raise RuntimeError(f"Unknown alg: {alg}")
full_confidence = torch.full_like(
x, -torch.inf, device=self.model.device, dtype=logits.dtype
)
full_confidence[mask_index] = confidence
for j in range(full_confidence.shape[0]):
number_transfer_tokens = num_transfer_tokens_list[j, i]
if number_transfer_tokens > 0:
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(
full_confidence[j], number_transfer_tokens
)
else:
fc = full_confidence[j] / alg_temp
fc = F.softmax(fc, dim=-1)
transfer_index = torch.multinomial(
fc, num_samples=number_transfer_tokens
)
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
x_[mask_index] = x0.clone()
x[j, transfer_index] = x_[j, transfer_index]
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)
@torch.no_grad()
def infill(
self,
inputs: list[torch.Tensor, list],
config,
generation_tokens_hook_func=lambda step, x, logits: x,
generation_logits_hook_func=lambda step, x, logits: logits,
**kwargs,
) -> GeneratorOutput | torch.Tensor:
"""
Fill in-place the tokenizer's `<mask>` tokens contained in `inputs`.
The whole (right-aligned) canvas is denoised iteratively: at each step, a scheduler
decides how many masked positions to commit, and a confidence rule (`alg`)
selects *which* positions to reveal (MaskGIT-style). Non-mask tokens are never changed.
High-level:
1) Build a right-aligned canvas per sample (left side padded with EOS).
2) Compute a per-sample transfer schedule via `scheduler` and `steps`.
3) At each step: forward pass → AR-shift logits → score masked positions
via `alg` → choose indices to commit (top-k or soft sampling) → write tokens.
Notes:
- Right padding uses EOS (serves as pad here).
- Only `[MASK]` positions are updated; original tokens remain intact.
- Logits are AR-shifted to preserve next-token prediction alignment.
Args:
model:
Mask predictor; returns logits of shape [B, T, V] when called as
`model(x, attention_mask, pos_id)`.
tokenizer:
Must provide `mask_token_id` and `eos_token_id`.
inputs:
List of 1D LongTensors (token ids). Each may contain `<mask>` tokens
to be filled; other tokens are treated as fixed context.
scheduler (BaseAlphaScheduler):
Controls how many masks to commit per step (deterministic or stochastic).
generation_tokens_hook_func / generation_logits_hook_func:
Optional hooks to intercept tokens/logits at each step.
output_history (bool):
If True, save intermediate canvases at each step.
return_dict_in_generate (bool):
If True, return `DreamModelOutput(sequences, history)`, else only `[B, T]`.
steps (int):
Total reverse-diffusion steps (qualityspeed trade-off).
alg (str):
Confidence rule to rank masked positions:
- "maskgit_plus": softmax probs
- "topk_margin": top1 - top2 margin
- "entropy": negative entropy
alg_temp (float):
Temperature for *confidence-based index sampling* (when > 0, soft selection).
temperature / top_p / top_k:
Token sampling hyperparameters within `sample_tokens`.
stochastic_transfer (bool):
If True, sample the number of transfers per step (Binomial); else use expectation.
Returns:
DreamModelOutput | torch.LongTensor:
If `return_dict_in_generate=True`, returns
- sequences: `[B, T]` final tokens
- history: optional list of intermediate canvases
Otherwise returns only `[B, T]`.
"""
# ----- pull args from config, allow kwargs to override -----
steps = kwargs.get("steps", config.steps)
eps = kwargs.get("eps", config.eps)
alg = kwargs.get("alg", config.alg)
alg_temp = kwargs.get("eps", config.alg_temp)
temperature = kwargs.get("temperature", config.temperature)
top_p = kwargs.get("top_p", config.top_p)
top_k = kwargs.get("top_k", config.top_k)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
# generation_tokens_hook_func = kwargs.get("stochastic_transfer", config.generation_tokens_hook_func)
# generation_logits_hook_func = kwargs.get("stochastic_transfer", config.generation_logits_hook_func)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
# --- Initialization ---
mask_token_id = self.tokenizer.mask_token_id
eos_token_id = self.tokenizer.eos_token_id
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
B = len(inputs)
seq_lens = [t.shape[0] for t in inputs]
T = max(seq_lens)
# Build right-aligned canvas; left side padded with EOS (used as pad)
x = torch.full((B, T), eos_token_id, dtype=torch.long, device=self.model.device)
for i, t in enumerate(inputs):
L = seq_lens[i]
x[i, -L:] = t
# Build 1D attention mask (valid tokens on the right)
attention_mask = torch.zeros((B, T), dtype=torch.bool, device=self.model.device)
for j, L in enumerate(seq_lens):
if L > 0:
attention_mask[j, -L:] = True
# Expand to pairwise attention if left padding is present
if torch.any(attention_mask == 0.0):
pos_id = attention_mask.long().cumsum(-1) - 1
pos_id.masked_fill_(attention_mask == 0, 1)
else:
pos_id = None
attention_mask = "full"
# Precompute per-sample transfer schedule (how many to commit per step)
mask_index = x == mask_token_id
num_transfer_tokens_list = get_num_transfer_tokens(
mask_index=mask_index,
steps=steps,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
effective_steps = num_transfer_tokens_list.size(1)
# Optional initial token hook
x = generation_tokens_hook_func(None, x, None)
histories = [x.clone()] if return_dict_in_generate else None
for i in range(effective_steps):
mask_index = x == mask_token_id
# Forward pass, then AR-shift to predict token at position i+1
logits = self.model(x, attention_mask, pos_id).logits
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
logits = generation_logits_hook_func(i, x, logits)
# Logits restricted to current `[MASK]` positions
mask_logits = logits[mask_index]
# Confidence scoring for masked positions
if alg == "maskgit_plus":
confidence, x0 = sample_tokens(
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
)
elif alg == "topk_margin":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
margin_confidence=True,
)
elif alg == "entropy":
confidence, x0 = sample_tokens(
mask_logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
neg_entropy=True,
)
else:
raise RuntimeError(f"Unknown alg: {alg}")
# Scatter per-position confidence back to full canvas
full_confidence = torch.full_like(
x, -torch.inf, device=self.model.device, dtype=logits.dtype
)
full_confidence[mask_index] = confidence
# Commit the scheduled number of tokens per sample
for j in range(B):
number_transfer_tokens = num_transfer_tokens_list[j, i]
if number_transfer_tokens > 0:
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(
full_confidence[j], number_transfer_tokens
)
else:
fc = full_confidence[j] / alg_temp
fc = F.softmax(fc, dim=-1)
transfer_index = torch.multinomial(
fc, num_samples=number_transfer_tokens
)
# Candidate tokens at masked positions only
x_ = torch.full_like(x, mask_token_id, device=self.model.device)
x_[mask_index] = x0.clone()
x[j, transfer_index] = x_[j, transfer_index]
# Optional token hook + history logging
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)

View File

@ -0,0 +1,13 @@
from .configuration_dream import DreamConfig
from .modeling_dream import DreamModel
# Register with HuggingFace Auto classes for local usage
try:
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
AutoConfig.register("Dream", DreamConfig)
AutoModel.register(DreamConfig, DreamModel)
AutoModelForMaskedLM.register(DreamConfig, DreamModel)
except ImportError:
# transformers not available or Auto classes not imported
pass

View File

@ -0,0 +1,85 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dream model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DreamConfig(PretrainedConfig):
model_type = "Dream"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=False, # cache not used in diffusion
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
mask_token_id=151666,
pad_token_id=151643,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id

View File

@ -0,0 +1,465 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import (
GenerationConfig
)
from transformers.utils import (
ModelOutput,
is_torchdynamo_compiling,
logging,
)
logger = logging.get_logger(__name__)
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
# Extract top1 and top2 probabilities
top1_probs = sorted_probs[:, 0]
top2_probs = sorted_probs[:, 1]
# Calculate confidence as top1 - top2
confidence = top1_probs - top2_probs
if neg_entropy:
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
@dataclass
class DreamModelOutput(ModelOutput):
sequences: torch.LongTensor = None
history: Optional[Tuple[torch.FloatTensor]] = None
class DreamGenerationConfig(GenerationConfig):
def __init__(self, **kwargs):
self.temperature: float = kwargs.pop("temperature", 0.0)
self.top_p: Optional[float] = kwargs.pop("top_p", None)
self.top_k: Optional[int] = kwargs.pop("top_k", None)
self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
# diffusion specific params
self.eps: float = kwargs.pop("eps", 1e-3)
self.steps: int = kwargs.pop("steps", 512)
self.alg: str = kwargs.pop("alg", 'origin')
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
# Parameters that define the output variables of `generate`
self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
self.output_history: bool = kwargs.pop("output_history", False)
# Special tokens that can be used at generation time
self.mask_token_id = kwargs.pop("mask_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
# interface.
self._from_model_config = kwargs.pop("_from_model_config", False)
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)
# Additional attributes without default values
if not self._from_model_config:
# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
# model's default configuration file
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
# Validate the values of the attributes
self.validate(is_init=True)
def validate(self, is_init=False, **kwargs):
pass
class DreamGenerationMixin:
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
# Do not call torch.repeat_interleave if expand_size is 1 because it clones
# the input tensor and thus requires more memory although no change is applied
if expand_size == 1:
return input_ids, attention_mask
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
return input_ids, attention_mask
def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length"""
# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
return
# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
warnings.warn(
f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
"generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
"generation.",
UserWarning,
)
if input_ids_length >= generation_config.max_length:
input_ids_string = "input_ids"
raise ValueError(
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_length` or, better yet, setting `max_new_tokens`."
)
def _prepare_generated_length(
self,
generation_config,
has_default_max_length,
input_ids_length,
):
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None:
logger.warning(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
elif has_default_max_length:
if generation_config.max_length == DreamGenerationConfig().max_length:
generation_config.max_length = generation_config.max_length + input_ids_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
if max_position_embeddings is not None:
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
return generation_config
def _prepare_generation_config(
self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
) -> DreamGenerationConfig:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs. This
function handles retrocompatibility with respect to configuration files.
"""
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
using_model_generation_config = False
if generation_config is None:
generation_config = DreamGenerationConfig.from_model_config(self.config)
using_model_generation_config = True
# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
# exception will be raised in `_validate_model_kwargs`
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.mask_token_id is None:
generation_config.mask_token_id = self.generation_config.mask_token_id
return generation_config
def _prepare_special_tokens(
self,
generation_config: DreamGenerationConfig,
device: Optional[Union[torch.device, str]] = None,
):
"""
Prepares the special tokens for generation, overwriting the generation config with their processed versions
converted to tensor.
Note that `generation_config` is changed in place and stops being serializable after this method is called.
That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""
# Convert special tokens to tensors
def _tensor_or_none(token, device=None):
if token is None:
return token
device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)
# Set pad token if unset (and there are conditions to do so)
if pad_token_tensor is None and eos_token_tensor is not None:
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# Update generation config with the updated special tokens tensors
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
# (in their non-tensor form), in order to enable end-to-end compilation. See
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
generation_config._bos_token_tensor = bos_token_tensor
generation_config._eos_token_tensor = eos_token_tensor
generation_config._pad_token_tensor = pad_token_tensor
generation_config._mask_token_tensor = mask_token_tensor
@torch.no_grad()
def diffusion_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[DreamGenerationConfig] = None,
**kwargs,
) -> Union[DreamModelOutput, torch.LongTensor]:
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
generation_config = self._prepare_generation_config(generation_config, **kwargs)
generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
# 2. Define model inputs
assert inputs is not None
input_ids = inputs
device = input_ids.device
attention_mask = kwargs.pop("attention_mask", None)
self._prepare_special_tokens(generation_config, device=device)
# 3. Prepare `max_length`.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
input_ids_length=input_ids_length,
)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 4. Check input_ids
if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
" Please make sure that you have put `input_ids` to the"
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
" running `.generate()`.",
UserWarning,
)
if (
hasattr(generation_config, "pad_token_id") and
torch.any(input_ids == generation_config.pad_token_id) and
attention_mask is None
):
warnings.warn(
"Padding was detected but no attention mask is passed here. For correct "
"generation results, please set `attention_mask` when batch-padding inputs.",
UserWarning,
)
input_ids, attention_mask = self._expand_inputs_for_generation(
expand_size=generation_config.num_return_sequences,
input_ids=input_ids,
attention_mask=attention_mask
)
result = self._sample(
input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
generation_tokens_hook_func=generation_tokens_hook_func,
generation_logits_hook_func=generation_logits_hook_func
)
return result
def _sample(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor],
generation_config: DreamGenerationConfig,
generation_tokens_hook_func,
generation_logits_hook_func
) -> Union[DreamModelOutput, torch.LongTensor]:
# init values
output_history = generation_config.output_history
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
mask_token_id = generation_config.mask_token_id
steps = generation_config.steps
eps = generation_config.eps
alg = generation_config.alg
alg_temp = generation_config.alg_temp
temperature = generation_config.temperature
top_p = generation_config.top_p
top_k = generation_config.top_k
histories = [] if (return_dict_in_generate and output_history) else None
# pad input_ids to max_length
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
if attention_mask is not None and torch.any(attention_mask == 0.0):
# we do not mask the [MASK] tokens so value = 1.0
attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
tok_idx = attention_mask.long().cumsum(-1) - 1
tok_idx.masked_fill_(attention_mask == 0, 1)
# attention_mask is of shape [B, N]
# broadcast to [B, 1, N, N]
attention_mask = torch.logical_and(
attention_mask.unsqueeze(1).unsqueeze(-2),
attention_mask.unsqueeze(1).unsqueeze(-1),
)
else:
tok_idx = None
attention_mask = "full"
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
# this allows user-defined token control of the intermediate steps
x = generation_tokens_hook_func(None, x, None)
for i in range(steps):
mask_index = (x == mask_token_id)
logits = self(x, attention_mask, tok_idx).logits
logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
# this allows user-defined logits control of the intermediate steps
logits = generation_logits_hook_func(i, x, logits)
mask_logits = logits[mask_index]
t = timesteps[i]
s = timesteps[i + 1]
if alg == 'origin':
p_transfer = 1 - s / t if i < steps - 1 else 1
x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
_, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
x[mask_index] = x0.clone()
else:
if alg == 'maskgit_plus':
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
elif alg == 'topk_margin':
confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
elif alg == 'entropy':
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
else:
raise RuntimeError(f"Unknown alg: {alg}")
num_mask_token = mask_index.sum() / mask_index.shape[0]
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
full_confidence[mask_index] = confidence
if number_transfer_tokens > 0:
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
else:
full_confidence = full_confidence / alg_temp
full_confidence = F.softmax(full_confidence, dim=-1)
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
x_[mask_index] = x0.clone()
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
# this allows user-defined token control of the intermediate steps
x = generation_tokens_hook_func(i, x, logits)
if histories is not None:
histories.append(x.clone())
if return_dict_in_generate:
return DreamModelOutput(
sequences=x,
history=histories,
)
else:
return x

View File

@ -0,0 +1,850 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT and Qwen implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Dream model."""
import math
from typing import List, Optional, Tuple, Union
import os
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from transformers import PretrainedConfig
from .configuration_dream import DreamConfig
from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Dream-7B"
_CONFIG_FOR_DOC = "DreamConfig"
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
class DreamRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DreamRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
class DreamRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[DreamConfig] = None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def reset_parameters(self):
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
class DreamMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class DreamAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = False
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = DreamRotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class DreamSdpaAttention(DreamAttention):
"""
Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from DreamAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# causal_mask = attention_mask
# if attention_mask is not None: # no matter the length, we just slice it
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
# is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=False, # hard coded
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class DreamDecoderLayer(nn.Module):
def __init__(self, config: DreamConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if config.sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
# self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.self_attn = DreamSdpaAttention(config, layer_idx)
self.mlp = DreamMLP(config)
self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class DreamPreTrainedModel(PreTrainedModel):
config_class = DreamConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["DreamDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
):
_model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
# NOTE(Lin): we need to override the generation config
# because the generation config loaded in `from_pretrained`
# does not include all the attributes of DreamGenerationConfig
resume_download = kwargs.get("resume_download", None)
proxies = kwargs.get("proxies", None)
subfolder = kwargs.get("subfolder", "")
from_auto_class = kwargs.get("_from_auto", False)
from_pipeline = kwargs.get("_from_pipeline", None)
_model.generation_config = DreamGenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
)
return _model
class DreamBaseModel(DreamPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
Args:
config: DreamConfig
"""
def __init__(self, config: DreamConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DreamRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = DreamBaseModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def reset_rope_parameters(self):
self.model.rotary_emb.reset_parameters()
for layer in self.model.layers:
layer.self_attn.rotary_emb.reset_parameters()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MaskedLMOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if isinstance(attention_mask, str) and attention_mask == "full" or attention_mask == None:
# whether attention_mask is full
pass
elif isinstance(attention_mask, torch.Tensor):
if not torch.any(attention_mask == 0.0):
attention_mask = 'full'
elif attention_mask.dim() == 2:
# [B, L] → [B, 1, L, L]
attention_mask = torch.logical_and(
attention_mask.unsqueeze(1).unsqueeze(-2),
attention_mask.unsqueeze(1).unsqueeze(-1),
)
attention_mask = attention_mask.to(torch.bool)
elif attention_mask.dim() in (3, 4):
# already extended/broadcasted form
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.to(torch.bool)
else:
raise ValueError(f"Unexpected attention_mask shape: {attention_mask.shape}")
else:
raise TypeError(f"Unsupported attention_mask type: {type(attention_mask)}")
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -0,0 +1,346 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on Qwen's implementations in this library.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Dream."""
import json
import os
import unicodedata
from functools import lru_cache
from typing import Optional, Tuple
import regex as re
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
MAX_MODEL_INPUT_SIZES = {"dream/dream-tokenizer": 32768}
PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
@lru_cache()
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class DreamTokenizer(PreTrainedTokenizer):
"""
Construct a Dream tokenizer. Based on byte-level Byte-Pair-Encoding.
Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
be encoded differently whether it is at the beginning of the sentence (without space) or not:
```python
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Dream-org/Dream-v0-Base-7B", trust_remote_code=True)
>>> tokenizer("Hello world")["input_ids"]
[9707, 1879]
>>> tokenizer(" Hello world")["input_ids"]
[21927, 1879]
```
This is expected.
You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*):
The beginning of sequence token. Not applicable for this tokenizer.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding, for example when batching sequences of different lengths.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is
to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
'|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
clean_up_tokenization_spaces=False,
split_special_tokens=False,
**kwargs,
):
# Dream vocab does not contain control tokens; added tokens need to be special
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
if isinstance(pad_token, str)
else pad_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_merges = []
with open(merges_file, encoding="utf-8") as merges_handle:
for i, line in enumerate(merges_handle):
line = line.strip()
if (i == 0 and line.startswith("#version:")) or not line:
continue
bpe_merges.append(tuple(line.split()))
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
# NOTE: the cache can grow without bound and will get really large for long running processes
# (esp. for texts of language that do not use space between word, e.g. Chinese); technically
# not a memory leak but appears as one.
# GPT2Tokenizer has the same problem, so let's be consistent.
self.cache = {}
self.pat = re.compile(PRETOKENIZE_REGEX)
if kwargs.get("add_prefix_space", False):
logger.warning_once(
f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
)
super().__init__(
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
split_special_tokens=split_special_tokens,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self.encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
# `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
# and cannot be configured elsewhere, but it should default to False for DreamTokenizer
return super().decode(
token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs):
text = unicodedata.normalize("NFC", text)
return (text, kwargs)
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
from .configuration_dream import DreamConfig
TOKENIZER_MAPPING.register(DreamConfig, (DreamTokenizer, None))

View File

@ -0,0 +1,84 @@
from typing import Any
import torch
from dllm.core.trainers import MDLMTrainer
def cart_weight(
masked_indices: torch.Tensor, t: torch.Tensor, p: float = 0.3
) -> torch.Tensor:
"""
Optimized CART weight computation using matrix operations.
Args:
masked_indices (torch.Tensor): (b, l) bool tensor indicating masked positions.
t (torch.Tensor): (b,) time steps (0-1 sampled uniformly). Not directly used in CART.
p (float): Parameter of geometric distribution (0 < p <= 1).
Returns:
torch.Tensor: (b, l) float tensor of weights.
"""
b, l = masked_indices.shape
device = masked_indices.device
idx = torch.arange(l, device=device)
dist_matrix = (idx[None, :] - idx[:, None]).abs() - 1
dist_matrix = torch.clamp(dist_matrix, min=0) # (l, l)
geo_matrix = (
torch.log(torch.tensor(p, device=device))
+ (dist_matrix - 1).clamp(min=0) * torch.log(torch.tensor(1 - p, device=device))
).exp() * 0.5 # Ensure numerical stability
geo_matrix.masked_fill_(dist_matrix == 0, 0.0) # ignore distance = 0
valid_mask = (~masked_indices).float() # (b, l), 1 = unmasked
weights = valid_mask @ geo_matrix.T # (b, l)
weights = weights * masked_indices.float()
return weights
class DreamTrainer(MDLMTrainer):
"""
DreamTrainer: specialization of MDLMTrainer for Dream training.
"""
def __init__(
self,
*args,
loss_weight_type: str = "cart[geo_p:0.3]",
**kwargs,
):
super().__init__(*args, loss_weight_type=loss_weight_type, **kwargs)
def _preprocess_inputs(self, inputs):
labels = inputs["labels"]
assert (labels[:, 0] == -100).all()
def _postprocess_outputs(self, outputs):
logits = outputs.logits
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
def _compute_loss_weights(
self,
t: torch.Tensor,
inputs: dict[str, Any],
masked_indices: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
if self.loss_weight_type.startswith("cart"):
# parse geo_p
import re
match = re.search(r"geo_p:(0\.\d+)", self.loss_weight_type)
geo_p = float(match.group(1)) if match else 0.3
loss_weights = cart_weight(masked_indices, t, p=geo_p)
else:
loss_weights = super()._compute_loss_weights(
t=t,
inputs=inputs,
masked_indices=masked_indices,
*args,
**kwargs,
)
return loss_weights

View File

@ -0,0 +1,180 @@
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn.functional as F
import transformers
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
top_k = min(top_k, logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
@dataclass
class DreamSFTCollator(transformers.DataCollatorForSeq2Seq):
"""
Randomly crop response length to reduce length bias during generation.
Reference: https://github.com/DreamLM/Dream/blob/main/src/trainer/fsdp_sft_trainer.py
"""
perbatch_cutoff: bool = True # Use prebatch truncation if True
resp_cutoff_ratio: float = 0.0 # Prob. of post-collation truncation
# -------------------------------------------------------------------------
# 1) Pre-collation truncation (per-sample)
# -------------------------------------------------------------------------
def apply_perbatch_cutoff(self, features):
"""
Randomly pick a response length from batch (`kept_len`) and trim other responses.
Before:
[<--promptA----><------responseA------>]
[<--promptB-><---responseB--->]
[<---promptC----><--respC-->]
After:
[<--promptA----><---respA--->]
[<--promptB-><--respB-->]
[<---promptC----><--respC-->]
kept_len = 10 → trim each response to ≤10 tokens (before padding)
"""
resp_lens = torch.tensor(
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
)
kept_len = int(np.random.choice(resp_lens))
for f, r_len in zip(features, resp_lens):
remove_len = max(r_len - kept_len, 0)
if remove_len > 0:
# f["input_ids"] = f["input_ids"][:-remove_len]
# f["attention_mask"] = f["attention_mask"][:-remove_len]
# f["labels"] = f["labels"][:-remove_len]
for key in ["input_ids", "labels", "attention_mask"]:
if key in f:
f[key] = f[key][:-remove_len]
return features
# -------------------------------------------------------------------------
# 2) Post-collation truncation
# -------------------------------------------------------------------------
def apply_resp_cutoff(self, batch, features):
"""
Uniformly chop tail *after padding*. All sequences truncated to new_seq_len.
Before:
[<--promptA----><-----respA----->] 40
[<--promptB-><respB><----pad---->] 40
[<---promptC----><--respC--><pad>] 40
cutoff_len = 5
After:
[<--promptA----><--respA--->] 35
[<--promptB-><respB><--pad->] 35
[<---promptC----><--respC-->] 35
"""
orig_seq_lens = [len(f["input_ids"]) for f in features]
resp_lens = torch.tensor(
[len(f["input_ids"]) - f["prompt_len"] for f in features], dtype=torch.long
)
min_resp_len = resp_lens.min().item()
if min_resp_len <= 1:
return batch
cutoff_len = int(np.random.randint(1, min_resp_len))
new_seq_len = max(orig_seq_lens) - cutoff_len
for key in ["input_ids", "labels", "attention_mask"]:
if key in batch:
batch[key] = batch[key][:, :new_seq_len].contiguous()
return batch
# -------------------------------------------------------------------------
# 3) Main call: pick truncation mode
# -------------------------------------------------------------------------
def __call__(self, features, return_tensors=None):
# optional pre-collation truncation
if self.perbatch_cutoff:
features = self.apply_perbatch_cutoff(features)
# always collate only the needed fields
base = [
{k: f[k] for k in ("input_ids", "labels", "attention_mask") if k in f}
for f in features
]
batch = super().__call__(base, return_tensors=return_tensors)
# optional post-collation truncation
if (
not self.perbatch_cutoff
and self.resp_cutoff_ratio > 0
and np.random.rand() < self.resp_cutoff_ratio
):
batch = self.apply_resp_cutoff(batch, features)
batch.pop("prompt_len", None)
return batch
@dataclass
class DreamPTCollator(transformers.DataCollatorForSeq2Seq):
random_length_ratio: float = 0.01
def __call__(self, features, return_tensors=None):
outputs = super().__call__(features, return_tensors=return_tensors)
input_ids, labels, attention_mask = (
outputs["input_ids"],
outputs["labels"],
outputs["attention_mask"],
)
bsz, seq_len = input_ids.shape
# --- Random truncation for robustness ---
if torch.rand(1).item() < self.random_length_ratio:
random_len = torch.randint(1, seq_len + 1, (1,)).item()
input_ids = input_ids[:, :random_len]
labels = labels[:, :random_len]
attention_mask = attention_mask[:, :random_len]
# --- Add BOS token to the beginning of input_ids ---
bos = torch.full(
(bsz, 1),
self.tokenizer.bos_token_id,
dtype=input_ids.dtype,
device=input_ids.device,
)
input_ids = torch.cat([bos, input_ids], dim=1)
# --- Prepend zeros to labels instead of BOS ---
ignore_labels = self.label_pad_token_id * torch.ones(
(bsz, 1), dtype=labels.dtype, device=labels.device
)
labels = torch.cat([ignore_labels, labels], dim=1)
# --- Prepend ones to attention_mask ---
bos_attention = torch.ones(
(bsz, 1), dtype=attention_mask.dtype, device=attention_mask.device
)
attention_mask = torch.cat([bos_attention, attention_mask], dim=1)
# --- Update and return ---
outputs["input_ids"] = input_ids
outputs["labels"] = labels
outputs["attention_mask"] = attention_mask
# Check if attention_mask is all ones and set it to None
if torch.all(outputs["attention_mask"] == 1):
outputs.pop("attention_mask")
return outputs

View File

@ -0,0 +1,14 @@
from . import trainer, utils
from .models.dream.modelling_dream import (
EditFlowDreamConfig,
EditFlowDreamModel,
)
from .models.llada.modelling_llada import (
EditFlowLLaDAConfig,
EditFlowLLaDAModel,
)
from .models.bert.modelling_modernbert import (
EditFlowModernBertConfig,
EditFlowModernBertModel,
)
from dllm.pipelines.editflow.trainer import EditFlowTrainer

View File

@ -0,0 +1,89 @@
import torch
from torch import nn
import transformers
class EditFlowModernBertConfig(transformers.ModernBertConfig):
model_type = "editflow-modernbert" # <- NEW model_type
class EditFlowModernBertModel(transformers.ModernBertForMaskedLM):
config_class = EditFlowModernBertConfig
modules_to_save = {
"rate_heads",
"sub_logits",
"ins_logits",
} # fully fintuned even using lora
def __init__(self, config):
# fa2 has bugs when forward(output_hidden_states=True)
config._attn_implementation = "sdpa"
super().__init__(config)
in_lm, out_lm = self.decoder.in_features, self.decoder.out_features
use_bias = self.decoder.bias is not None
# Create new, independent heads (no deepcopy)
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.rate_heads = nn.Sequential(nn.Linear(in_lm, 3), nn.Softplus())
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
t: torch.Tensor | None = None,
**kwargs,
):
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
h = output["hidden_states"][-1] # final hidden states
h = self.head(h)
# Position heads
sub_log = self.sub_logits(h) # [B, L, V]
ins_log = self.ins_logits(h) # [B, L, V]
rates = self.rate_heads(h)
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
-1
) # [B, L], [B, L], [B, L]
return dict(
sub_rate_hat=sub_rate_hat, # [B,L]
del_rate_hat=del_rate_hat, # [B,L]
ins_rate_hat=ins_rate_hat, # [B,L]
ins_logits=ins_log, # [B,L,V]
sub_logits=sub_log, # [B,L,V]
)
from transformers.models.auto import AutoModel, AutoConfig
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoConfig.register("editflow-modernbert", EditFlowModernBertConfig)
AutoModel.register(EditFlowModernBertConfig, EditFlowModernBertModel)
if __name__ == "__main__":
import dllm
import torch
from transformers import AutoConfig, AutoModel
# Load a config from a local path (either a directory containing config.json, or the file itself)
config_path = dllm.utils.resolve_with_base_env(
"answerdotai/ModernBERT-base", "BASE_MODELS_DIR"
)
config = EditFlowModernBertConfig.from_pretrained(config_path)
if hasattr(config, "auto_map"):
delattr(config, "auto_map")
if hasattr(config, "architectures"):
delattr(config, "architectures")
torch.set_default_device("cuda")
model = EditFlowModernBertModel(config)
model.save_pretrained("models-tmp/editflow-modernbert")
auto_model = AutoModel.from_pretrained("models-tmp/editflow-modernbert")

View File

@ -0,0 +1,97 @@
import copy
from typing import Optional
import torch
from torch import nn
from dllm.pipelines import dream
class EditFlowDreamConfig(dream.DreamConfig):
model_type = "editflow-dream" # <- NEW model_type
class EditFlowDreamModel(dream.DreamModel):
config_class = EditFlowDreamConfig
modules_to_save = {
"rate_heads",
"sub_logits",
"ins_logits",
} # fully fintuned even using lora
def __init__(self, config):
super().__init__(config)
in_lm, out_lm = self.lm_head.in_features, self.lm_head.out_features
use_bias = self.lm_head.bias is not None
# Create new, independent heads (no deepcopy)
self.sub_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.ins_logits = nn.Linear(in_lm, out_lm, bias=use_bias)
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
t: torch.Tensor | None = None,
**kwargs,
):
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
h = output["hidden_states"][-1] # final hidden states
# Position heads
sub_log = self.sub_logits(h) # [B, L, V]
sub_log = torch.concatenate(
[torch.zeros_like(sub_log)[:, :1], sub_log[:, :-1]], dim=1
) # [B, L, V]
ins_log = self.ins_logits(h) # [B, L, V]
rates = self.rate_heads(h)
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
-1
) # [B, L], [B, L], [B, L]
sub_rate_hat = torch.concatenate(
[torch.zeros_like(sub_rate_hat[:, :1]), sub_rate_hat[:, :-1]], dim=1
) # [B, L]
del_rate_hat = torch.concatenate(
[torch.zeros_like(del_rate_hat[:, :1]), del_rate_hat[:, :-1]], dim=1
) # [B, L]
return dict(
sub_rate_hat=sub_rate_hat, # [B,L]
del_rate_hat=del_rate_hat, # [B,L]
ins_rate_hat=ins_rate_hat, # [B,L]
ins_logits=ins_log, # [B,L,V]
sub_logits=sub_log, # [B,L,V]
)
from transformers.models.auto import AutoModel, AutoConfig
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoConfig.register("editflow-dream", EditFlowDreamConfig)
AutoModel.register(EditFlowDreamConfig, EditFlowDreamModel)
if __name__ == "__main__":
import dllm
import torch
from transformers import AutoConfig, AutoModel
# Load a config from a local path (either a directory containing config.json, or the file itself)
config_path = dllm.utils.resolve_with_base_env(
"Dream-org/Dream-v0-Base-7B", "BASE_MODELS_DIR"
)
config = EditFlowDreamConfig.from_pretrained(config_path)
if hasattr(config, "auto_map"):
delattr(config, "auto_map")
if hasattr(config, "architectures"):
delattr(config, "architectures")
torch.set_default_device("cuda")
model = EditFlowDreamModel(config)
model.save_pretrained("models-tmp/editflow-dream")
auto_model = AutoModel.from_pretrained("models-tmp/editflow-dream")

View File

@ -0,0 +1,91 @@
import copy
from typing import Optional
import torch
from torch import nn
from dllm.pipelines import llada
class EditFlowLLaDAConfig(llada.LLaDAConfig):
model_type = "editflow-llada" # <- NEW model_type
class EditFlowLLaDAModel(llada.LLaDAModelLM):
config_class = EditFlowLLaDAConfig
modules_to_save = {
"rate_heads",
"sub_logits",
"ins_logits",
} # fully fintuned even using lora
def __init__(self, config):
# TODO: time embedding
super().__init__(config)
ff = self.model.transformer.ff_out
in_f, out_f = ff.in_features, ff.out_features
use_bias = ff.bias is not None
# Create new, independent heads (no deepcopy)
self.sub_logits = nn.Linear(in_f, out_f, bias=use_bias)
self.ins_logits = nn.Linear(in_f, out_f, bias=use_bias)
self.rate_heads = nn.Sequential(nn.Linear(config.hidden_size, 3), nn.Softplus())
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor | None = None,
t: torch.Tensor | None = None,
**kwargs,
):
# TODO: time embedding
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
h = output["hidden_states"][-1] # final hidden states
# Position heads
sub_log = self.sub_logits(h) # [B, L, V]
ins_log = self.ins_logits(h) # [B, L, V]
rates = self.rate_heads(h)
sub_rate_hat, del_rate_hat, ins_rate_hat = rates.unbind(
-1
) # [B, L], [B, L], [B, L]
return dict(
sub_rate_hat=sub_rate_hat, # [B,L]
del_rate_hat=del_rate_hat, # [B,L]
ins_rate_hat=ins_rate_hat, # [B,L]
ins_logits=ins_log, # [B,L,V]
sub_logits=sub_log, # [B,L,V]
)
from transformers.models.auto import AutoModel, AutoConfig
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoConfig.register("editflow-llada", EditFlowLLaDAConfig)
AutoModel.register(EditFlowLLaDAConfig, EditFlowLLaDAModel)
if __name__ == "__main__":
import dllm
import torch
from transformers import AutoConfig, AutoModel
# Load a config from a local path (either a directory containing config.json, or the file itself)
config_path = dllm.utils.resolve_with_base_env(
"GSAI-ML/LLaDA-8B-Base", "BASE_MODELS_DIR"
)
config = EditFlowLLaDAConfig.from_pretrained(config_path)
if hasattr(config, "auto_map"):
delattr(config, "auto_map")
if hasattr(config, "architectures"):
delattr(config, "architectures")
torch.set_default_device("cuda")
model = EditFlowLLaDAModel(config)
model.save_pretrained("models-tmp/editflow-llada")
auto_model = AutoModel.from_pretrained("models-tmp/editflow-llada")

View File

@ -0,0 +1,407 @@
from typing import Any, Dict, Union, List, Tuple, Optional
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler
from dllm.pipelines.editflow.utils import pad_1d
BLANK = -1
def align_with_blanks(
x0: List[int], x1: List[int], sub_cost: int = 1, gap_cost: int = 1
) -> Dict:
"""
NeedlemanWunsch global alignment of two integer sequences with:
match cost = 0, substitution cost = sub_cost, gap cost = gap_cost.
Returns aligned sequences (z0, z1) of equal length containing BLANK = ε where gaps occur.
"""
n, m = len(x0), len(x1)
# DP tables
dp = [[0] * (m + 1) for _ in range(n + 1)]
ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag', 'up', 'left'
for i in range(1, n + 1):
dp[i][0] = i * gap_cost
ptr[i][0] = "up"
for j in range(1, m + 1):
dp[0][j] = j * gap_cost
ptr[0][j] = "left"
for i in range(1, n + 1):
for j in range(1, m + 1):
cost_diag = dp[i - 1][j - 1] + (0 if x0[i - 1] == x1[j - 1] else sub_cost)
cost_up = dp[i - 1][j] + gap_cost
cost_left = dp[i][j - 1] + gap_cost
best = min(cost_diag, cost_up, cost_left)
dp[i][j] = best
if best == cost_diag:
ptr[i][j] = "diag"
elif best == cost_up:
ptr[i][j] = "up"
else:
ptr[i][j] = "left"
# traceback
z0, z1 = [], []
i, j = n, m
while i > 0 or j > 0:
p = ptr[i][j]
if p == "diag":
z0.append(x0[i - 1])
z1.append(x1[j - 1])
i -= 1
j -= 1
elif p == "up":
z0.append(x0[i - 1])
z1.append(BLANK)
i -= 1
else: # 'left'
z0.append(BLANK)
z1.append(x1[j - 1])
j -= 1
z0.reverse()
z1.reverse()
# return Alignment(z0=z0, z1=z1)
# return {"z0": z0, "z1": z1}
return dict(z0=z0, z1=z1)
# def align_with_blanks(
# x0: list[int], x1: list[int], sub_cost: int = 1, gap_cost: int = 1
# ) -> dict:
# """
# NeedlemanWunsch with a secondary objective that defers gaps to the end:
# - 'up' (gap in z1) is penalized if j < m
# - 'left' (gap in z0) is penalized if i < n
# This pushes blanks (-1) to the *right* whether x0 > x1 or x0 < x1.
# """
# n, m = len(x0), len(x1)
# dp_cost = [[0] * (m + 1) for _ in range(n + 1)]
# dp_pen = [[0] * (m + 1) for _ in range(n + 1)]
# ptr = [[None] * (m + 1) for _ in range(n + 1)] # 'diag' | 'up' | 'left'
# # Left edge: all 'up' moves with j=0 (< m) → penalize each step
# for i in range(1, n + 1):
# dp_cost[i][0] = i * gap_cost
# dp_pen[i][0] = i # i early 'up' moves
# ptr[i][0] = "up"
# # Top edge: all 'left' moves with i=0 (< n) → penalize each step
# for j in range(1, m + 1):
# dp_cost[0][j] = j * gap_cost
# dp_pen[0][j] = j # j early 'left' moves
# ptr[0][j] = "left"
# for i in range(1, n + 1):
# xi = x0[i - 1]
# for j in range(1, m + 1):
# yj = x1[j - 1]
# # diag
# cost_diag = dp_cost[i - 1][j - 1] + (0 if xi == yj else sub_cost)
# pen_diag = dp_pen[i - 1][j - 1]
# cand_diag = (cost_diag, pen_diag)
# # up: add blank to z1, penalize if j < m (early)
# cost_up = dp_cost[i - 1][j] + gap_cost
# pen_up = dp_pen[i - 1][j] + (1 if j < m else 0)
# cand_up = (cost_up, pen_up)
# # left: add blank to z0, penalize if i < n (early)
# cost_left = dp_cost[i][j - 1] + gap_cost
# pen_left = dp_pen[i][j - 1] + (1 if i < n else 0)
# cand_left = (cost_left, pen_left)
# # choose (cost,pen) min; deterministic tie-break: diag > left > up
# best = min(cand_diag, cand_left, cand_up)
# dp_cost[i][j], dp_pen[i][j] = best
# if best == cand_diag:
# ptr[i][j] = "diag"
# elif best == cand_left:
# ptr[i][j] = "left"
# else:
# ptr[i][j] = "up"
# # traceback
# z0, z1 = [], []
# i, j = n, m
# while i > 0 or j > 0:
# p = ptr[i][j]
# if p == "diag":
# z0.append(x0[i - 1])
# z1.append(x1[j - 1])
# i -= 1
# j -= 1
# elif p == "up":
# z0.append(x0[i - 1])
# z1.append(BLANK)
# i -= 1
# else: # 'left'
# z0.append(BLANK)
# z1.append(x1[j - 1])
# j -= 1
# z0.reverse()
# z1.reverse()
# return dict(z0=z0, z1=z1)
def strip_blanks(z: list[int]) -> list[int]:
# IMPORTANT: do NOT strip BOS; we only remove BLANKs
return [t for t in z if t != BLANK]
@dataclass
class Edit:
kind: str # "SUB" | "DEL" | "INS"
pos: int # position (for SUB/DEL) or token-row idx for INS (incl. BOS row 0)
token: int | None # token for SUB/INS, else None
def build_remaining_edits(zt: list[int], z1: list[int]) -> list[Edit]:
edits: list[Edit] = []
def count_nonblank_prefix(z: list[int], j: int) -> int:
c = 0
for k in range(j):
if z[k] != BLANK:
c += 1
return c
for j, (a, b) in enumerate(zip(zt, z1)):
if a == b:
continue
nb = count_nonblank_prefix(
zt, j
) # counts BOS as 1, first content token will be nb=1 before its column
if a == BLANK and b != BLANK:
# INSERT after row (nb-1): BOS insert => nb=1 -> gap=0; general case works too
gap = max(nb - 1, 0)
edits.append(Edit("INS", gap, b))
elif a != BLANK and b == BLANK:
# DELETE token at row nb (first content token => nb=1, allowed; BOS is never BLANK so nb>=1)
pos = nb
# if pos > 0: # forbid BOS (row 0)
edits.append(Edit("DEL", pos, None))
else: # a != BLANK, b != BLANK, a != b
# SUB token at row nb
pos = nb
# if pos > 0: # forbid BOS (row 0)
edits.append(Edit("SUB", pos, b))
return edits
class EditFlowTrainer(transformers.Trainer):
"""
Trainer for Edit Flows where the model returns:
- sub_logits: [B,L,V] (token dist for SUB)
- ins_logits: [B,L,V] (token dist for INS)
- sub_rate_hat: [B,L] (normalized rates; NO kappa factor)
- del_rate_hat: [B,L]
- ins_rate_hat: [B,L]
True intensities are w * rate_hat, with w = kappa_dot(t) / (1 - kappa(t)).
"""
def __init__(
self,
*args,
scheduler: BaseKappaScheduler | None = None,
normalize_per_position: bool = True,
time_epsilon: float = 1e-3,
max_w: float | None = None,
**kwargs,
):
self.scheduler = scheduler or CubicKappaScheduler()
self.normalize_per_position = normalize_per_position
self.time_epsilon = time_epsilon
self.max_w = max_w
super().__init__(*args, **kwargs)
def compute_loss(
self,
model: transformers.PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs: bool = False,
**kwargs,
):
device = self.model.device
B = len(inputs["x0_ids"])
# -------- 1) Align with blanks (z0,z1) and sample time t --------
aligns = [
align_with_blanks(x0, x1)
for x0, x1 in zip(inputs["x0_ids"], inputs["x1_ids"])
]
z0_list = [a["z0"] for a in aligns]
z1_list = [a["z1"] for a in aligns]
assert all(len(z0) == len(z1) for z0, z1 in zip(z0_list, z1_list))
assert all(z0[0] != BLANK for z0 in z0_list) # BOS must remain
t = (1 - self.time_epsilon) * torch.rand(B, 1, device=device) # [B,1]
k = self.scheduler.kappa(t).to(device) # [B,1]
w = self.scheduler.weight(t).squeeze(1).to(device) # [B]
if self.max_w:
w = w.clamp(max=self.max_w)
# -------- 2) Sample z_t by κ-mixing (vectorized per example) --------
# Keep python lists -> tensors per-example to reuse build_remaining_edits
zt_list: list[list[int]] = []
for z0, z1, kb in zip(z0_list, z1_list, k.squeeze(1).tolist()):
# per-column Bernoulli(κ) mix; BOS is equal in z0/z1 so it stays BOS
choose_target = torch.rand(len(z0)) < kb
zt = [b if choose_target[j] else a for j, (a, b) in enumerate(zip(z0, z1))]
zt_list.append(zt)
# -------- 3) Strip blanks to x_t and compute remaining edits --------
xt_list = [strip_blanks(zt) for zt in zt_list]
edits_list: list[list[Edit]] = [
build_remaining_edits(zt, z1) for zt, z1 in zip(zt_list, z1_list)
]
# -------- 4) Collate x_t for the model --------
x_tok, x_mask = pad_1d(
xt_list, pad_val=self.processing_class.pad_token_id
) # [B,Lmax], [B,Lmax]
x_tok = x_tok.to(device)
x_mask = x_mask.to(device)
# -------- 5) Forward pass --------
out = model(input_ids=x_tok, attention_mask=x_mask, t=t.to(device))
# Rename for clarity: model returns normalized rates (no kappa)
sub_rate_hat = out["sub_rate_hat"] # [B,L]
del_rate_hat = out["del_rate_hat"] # [B,L]
ins_rate_hat = out["ins_rate_hat"] # [B,L]
logQ_sub = F.log_softmax(out["sub_logits"], dim=-1) # [B,L,V]
logQ_ins = F.log_softmax(out["ins_logits"], dim=-1) # [B,L,V]
# *** NEW: zero-cost anchor to "touch" every head even if unused this step ***
# Using .sum() * 0.0 keeps a graph dependency without changing the loss value.
# Include both logits (for SUB/INS heads) and rates (for SUB/DEL/INS heads).
# This is important for Deepspeed ZeRO stage 2/3 to avoid skipping unused parameters.
anchor = (
sub_rate_hat.sum() * 0.0
+ del_rate_hat.sum() * 0.0
+ ins_rate_hat.sum() * 0.0
+ logQ_sub.sum() * 0.0
+ logQ_ins.sum() * 0.0
)
# Utility
def safe_log(x: torch.Tensor) -> torch.Tensor:
return torch.log(x.clamp_min(1e-12))
# -------- 6) Survival term --------
# Survival = E[sum of true intensities over valid rows]
# true intensity = w[b] * rate_hat[b, i]
mask_f = x_mask.float()
# L = mask_f.sum(dim=1).clamp_min(1.0) # [B] number of positions (incl. BOS)
L1 = torch.tensor(
[len(x1) for x1 in inputs["x1_ids"]], device=device, dtype=torch.float
).clamp_min(1.0)
denom = L1 if self.normalize_per_position else torch.ones_like(L1)
Lambda_hat = ((sub_rate_hat + del_rate_hat + ins_rate_hat) * mask_f).sum(
dim=1
) # [B]
loss_surv = ((w * Lambda_hat) / denom).mean()
# -------- 7) Positive edit terms --------
# For each remaining edit e: -log true rate(e) - log token prob(e) if tokenized
# loss_pos_per = sub_rate_hat.new_zeros(B) # [B]
# for b, edits in enumerate(edits_list):
# if not edits:
# continue
# cur_len = int(x_mask[b].sum().item())
# for e in edits:
# pos = e.pos
# assert 0 <= pos < cur_len, f"pos {pos} out of range {cur_len}"
# if e.kind == "SUB":
# loss_pos_per[b] -= logQ_sub[b, pos, e.token] + safe_log(
# sub_rate_hat[b, pos]
# )
# elif e.kind == "DEL":
# loss_pos_per[b] -= safe_log(del_rate_hat[b, pos])
# else: # "INS"
# loss_pos_per[b] -= logQ_ins[b, pos, e.token] + safe_log(
# ins_rate_hat[b, pos]
# )
# -------- 7) Positive edit terms (vectorized) --------
pos_sub, tok_sub, pos_ins, tok_ins, pos_del = [], [], [], [], []
for b, edits in enumerate(edits_list):
cur_len = int(x_mask[b].sum().item())
ps, ts, pi, ti, pd = [], [], [], [], []
for e in edits:
if not (0 <= e.pos < cur_len):
raise AssertionError(
f"pos {e.pos} out of range {cur_len} for b={b}"
)
if e.kind == "SUB":
ps.append(e.pos)
ts.append(e.token)
elif e.kind == "INS":
pi.append(e.pos)
ti.append(e.token)
else:
pd.append(e.pos)
pos_sub.append(
torch.tensor(ps, device=x_tok.device, dtype=torch.long) if ps else None
)
tok_sub.append(
torch.tensor(ts, device=x_tok.device, dtype=torch.long) if ts else None
)
pos_ins.append(
torch.tensor(pi, device=x_tok.device, dtype=torch.long) if pi else None
)
tok_ins.append(
torch.tensor(ti, device=x_tok.device, dtype=torch.long) if ti else None
)
pos_del.append(
torch.tensor(pd, device=x_tok.device, dtype=torch.long) if pd else None
)
loss_pos_terms = []
for b in range(B):
lp = x_tok.new_zeros(())
if pos_sub[b] is not None:
lp = (
lp
- (
logQ_sub[b, pos_sub[b], tok_sub[b]]
+ safe_log(sub_rate_hat[b, pos_sub[b]])
).sum()
)
if pos_ins[b] is not None:
lp = (
lp
- (
logQ_ins[b, pos_ins[b], tok_ins[b]]
+ safe_log(ins_rate_hat[b, pos_ins[b]])
).sum()
)
if pos_del[b] is not None:
lp = lp - safe_log(del_rate_hat[b, pos_del[b]]).sum()
loss_pos_terms.append(lp)
loss_pos_per = torch.stack(loss_pos_terms) # [B]
# # Average positive term per sequence (MC estimator across batch)
loss_pos = ((w * loss_pos_per) / denom).mean()
# -------- 8) Total --------
loss = loss_surv + loss_pos + anchor
return (loss, out) if return_outputs else loss
if __name__ == "__main__":
pass

View File

@ -0,0 +1,218 @@
import math
import random
from dataclasses import dataclass
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Text
from collections.abc import Callable
import torch
import transformers
from dllm.utils.utils import parse_spec
# ------------------------------- Collator (x0 source) --------------------------------
@dataclass
class X0Sampler:
def __call__(self, *args, **kwargs) -> list[int]:
raise NotImplementedError("Subclasses must implement __call__.")
@dataclass
class SampleX0Empty(X0Sampler):
"""Return BOS-only (i.e., empty tail)."""
def __call__(self, *args, **kwargs) -> list[int]:
return []
@dataclass
class SampleX0Masks(X0Sampler):
"""Return a run of mask tokens of given length."""
length: int = 128
tokenizer: transformers.PreTrainedTokenizer = None
def __call__(self, *args, **kwargs) -> list[int]:
mask_id = getattr(self.tokenizer, "mask_token_id", None)
if mask_id is None:
raise ValueError("tokenizer needs mask_token_id for mask-based sampler")
return [int(mask_id)] * self.length
# ---------------- Factory ---------------- #
_X0_SAMPLER_CLASSES: dict[str, type[X0Sampler]] = {
"empty": SampleX0Empty,
"masks": SampleX0Masks,
}
def make_x0_sampler(name: str, tokenizer: Any, **kwargs) -> X0Sampler:
try:
name, kvs = parse_spec(name)
cls = _X0_SAMPLER_CLASSES[name.lower()]
except KeyError:
raise ValueError(
f"Unknown x0 sampler '{name}'. Available: {list(_X0_SAMPLER_CLASSES)}"
)
# merged_kwargs = {**kvs, **kwargs}
return cls(tokenizer=tokenizer, **kvs, **kwargs)
@dataclass
class EditFlowCollator:
tokenizer: transformers.PreTrainedTokenizer = None
x0_sampler: Callable | str | None = X0Sampler # can be func OR name
def __post_init__(self):
if isinstance(self.x0_sampler, str):
self.x0_sampler = make_x0_sampler(self.x0_sampler, self.tokenizer)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, list[Any]]:
if not features:
return {}
keys = features[0].keys()
batch = {k: [ex[k] for ex in features] for k in keys}
batch["x1_ids"] = batch["input_ids"]
if "prompt_len" not in batch:
assert self.tokenizer.bos_token_id is not None
bos = self.tokenizer.bos_token_id
batch["x1_ids"] = [
x if x and x[0] == bos else [bos] + x for x in batch["x1_ids"]
]
batch["x0_ids"] = [
x1_ids[:1] + self.x0_sampler(x1_ids=x1_ids[1:])
for x1_ids in batch["x1_ids"]
]
else:
batch["x0_ids"] = [
x1_ids[:prompt_len] + self.x0_sampler(x1_ids=x1_ids[prompt_len:])
for x1_ids, prompt_len in zip(batch["x1_ids"], batch["prompt_len"])
]
batch["return_loss"] = True
return batch
# ------------------------------- Trainer utils --------------------------------
def pad_1d(
batch_lists: list[list[int]], pad_val: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pads a list of variable-length integer lists into:
- out: tensor of shape [B, Lmax] with padding value `pad_val`
- mask: tensor of shape [B, Lmax] with 1 for real tokens and 0 for padding (int mask)
"""
B = len(batch_lists)
Lmax = max((len(x) for x in batch_lists), default=0)
out = torch.full((B, Lmax), pad_val, dtype=torch.long)
mask = torch.zeros((B, Lmax), dtype=torch.long) # 0/1 mask (int)
for b, x in enumerate(batch_lists):
if not x:
continue
L = len(x)
out[b, :L] = torch.tensor(x, dtype=torch.long)
mask[b, :L] = 1 # mark valid positions with 1
return out, mask
def init_editflow_from_src(
ef_model, src_model, lm_head_key: str = "lm_head", verbose: bool = True
):
"""
Initialize an EditFlowModel (ef_model) from a pretrained source model.
If DeepSpeed ZeRO-3 is enabled (detected via HF's `is_deepspeed_zero3_enabled()`),
this function temporarily gathers full parameters for both models on rank 0,
performs the copy there, and then returns to sharded mode automatically.
Otherwise it behaves like a normal CPU/GPU single-process copy.
Returns (missing_keys, unexpected_keys) from load_state_dict(strict=False).
"""
import deepspeed
from transformers.integrations import is_deepspeed_zero3_enabled
dist_ok = torch.distributed.is_available() and torch.distributed.is_initialized()
rank = torch.distributed.get_rank() if dist_ok else 0
def _copy_once():
src_sd = src_model.state_dict()
tgt_sd = ef_model.state_dict()
new_sd = OrderedDict()
# 1) copy matching backbone tensors
for k, v in src_sd.items():
if k in tgt_sd and tgt_sd[k].shape == v.shape:
new_sd[k] = v
# 2) duplicate lm_head -> sub_logits & ins_logits (weight + optional bias)
lm_w = f"{lm_head_key}.weight"
lm_b = f"{lm_head_key}.bias"
if lm_w in src_sd:
if "sub_logits.weight" in tgt_sd:
new_sd["sub_logits.weight"] = src_sd[lm_w]
if "ins_logits.weight" in tgt_sd:
new_sd["ins_logits.weight"] = src_sd[lm_w]
if lm_b in src_sd:
if "sub_logits.bias" in tgt_sd:
new_sd["sub_logits.bias"] = src_sd[lm_b]
if "ins_logits.bias" in tgt_sd:
new_sd["ins_logits.bias"] = src_sd[lm_b]
# 3) non-strict load so new rate heads remain randomly initialized
missing, unexpected = ef_model.load_state_dict(new_sd, strict=False)
return new_sd, missing, unexpected
if is_deepspeed_zero3_enabled():
# All ranks enter/exit together; only rank 0 materializes full tensors.
params = list(ef_model.parameters()) + list(src_model.parameters())
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
if rank == 0:
new_sd, missing, unexpected = _copy_once()
else:
new_sd, missing, unexpected = OrderedDict(), [], []
if dist_ok:
torch.distributed.barrier()
if verbose and rank == 0:
_p = getattr(globals().get("dllm", None), "utils", None)
printer = getattr(_p, "print_main", print) if _p else print
printer(
f"[EditFlow init][ZeRO-3] Copied {len(new_sd)} tensors from Src Model."
)
if missing:
printer(" Missing (expected for new rate heads, etc.):")
for k in missing:
printer(" -", k)
if unexpected:
printer(" Unexpected (check key names):")
for k in unexpected:
printer(" -", k)
return missing, unexpected
# --- Non-ZeRO (or DS not present) path ---
new_sd, missing, unexpected = _copy_once()
if verbose:
_p = getattr(globals().get("dllm", None), "utils", None)
printer = getattr(_p, "print_main", print) if _p else print
printer(f"[EditFlow init] Copied {len(new_sd)} tensors from Src Model.")
if missing:
printer(" Missing (expected for new rate heads, etc.):")
for k in missing:
printer(" -", k)
if unexpected:
printer(" Unexpected (check key names):")
for k in unexpected:
printer(" -", k)
return missing, unexpected
if __name__ == "__main__":
pass

View File

@ -0,0 +1,7 @@
from . import generator, trainer
from .models.modeling_llada import LLaDAModelLM
from .models.configuration_llada import LLaDAConfig
from .models.modeling_lladamoe import LLaDAMoEModelLM
from .models.configuration_lladamoe import LLaDAMoEConfig
from .generator import LLaDAGeneratorConfig, LLaDAGenerator
from .trainer import LLaDATrainer

View File

@ -0,0 +1,357 @@
"""
accelerate launch \
--num_processes 2 \
dllm/pipelines/llada/eval.py \
--tasks gsm8k \
--model llada \
--num_fewshot 8 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Base,is_check_greedy=False,mc_num=1,max_new_tokens=1024,steps=1024,block_length=32,cfg=0.0"
"""
from types import SimpleNamespace
from dataclasses import dataclass
import accelerate
import torch
import torch.nn.functional as F
from datasets import Dataset
from tqdm import tqdm
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import get_dtype
import dllm
from dllm.pipelines.llada import LLaDAGenerator, LLaDAGeneratorConfig
@dataclass
class LLaDAEvalConfig(LLaDAGeneratorConfig):
max_new_tokens: int = 1024
max_length: int = 4096
steps: int = 1024
block_length: int = 1024
pretrained: str = ""
dtype: str | torch.dtype = "auto"
batch_size: int = 32
mc_num: int = 128
is_check_greedy: bool = True
device: str = "cuda"
@register_model("llada")
class LLaDAEvalHarness(LM):
def __init__(
self,
config: LLaDAEvalConfig | None = None,
**kwargs,
):
super().__init__()
if config is None:
config = LLaDAEvalConfig()
# Pull args from config, allow kwargs to override
pretrained = kwargs.get("pretrained", config.pretrained)
dtype = kwargs.get("dtype", config.dtype)
batch_size = kwargs.get("batch_size", config.batch_size)
mc_num = kwargs.get("mc_num", config.mc_num)
is_check_greedy = kwargs.get("is_check_greedy", config.is_check_greedy)
device = kwargs.get("device", config.device)
cfg = kwargs.get("cfg", config.cfg_scale)
steps = kwargs.get("steps", config.steps)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
block_length = kwargs.get("block_length", config.block_length)
max_length = kwargs.get("max_length", config.max_length)
remasking = kwargs.get("remasking", config.remasking)
accelerator = accelerate.Accelerator()
# Get GLOBAL rank from torch.distributed (not accelerator)
if torch.distributed.is_initialized():
self._rank = torch.distributed.get_rank() # ← GLOBAL rank (0-15)
self._world_size = (
torch.distributed.get_world_size()
) # ← GLOBAL world size (16)
else:
self._rank = 0
self._world_size = 1
# Use accelerator for device placement
self.model = dllm.utils.get_model(
SimpleNamespace(model_name_or_path=pretrained, dtype=get_dtype(dtype))
)
self.model.eval()
if accelerator.num_processes > 1:
# Let accelerator handle device placement
self.model = accelerator.prepare(self.model)
self.device = (
accelerator.device
) # ← Accelerator figures out local device correctly
self.accelerator = accelerator
else:
# Single GPU
self.model = self.model.to(device)
self.device = torch.device(device)
self.accelerator = None
self.tokenizer = dllm.utils.get_tokenizer(
SimpleNamespace(model_name_or_path=pretrained, model=self.model)
)
# generation params
self.mask_id = self.tokenizer.mask_token_id
self.batch_size = int(batch_size)
self.max_length = max_length
self.max_new_tokens = int(max_new_tokens)
self.block_length = int(block_length)
self.steps = int(steps)
self.cfg = float(cfg)
self.remasking = remasking
self.is_check_greedy = is_check_greedy
# loglikelihood params
self.mc_num = int(mc_num)
assert mc_num % self.batch_size == 0
self.sampling_eps = 0.0
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
chat_templated = self.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
return chat_templated
@property
def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__")
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def _forward_process(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
b, l = batch.shape
target_len = (l - prompt_index.sum()).item()
k = torch.randint(1, target_len + 1, (), device=batch.device)
x = torch.round(
torch.linspace(
float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device
)
).long()
x = ((x - 1) % target_len) + 1
assert x.min() >= 1 and x.max() <= target_len
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
is_mask = indices < x.unsqueeze(1)
for i in range(b):
is_mask[i] = is_mask[i][torch.randperm(target_len)]
is_mask = torch.cat(
(
torch.zeros(
b, prompt_index.sum(), dtype=torch.bool, device=batch.device
),
is_mask,
),
dim=1,
)
noisy_batch = torch.where(is_mask, self.mask_id, batch)
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
@torch.no_grad()
def get_logits(
self, batch: torch.Tensor, prompt_index: torch.Tensor
) -> torch.Tensor:
if self.cfg > 0.0:
assert len(prompt_index) == batch.shape[1]
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
un_batch = batch.clone()
un_batch[prompt_index] = self.mask_id
batch = torch.cat([batch, un_batch])
logits = self.model(batch).logits
if self.cfg > 0.0:
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
return logits[:, : batch.shape[1]]
@torch.no_grad()
def get_loglikelihood(self, prefix: torch.Tensor, target: torch.Tensor) -> float:
seq = torch.concatenate([prefix, target])[None, :]
seq = seq.repeat((self.batch_size, 1)).to(self.device)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
loss_acc = []
for _ in range(self.mc_num // self.batch_size):
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
mask_indices = perturbed_seq == self.mask_id
logits = self.get_logits(perturbed_seq, prompt_index)
loss = (
F.cross_entropy(
logits[mask_indices], seq[mask_indices], reduction="none"
)
/ p_mask[mask_indices]
)
loss = loss.sum() / self.batch_size
loss_acc.append(loss.item())
return -sum(loss_acc) / len(loss_acc)
@torch.no_grad()
def suffix_greedy_prediction(
self, prefix: torch.Tensor, target: torch.Tensor
) -> bool:
if not self.is_check_greedy:
return False
seq = torch.full(
(1, len(prefix) + len(target)), self.mask_id, device=self.device
)
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
prefix, target = prefix.to(self.device), target.to(self.device)
seq[0, : len(prefix)] = prefix
for i in range(len(target)):
mask_index = seq == self.mask_id
logits = self.get_logits(seq, prompt_index)[mask_index]
x0 = torch.argmax(logits, dim=-1)
p = torch.softmax(logits.to(torch.float32), dim=-1)
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(
dim=-1
)
_, index = torch.sort(confidence, descending=True)
x0[index[1:]] = self.mask_id
seq[mask_index] = x0.clone()
correct = target == seq[0, len(prefix) :]
correct = torch.all(correct)
return correct
def _encode_pair(
self, context: str, continuation: str
) -> tuple[torch.Tensor, torch.Tensor]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tokenizer(context + continuation)["input_ids"]
context_enc = self.tokenizer(context)["input_ids"]
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
def _tokenize(e):
prefix, target = self._encode_pair(e["prefix"], e["target"])
return {
"prefix_text": e["prefix"],
"target_text": e["target"],
"prefix": prefix,
"target": target,
}
ds = []
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
assert max(prompt_len) <= 4096
out = []
with torch.no_grad():
for elem in tqdm(ds, desc="Computing likelihood..."):
prefix = elem["prefix"]
target = elem["target"]
ll = self.get_loglikelihood(prefix, target)
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
torch.cuda.empty_cache()
return out
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
raise NotImplementedError
def generate_until(self, requests: list[Instance]) -> list[str]:
def _tokenize(e):
return {
"question": self.tokenizer(e["question"])["input_ids"],
"question_text": e["question"],
"until": e["until"],
}
ds = [
{"question": req.args[0], "until": req.args[1]["until"]} for req in requests
]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize)
ds = ds.with_format("torch")
out = []
generator = LLaDAGenerator(model=self.model, tokenizer=self.tokenizer)
for elem in tqdm(ds, desc="Generating..."):
prompt = [elem["question"].to(self.device)]
stop_tokens = elem["until"]
generated_ids = generator.generate(
inputs=prompt,
steps=self.steps,
max_new_tokens=self.max_new_tokens,
block_length=self.block_length,
temperature=0.0,
cfg_scale=self.cfg,
remasking=self.remasking,
)
generated_answer = self.tokenizer.decode(
generated_ids[0][prompt[0].shape[0] :], skip_special_tokens=False
)
for stop_seq in stop_tokens:
if stop_seq in generated_answer:
generated_answer = generated_answer.split(stop_seq)[0]
# remove special tokens
generated_answer_ids = self.tokenizer(generated_answer)["input_ids"]
generated_answer = self.tokenizer.decode(
generated_answer_ids, skip_special_tokens=True
)
out.append(generated_answer)
if self.accelerator is not None:
self.accelerator.wait_for_everyone()
return out
if __name__ == "__main__":
cli_evaluate()

View File

@ -0,0 +1,379 @@
"""
reference: https://github.com/ML-GSAI/LLaDA/blob/main/generate.py
"""
import math
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn.functional as F
from dllm.utils.generation_utils import get_num_transfer_tokens
from dllm.core.generation.generator import (
GeneratorOutput,
GeneratorConfig,
BaseGenerator,
)
def add_gumbel_noise(logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
"""
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (-torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
@dataclass
class LLaDAGeneratorConfig(GeneratorConfig):
max_new_tokens: int = 128
max_length: int = (
None # There's no explicit length_limit except for the tokenizer/model context
)
block_length: int = 128
steps: int = 128
temperature: float = 0.0
remasking: str = "low_confidence"
stochastic_transfer: bool = False
cfg_scale: float = 0.0
cfg_keep_tokens: list[int] | None = None
@dataclass
class LLaDAGenerator(BaseGenerator):
@torch.no_grad()
def generate(
self,
inputs: list[torch.Tensor | list],
config: LLaDAGeneratorConfig | None = None,
**kwargs,
) -> GeneratorOutput | torch.Tensor:
if config is None:
config = LLaDAGeneratorConfig()
# ----- pull args from config, allow kwargs to override -----
steps = kwargs.get("steps", config.steps)
max_new_tokens = kwargs.get("max_new_tokens", config.max_new_tokens)
max_length = kwargs.get("max_length", config.max_length)
block_length = kwargs.get("block_length", config.block_length)
temperature = kwargs.get("temperature", config.temperature)
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
remasking = kwargs.get("remasking", config.remasking)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
assert 1 <= block_length
assert 1 <= steps
mask_id = self.tokenizer.mask_token_id
eos_id = self.tokenizer.eos_token_id
# ----- Shape bookkeeping: per-sample prompt lengths and final canvas width -----
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
prompt_lens = [p.shape[0] for p in inputs]
if max_new_tokens:
max_length = max_new_tokens + max(prompt_lens)
else:
max_new_tokens = max_length - max(prompt_lens)
B = len(inputs)
T = max_length
# ----- Initialize canvas with EOS, copy inputs, and append mask tail -----
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
for i, p in enumerate(inputs):
x[i, : prompt_lens[i]] = p # keep original prompt tokens
x[i, prompt_lens[i] : prompt_lens[i] + max_new_tokens] = (
mask_id # append `max_new_tokens` masks to be generated
)
attention_mask = (x != eos_id).long() if B > 1 else None
# Tokens that were *given* at the start (non-mask, non-EOS).
# These will be masked in the unconditional forward pass for CFG.
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
unmasked_index = (x != mask_id) & (x != eos_id)
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
keep_mask = torch.isin(
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
)
unmasked_index = unmasked_index & ~keep_mask
# ----- Block scheduling over the appended mask tail -----
num_blocks = math.ceil(max_new_tokens / block_length)
steps = math.ceil(steps / num_blocks) # per-block step budget
histories = [x.clone()] if return_dict_in_generate else None
for b in range(num_blocks):
# Build a per-sample mask *within this block* (aligned to each prompt's tail)
block_mask_index = torch.zeros(
(B, block_length), dtype=torch.bool, device=x.device
)
for j in range(B):
start = prompt_lens[j] + b * block_length
end = min(start + block_length, prompt_lens[j] + max_new_tokens, T)
if start < end:
width = end - start
block_mask_index[j, :width] = (
x[j, start:end] == mask_id
) # which positions in this block are still masked
# Decide how many tokens to reveal per step in this block
num_transfer_tokens = get_num_transfer_tokens(
mask_index=block_mask_index,
steps=steps,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
# Some steps may be skipped if there are no transfers
effective_steps = num_transfer_tokens.size(1)
# ----- Iterative reveal inside the current block -----
for i in range(effective_steps):
mask_index = x == mask_id # current global mask map
# Optional CFG: second forward where original prompt tokens are masked out
if cfg_scale > 0.0:
un_x = x.clone()
un_x[unmasked_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
logits = self.model(
x_, attention_mask=attention_mask
).logits # Use attention mask here
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = self.model(
x, attention_mask=attention_mask
).logits # Use attention mask here
# Argmax decoding with optional Gumbel-Max noise for exploration
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(
logits_with_noise, dim=-1
) # [B, T] predicted token ids
# Per-position confidence used to pick which masks to commit this step
if remasking == "low_confidence":
p = F.softmax(logits, dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
) # [B, T] confidence of predicted token
elif remasking == "random":
x0_p = torch.rand(
(x0.shape[0], x0.shape[1]), device=x0.device
) # random scores
else:
raise NotImplementedError(remasking)
# Restrict selection window to the *current block's* tail region
for j in range(B):
x0_p[j, prompt_lens[j] + (b + 1) * block_length :] = -np.inf
# Only allow updates at currently masked positions; keep others fixed
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(
mask_index, x0_p, -np.inf
) # consider masked positions only
# Pick exactly `num_transfer_tokens[j, i]` highest-confidence positions per sample
transfer_index = torch.zeros_like(
x0, dtype=torch.bool, device=x0.device
)
for j in range(confidence.shape[0]):
_, select_index = torch.topk(
confidence[j], k=num_transfer_tokens[j, i]
)
transfer_index[j, select_index] = True
# Commit chosen predictions into the canvas
x[transfer_index] = x0[transfer_index]
if histories is not None:
histories.append(x.clone())
# ----- Output format -----
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)
@torch.no_grad()
def infill(
self, inputs: list[torch.Tensor | list], config, **kwargs
) -> GeneratorOutput | torch.Tensor:
"""
Fill in-place the <|mdm_mask|> tokens contained in `inputs`.
The whole (padded) sequence is split into block windows of length
`block_length`; within each window we progressively "unmask" positions
according to the scheduler and chosen remasking strategy.
Notes:
- Right padding uses EOS.
- CFG masks out *originally known* (non-mask, non-EOS) tokens in the
unconditional branch, identical to `generate`.
- Only masked positions are ever updated; non-mask tokens are left intact.
"""
# ----- pull args from config, allow kwargs to override -----
steps = kwargs.get("steps", config.steps)
block_length = kwargs.get("block_length", config.block_length)
temperature = kwargs.get("temperature", config.temperature)
cfg_scale = kwargs.get("cfg_scale", config.cfg_scale)
cfg_keep_tokens = kwargs.get("cfg_keep_tokens", config.cfg_keep_tokens)
remasking = kwargs.get("remasking", config.remasking)
stochastic_transfer = kwargs.get(
"stochastic_transfer", config.stochastic_transfer
)
return_dict_in_generate = kwargs.get(
"return_dict_in_generate", config.return_dict_in_generate
)
mask_id = self.tokenizer.mask_token_id
eos_id = self.tokenizer.eos_token_id
# ----- Build canvas: right-pad with EOS to the max length in the batch -----
if isinstance(inputs[0], list):
inputs = [
torch.as_tensor(p, dtype=torch.long, device=self.model.device)
for p in inputs
]
B = len(inputs)
seq_lens = [t.shape[0] for t in inputs]
T = max(seq_lens)
# Default to a single block spanning the whole sequence
if block_length is None:
block_length = T
assert 1 <= block_length
assert 1 <= steps
x = torch.full((B, T), eos_id, dtype=torch.long, device=self.model.device)
for i, t in enumerate(inputs):
x[i, : seq_lens[i]] = t
attention_mask = (x != eos_id).long() if B > 1 else None
# Tokens that were *given* at the start (non-mask, non-EOS).
# These will be masked in the unconditional forward pass for CFG.
# Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
unmasked_index = (x != mask_id) & (x != eos_id)
if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
keep_mask = torch.isin(
x, torch.as_tensor(cfg_keep_tokens, device=self.model.device)
)
unmasked_index = unmasked_index & ~keep_mask
# ----- Blockwise schedule over the *entire* (padded) sequence -----
num_blocks = math.ceil(T / block_length)
steps_per_block = math.ceil(steps / num_blocks)
histories = [x.clone()] if return_dict_in_generate else None
# Create attention mask where eos_token_id is masked (set to 0)
attention_mask = (x != eos_id).long()
for b in range(num_blocks):
start = b * block_length
stop = min(start + block_length, T)
# Per-sample view of which positions in this block are masks
block_mask_index = torch.zeros(
(B, block_length), dtype=torch.bool, device=self.model.device
)
widths = []
for j in range(B):
# Width limited by sample's true length and sequence end
width = max(0, min(seq_lens[j], stop) - start)
widths.append(width)
if width > 0:
block_mask_index[j, :width] = x[j, start : start + width] == mask_id
# Decide how many tokens to reveal at each step in this block
num_transfer_tokens = get_num_transfer_tokens(
mask_index=block_mask_index,
steps=steps_per_block,
scheduler=self.scheduler,
stochastic=stochastic_transfer,
)
# Some blocks may have no masks => effective_steps == 0
effective_steps = num_transfer_tokens.size(1)
for s in range(effective_steps):
mask_index_full = x == mask_id
# ----- Forward pass (+ optional CFG) -----
if cfg_scale > 0.0:
un_x = x.clone()
un_x[unmasked_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
logits = self.model(
x_, attention_mask=attention_mask
).logits # Use attention mask here
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = self.model(
x, attention_mask=attention_mask
).logits # Use attention mask here
# Greedy with optional Gumbel-Max noise
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1) # [B, T]
# Confidence used for choosing which masks to commit this step
if remasking == "low_confidence":
p = F.softmax(logits, dim=-1)
x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(
-1
) # [B, T]
elif remasking == "random":
x0_p = torch.rand((B, T), device=self.model.device)
else:
raise NotImplementedError(remasking)
# Restrict selection to the *current* block only
for j in range(B):
end_j = start + widths[j]
# Outside current block => impossible to select
x0_p[j, :start] = -np.inf
x0_p[j, end_j:] = -np.inf
# Only consider currently-masked positions as candidates
x0 = torch.where(mask_index_full, x0, x)
confidence = torch.where(mask_index_full, x0_p, -np.inf)
# Pick exactly num_transfer_tokens[j, s] positions per sample
transfer_index = torch.zeros_like(x, dtype=torch.bool)
for j in range(B):
k = int(num_transfer_tokens[j, s].item())
if k > 0:
_, select_idx = torch.topk(confidence[j], k=k)
transfer_index[j, select_idx] = True
# Commit selected predictions into the canvas
x[transfer_index] = x0[transfer_index]
if histories is not None:
histories.append(x.clone())
# ----- Output format -----
if not return_dict_in_generate:
return x
else:
return GeneratorOutput(sequences=x, histories=histories)

View File

@ -0,0 +1,19 @@
from .configuration_llada import LLaDAConfig
from .modeling_llada import LLaDAModelLM
from .configuration_lladamoe import LLaDAMoEConfig
from .modeling_lladamoe import LLaDAMoEModelLM
# Register with HuggingFace Auto classes for local usage
try:
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
AutoConfig.register("llada", LLaDAConfig)
AutoModel.register(LLaDAConfig, LLaDAModelLM)
AutoModelForMaskedLM.register(LLaDAConfig, LLaDAModelLM)
AutoConfig.register("lladamoe", LLaDAMoEConfig)
AutoModel.register(LLaDAMoEConfig, LLaDAMoEModelLM)
AutoModelForMaskedLM.register(LLaDAMoEConfig, LLaDAMoEModelLM)
except ImportError:
# transformers not available or Auto classes not imported
pass

View File

@ -0,0 +1,459 @@
"""
LLaDA configuration
"""
from transformers import PretrainedConfig
from enum import Enum
from os import PathLike
from typing import Union
from dataclasses import asdict, dataclass, field
from glob import glob
from pathlib import Path
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
__all__ = [
"ActivationType",
"ActivationCheckpointingStrategy",
"BlockType",
"LayerNormType",
"InitFnType",
"ModelConfig",
]
PathOrStr = Union[str, PathLike]
class StrEnum(str, Enum):
"""
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
We include this here for compatibility with older version of Python.
"""
def __str__(self) -> str:
return self.value
def __repr__(self) -> str:
return f"'{str(self)}'"
class LayerNormType(StrEnum):
default = "default"
"""
The default LayerNorm implementation, equivalent to PyTorch's built-in version.
"""
low_precision = "low_precision"
"""
A low-precision version of the default LayerNorm.
"""
rms = "rms"
"""
An RMSNorm implementation. When using ``torch.compile`` this is
probably the fastest implementation.
"""
gemma_rms = "gemma_rms"
"""
An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
probably the fastest implementation.
"""
amd_compatible = "amd_compatible"
"""
LayerNorm implemented manually to work around an issue with ROCm.
"""
class ActivationType(StrEnum):
gelu = "gelu"
relu = "relu"
silu = "silu"
swiglu = "swiglu"
class BlockType(StrEnum):
sequential = "sequential"
parallel = "parallel"
llama = "llama"
"""
A block similar to the sequential block with slightly different
implementations of operations like attention to imitate the behavior of Llama.
"""
class InitFnType(StrEnum):
mitchell = "mitchell"
"""
The strategy suggested to us by Mitchell Wortsman from UW.
This uses a truncated normal distribution with an adaptive standard deviation that depends
on the size of the weights as well as the depth of the layer.
"""
normal = "normal"
"""
All weights are initialized from the same normal distribution.
"""
kaiming_normal = "kaiming_normal"
"""
All weights are initialized with the Kaiming method from a normal distribution.
Note this currently won't work with FSDP.
"""
fan_in = "fan_in"
"""
"Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
is the input dimensionality of the kernel.
"""
full_megatron = "full_megatron"
"""
This is what metaseq calls "full megatron init". It is the init used for Llama 2.
"""
@dataclass
class ModelConfig():
"""
LLaDA (model) configuration.
"""
# Note that the defaults for these attributes are equivalent to the base GPT2 model.
d_model: int = 768
"""
The hidden size of the model.
"""
n_heads: int = 12
"""
The number of self-attention heads.
"""
n_kv_heads: Optional[int] = None
"""
The number of heads to use for keys and values. Defaults to `n_heads`.
Set this to ``None`` or ``n_heads`` for normal multi-head attention.
Set this to 1 for multi-query attention.
Set it to some in-between value for Llama2-style grouped query attention.
"""
n_layers: int = 12
"""
The number of layers/blocks.
"""
mlp_ratio: int = 4
"""
The ratio of the inner MLP dimensionality to ``d_model``.
This is only used when ``mlp_hidden_size`` is not set.
"""
mlp_hidden_size: Optional[int] = None
"""
Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
"""
activation_type: ActivationType = ActivationType.swiglu
"""
The activation function to use within the MLP layers.
"""
block_type: BlockType = BlockType.sequential
"""
The transformer block implementation.
"""
block_group_size: int = 1
"""
The number of blocks to group together into a single parent block.
This has no affect on the number of parameters in the model and is only used to wrap groups
of blocks together with a single FSDP wrapper during training.
"""
alibi: bool = False
"""
If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
"""
alibi_bias_max: float = 8.0
"""
Maximum absolute value of ALiBi bias.
"""
rope: bool = False
"""
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
"""
rope_full_precision: bool = True
"""
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
apply RoPE at the precision of the input.
"""
flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
"""
attention_dropout: float = 0.1
"""
The dropout probability within the attention modules.
"""
multi_query_attention: Optional[bool] = None
"""
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
and is more efficient during inference.
"""
attention_layer_norm: bool = False
"""
Apply layer norm to the keys and queries within the attention mechanism.
This can help stabilize training.
"""
residual_dropout: float = 0.1
"""
The dropout probability for the MLP and attention output within each block.
"""
embedding_dropout: float = 0.1
"""
The dropout probability for embeddings.
"""
input_emb_norm: bool = False
"""
An input hidden_states norm implementation by gemmma.
"""
layer_norm_type: LayerNormType = LayerNormType.default
"""
The layernorm implementation to use.
"""
layer_norm_with_affine: bool = True
"""
Whether to include bias and weight parameters for the layer norms.
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
to ``False``.
"""
rms_norm_eps: float = 1e-05
"""
The rms layernorm eps param.
"""
attention_layer_norm_with_affine: bool = True
"""
Toggle affine transform for the QK norms.
"""
max_sequence_length: int = 1024
"""
The maximum input sequence length supported by the model.
"""
rope_theta: float = 10000.0
"""
The rope base param.
"""
include_qkv_bias: Optional[bool] = False
"""
Whether or not to include bias parameters in qkv linear layers.
"""
include_bias: bool = False
"""
Whether or not to include bias parameters in linear layers.
In PaLM, they got rid of all bias terms because they found that large
models tend to have near 0 bias terms anyway.
"""
bias_for_layer_norm: Optional[bool] = None
"""
Whether or not to include bias parameters in layer norm.
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
layer norm.
When this is None (the default), it inherits the setting from include_bias.
"""
scale_logits: bool = False
"""
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
"""
vocab_size: int = 50257
"""
Vocabulary size of the model.
"""
embedding_size: Optional[int] = 50304
"""
The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
next multiple of 128 that's greater than ``vocab_size`` can improve throughput
substantially.
"""
weight_tying: bool = True
"""
Whether to tie output linear weights to the input embedding.
"""
eos_token_id: int = 50256
"""
The ID of the end-of-sentence special token.
"""
pad_token_id: int = 50256
"""
The ID of the token to use for padding. Defaults to the ID of the EOS token.
"""
mask_token_id: Optional[int] = 50256
"""
The ID of the token to use for mask token. Defaults to the ID of the EOS token.
"""
init_device: Optional[str] = None
"""
The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
"""
init_fn: InitFnType = InitFnType.normal
"""
The weight initialization strategy.
"""
init_std: float = 0.02
"""
The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
as "normal".
"""
init_cutoff_factor: Optional[float] = None
"""
A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
as "normal". Setting this to None means values are not cutoff.
"""
precision: Optional[str] = None
"""
Precision used to train/evaluate with. You shouldn't set this directly.
See :data:`TrainConfig.precision` instead.
"""
@property
def effective_n_kv_heads(self) -> int:
if self.n_kv_heads is None:
if self.multi_query_attention is True:
return 1
else:
return self.n_heads
else:
if self.multi_query_attention is None:
return self.n_kv_heads
if self.multi_query_attention:
n_kv_heads_should_be = 1
else:
n_kv_heads_should_be = self.n_heads
if self.n_kv_heads == n_kv_heads_should_be:
return n_kv_heads_should_be
else:
raise Exception(
"You can't set `multi_query_attention` and `n_kv_heads` at the same time."
)
class ActivationCheckpointingStrategy(StrEnum):
whole_layer = "whole_layer"
"""
Checkpoint every transformer layer.
"""
one_in_two = "one_in_two"
"""
Checkpoint one in two transformer layers.
"""
one_in_three = "one_in_three"
"""
Checkpoint one in three transformer layers.
"""
one_in_four = "one_in_four"
"""
Checkpoint one in four transformer layers.
"""
two_in_three = "two_in_three"
"""
Checkpoint two out of every three transformer layers.
"""
three_in_four = "three_in_four"
"""
Checkpoint three out of four of every transformer layers.
"""
four_in_five = "four_in_five"
"""
Checkpoint four out of five of every transformer layers.
"""
nine_in_ten = "nine_in_ten"
"""
Checkpoint nine out of ten of every transformer layers.
"""
fine_grained = "fine_grained"
"""
Focus checkpointing on where it is cheap to recompute and saves most memory.
"""
class LLaDAConfig(PretrainedConfig):
model_type = "llada"
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
def __init__(self, use_cache: bool = False, **kwargs):
model_config = ModelConfig()
all_kwargs = model_config.__dict__
all_kwargs.update(kwargs)
all_kwargs.update({"use_cache": use_cache})
all_kwargs.update(
{
"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
}
)
super().__init__(**all_kwargs)
@property
def num_attention_heads(self):
return self.n_heads
@property
def num_hidden_layers(self):
return self.n_layers
@property
def hidden_size(self):
return self.d_model

View File

@ -0,0 +1,96 @@
"""
LLaDA MoE configuration
"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
class LLaDAMoEConfig(PretrainedConfig):
model_type = "lladamoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=-1,
hidden_size=-1,
dense_intermediate_size=-1,
expert_intermediate_size=-1,
shared_expert_intermediate_size=-1,
num_hidden_layers=-1,
num_attention_heads=-1,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=False,
pad_token_id=1,
bos_token_id=None,
eos_token_id=50279,
tie_word_embeddings=False,
rope_theta=-1,
partial_rotary_factor=-1,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
clip_qkv=None,
num_experts_per_tok=-1,
num_experts=-1,
output_router_logits=False,
router_aux_loss_coef=0.01,
norm_topk_prob=None,
qk_layernorm=None,
moe_layer_freq=[],
moe_router_enable_expert_bias=None,
moe_router_score_function=None,
routed_scaling_factor=1,
router_num_group=-2,
router_topk_group=-2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.expert_intermediate_size = expert_intermediate_size
self.dense_intermediate_size = dense_intermediate_size
self.shared_expert_intermediate_size = shared_expert_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.clip_qkv = clip_qkv
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.norm_topk_prob = norm_topk_prob
self.qk_layernorm = qk_layernorm
self.moe_layer_freq = moe_layer_freq
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
self.moe_router_score_function = moe_router_score_function
self.partial_rotary_factor = partial_rotary_factor
self.routed_scaling_factor = routed_scaling_factor
self.router_num_group = router_num_group
self.router_topk_group = router_topk_group
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,3 @@
from dllm.core.trainers import MDLMTrainer
LLaDATrainer = MDLMTrainer

View File

@ -0,0 +1,7 @@
# from dllm.pipelines.rnd import generate, trainer
from . import models
from .models import RND1LM, RND1Config, RND1GenerationConfig
# from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
# from dllm.pipelines.rnd.models.configuration_rnd import RND1Config
from .trainer import RNDTrainer

View File

@ -0,0 +1,53 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
Radical Numerics Diffusion (RND1) - Diffusion-based Language Model.
"""
from .configuration_rnd import RND1Config
from .modeling_rnd import (
RND1LM,
RND1Model,
RND1PreTrainedModel,
RND1Attention,
RND1DecoderLayer,
RND1SparseMoeBlock,
)
from .generation_config import RND1GenerationConfig
from .generation_utils import RND1GenerationMixin
from .sampling import (
diffusion_sample,
apply_top_k_filtering,
apply_top_p_filtering,
)
from .terminal_visualizer import TerminalVisualizer, SimpleProgressBar
__version__ = "0.1.0"
__all__ = [
"RND1Config",
"RND1GenerationConfig",
"RND1LM",
"RND1Model",
"RND1PreTrainedModel",
"RND1Attention",
"RND1DecoderLayer",
"RND1SparseMoeBlock",
"RND1GenerationMixin",
"TerminalVisualizer",
"SimpleProgressBar",
]
# Register with HuggingFace Auto classes for local usage
try:
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
AutoConfig.register("rnd1", RND1Config)
AutoModel.register(RND1Config, RND1LM)
AutoModelForMaskedLM.register(RND1Config, RND1LM)
except ImportError:
# transformers not available or Auto classes not imported
pass

View File

@ -0,0 +1,124 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 Model Configuration.
This module defines the configuration class for RND1 models.
The default settings are derived from Qwen/Qwen3-30B-A3B and augmented
with RND1-specific parameters.
"""
from transformers.configuration_utils import PretrainedConfig
# Qwen3-30B-A3B / checkpoint defaults
CONFIG_DEFAULTS = {
"attention_bias": False,
"attention_dropout": 0.0,
"decoder_sparse_step": 1,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 40960,
"max_window_layers": 48,
"mlp_only_layers": [],
"moe_intermediate_size": 768,
"norm_topk_prob": True,
"num_attention_heads": 32,
"num_experts": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 48,
"num_key_value_heads": 4,
"output_router_logits": False,
"pad_token_id": 151643,
"rms_norm_eps": 1e-06,
"rope_scaling": False,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.001,
"sliding_window": False,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"use_cache": False,
"use_sliding_window": False,
"vocab_size": 151936,
}
class RND1Config(PretrainedConfig):
"""
Configuration class for RND1 models.
This configuration extends Qwen3MoeConfig with additional parameters
specific to the RND1 (Radical Numerics Diffusion v1) architecture.
Args:
moe_backend: Backend for MoE computation ("hf", "vllm", "sglang" or "flashinfer")
num_diffusion_steps: Default number of diffusion steps for generation
mask_token_id: Token ID used for masking (default: 151669 for Qwen)
**kwargs: Additional arguments passed to Qwen3MoeConfig
"""
model_type = "rnd1"
def __init__(
self,
moe_backend: str = "hf",
num_diffusion_steps: int = 256,
mask_token_id: int = 151669,
**kwargs,
):
# Force non-causal and no caching for RND1
kwargs["use_cache"] = False
kwargs["is_causal"] = False
super().__init__(**kwargs)
# Set defaults after pretrained init to prevent overrides
self.set_config_defaults()
# QoL: set attn impl directly from config
if "attn_implementation" in kwargs:
self._attn_implementation = kwargs["attn_implementation"]
# RND1-specific parameters
self.moe_backend = moe_backend
self.num_diffusion_steps = num_diffusion_steps
self.mask_token_id = mask_token_id
# Ensure bidirectional attention and no caching
self.is_causal = False
self.use_cache = False
def set_config_defaults(self):
"""
Ensure model defaults are set according to final training checkpoint
Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which
RND1 is derived.
"""
for k, v in CONFIG_DEFAULTS.items():
setattr(self, k, v)
def to_dict(self):
"""
Serializes configuration to dictionary with auto_map for Hub.
The auto_map ensures that when users load from HuggingFace Hub,
the correct custom classes are automatically resolved.
"""
data = super().to_dict()
data.setdefault(
"auto_map",
{
"AutoConfig": "configuration_rnd.RND1Config",
"AutoModel": "modeling_rnd.RND1Model",
"AutoModelForMaskedLM": "modeling_rnd.RND1LM",
},
)
return data

View File

@ -0,0 +1,77 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 Generation Configuration.
This module defines the generation configuration for RND1 models,
controlling the diffusion-based generation process.
"""
from typing import Optional
from transformers.generation.configuration_utils import GenerationConfig
class RND1GenerationConfig(GenerationConfig):
"""
Configuration class for RND1 generation parameters.
This class extends the base GenerationConfig to include parameters
specific to diffusion-based language generation.
Args:
max_length: Maximum sequence length
num_diffusion_steps: Number of denoising steps in the diffusion process
mask_token_id: Token ID used for masking during diffusion
temperature: Temperature for sampling (higher = more random)
top_k: Optional top-k filtering
top_p: Optional nucleus (top-p) filtering
greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
**kwargs: Additional arguments passed to GenerationConfig
"""
def __init__(
self,
max_length: int = 256,
num_diffusion_steps: int = 256,
mask_token_id: int = 151669,
temperature: float = 0.1,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
greedy: bool = False,
bos_token_id: int = None,
eos_token_id: int = None,
pad_token_id: int = None,
use_cache: bool = False,
**kwargs,
):
# Force no caching for RND generation
# kwargs['use_cache'] = False
kwargs.pop('use_cache', None)
super().__init__(
max_length=max_length,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=not greedy,
use_cache=False,
**kwargs,
)
# RND-specific parameters
self.num_diffusion_steps = num_diffusion_steps
self.mask_token_id = mask_token_id
self.greedy = greedy
def to_dict(self):
"""Convert configuration to dictionary."""
output = super().to_dict()
output["num_diffusion_steps"] = self.num_diffusion_steps
output["mask_token_id"] = self.mask_token_id
output["greedy"] = self.greedy
return output

View File

@ -0,0 +1,187 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 Generation Utilities.
This module provides generation utilities and mixins for RND1 models,
including the main GenerationMixin class that integrates with HuggingFace.
"""
import torch
import torch.nn as nn
from typing import Optional, Union, Dict, Any
from transformers import GenerationMixin as HFGenerationMixin
from transformers.generation import GenerationConfig
from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
class RND1GenerationMixin(HFGenerationMixin):
"""
Generation mixin for RND1 models.
This mixin provides generation methods compatible with HuggingFace's
generation API while using RND1's diffusion-based sampling internally.
"""
def generate(
self,
inputs: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
# RND1-specific parameters
prefix_ids: Optional[torch.LongTensor] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
return_dict_in_generate: Optional[bool] = None,
**kwargs, # Accept all kwargs to be compatible with pipelines
) -> Union[torch.LongTensor, Dict[str, Any]]:
"""
Generate text using RND1's diffusion-based sampling.
Follows HuggingFace's standard generate API, using diffusion sampling
internally. Supports both standard generation and infilling.
Args:
inputs: Input token IDs to use as prefix (standard HF parameter)
generation_config: Generation configuration object
prefix_ids: Alternative to inputs for infilling tasks
suffix_ids: Optional suffix for infilling tasks
infill_length: Length of infill region (for infilling)
return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
**kwargs: Additional arguments (accepted for compatibility)
Returns:
Generated token IDs or GenerateDecoderOnlyOutput
"""
if generation_config is not None:
gen_config = generation_config
model_kwargs = kwargs.copy()
else:
# Only prepare config from kwargs if no config was provided
gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
device = next(self.parameters()).device
if inputs is not None:
prefix_ids = inputs.to(device)
elif prefix_ids is not None:
prefix_ids = prefix_ids.to(device)
else:
prefix_ids = None
if suffix_ids is not None:
suffix_ids = suffix_ids.to(device)
eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
eos_token_id = None if eos_token_id == -1 else eos_token_id
pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
if infill_length is not None and prefix_ids is not None:
# Infilling mode: use specified infill_length
prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
seq_len = prefix_len + infill_length + suffix_len
else:
# Standard generation mode
if prefix_ids is not None:
prefix_len = prefix_ids.shape[1]
if gen_config.max_new_tokens is not None:
seq_len = prefix_len + gen_config.max_new_tokens
else:
seq_len = gen_config.max_length or self.config.max_position_embeddings
else:
seq_len = gen_config.max_length or self.config.max_position_embeddings
num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
getattr(self.config, "num_diffusion_steps", 256))
temperature = float(getattr(gen_config, "temperature", 1.0))
top_k = getattr(gen_config, "top_k", None)
top_p = getattr(gen_config, "top_p", None)
greedy = getattr(gen_config, "greedy",
not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
with torch.inference_mode():
sequences = diffusion_sample(
model=self,
seq_len=seq_len,
num_steps=num_diffusion_steps,
mask_token_id=mask_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
greedy=greedy,
prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
infill_length=infill_length,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
device=device,
visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
)
if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
from transformers.generation.utils import GenerateDecoderOnlyOutput
return GenerateDecoderOnlyOutput(sequences=sequences)
return sequences
def generate_with_visualization(
self,
tokenizer,
inputs: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
**kwargs,
) -> torch.LongTensor:
"""
Generate with live visualization (for demos).
This method requires a tokenizer to display the generation process.
For production use, prefer `generate()`.
Args:
tokenizer: Tokenizer for decoding tokens to text
inputs: Input token IDs to use as prefix
generation_config: Generation configuration object
suffix_ids: Optional suffix token IDs
infill_length: Length of infill region
**kwargs: Additional arguments for backward compatibility
Returns:
Generated token IDs as LongTensor
"""
from .terminal_visualizer import TerminalVisualizer
visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
return self.generate(
inputs=inputs,
generation_config=generation_config,
suffix_ids=suffix_ids,
infill_length=infill_length,
visualizer=visualizer,
return_dict_in_generate=False,
**kwargs,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
**kwargs,
) -> Dict[str, Any]:
"""
Prepare inputs for generation (required by HuggingFace).
For RND1, we don't use the standard autoregressive generation,
so this just returns the input_ids.
"""
return {"input_ids": input_ids}

View File

@ -0,0 +1,653 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 model implementation.
This module implements the RND1 architecture with bidirectional attention for
diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
with multiple backend options (HF, vLLM, SGLang, FlashInfer).
Based on the Qwen3Moe architecture:
https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
"""
from __future__ import annotations
import os
from typing import Optional, Tuple, List, Union
import torch
from torch import nn
from transformers.utils import logging
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
MoeModelOutputWithPast,
MaskedLMOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from .configuration_rnd import RND1Config
from .generation_utils import RND1GenerationMixin
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeRMSNorm,
Qwen3MoeRotaryEmbedding,
Qwen3MoeMLP,
apply_rotary_pos_emb
)
import torch.nn.functional as F
try:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts as fused_experts_vllm, fused_topk as fused_topk_vllm
from vllm.model_executor.layers.layernorm import RMSNorm as VLLMRMSNorm
except Exception:
fused_experts_vllm = None
fused_topk_vllm = None
try:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
# from sglang.srt.layers.layernorm import RMSNorm as SGLangRMSNorm # TODO: buggy atm
from sglang.srt.layers.moe.topk import StandardTopKOutput
except Exception:
sglang_fused_moe = None
StandardTopKOutput = None
try:
import flashinfer.fused_moe as fused_moe
## TODO: below needs flashinfer>=0.4.0, but has some bug atm
# from flashinfer.norm import rmsnorm as flashinfer_rmsnorm
# class FlashInferRMSNorm(Qwen3MoeRMSNorm):
# """Wrapper around FlashInfer RMSNorm to match Qwen3MoeRMSNorm interface"""
# def forward(self, hidden_states):
# return flashinfer_rmsnorm(hidden_states, self.weight, self.variance_epsilon)
except Exception:
fused_moe = None
logger = logging.get_logger(__name__)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand key/value heads to match query heads for grouped-query attention."""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class RND1Attention(nn.Module):
"""RND1 attention layer with bidirectional attention for diffusion modeling."""
def __init__(self, config: RND1Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
if config.moe_backend == "vllm":
RMSNormClass = VLLMRMSNorm
else:
RMSNormClass = Qwen3MoeRMSNorm
self.q_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNormClass(self.head_dim, eps=config.rms_norm_eps)
self.sliding_window = getattr(config, "sliding_window", None)
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
dual_cache: Optional[bool] = False,
replace_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]:
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa")
if use_sdpa:
if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
attention_mask = attention_mask.to(dtype=query_states.dtype)
assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
attn_out = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=self.is_causal,
)
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
attn_out = self.o_proj(attn_out)
return attn_out, None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None:
# TODO: modify this to boolean masks
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_out = torch.matmul(attn_weights, value_states)
attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
attn_out = self.o_proj(attn_out)
return attn_out, None
class RND1DecoderLayer(nn.Module):
"""RND1 decoder layer with bidirectional attention for diffusion language modeling."""
def __init__(self, config: RND1Config, layer_idx: int):
super().__init__()
self.self_attn = RND1Attention(config, layer_idx)
self.mlp = RND1SparseMoeBlock(config)
if config.moe_backend == "vllm":
RMSNormClass = VLLMRMSNorm
else:
RMSNormClass = Qwen3MoeRMSNorm
self.input_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
replace_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_out, attn_weights = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
replace_position=replace_position,
)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
ff_out = self.mlp(hidden_states)
if isinstance(ff_out, tuple):
ff_out = ff_out[0]
hidden_states = residual + ff_out
return hidden_states, attn_weights
class RND1SparseMoeBlock(nn.Module):
"""RND1 Sparse MoE block with multiple backend support (HF, vLLM, SGLang, FlashInfer)."""
def __init__(self, config: RND1Config):
super().__init__()
self.config = config
self.backend = getattr(config, "moe_backend", "hf")
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_size = config.hidden_size
self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
self.experts = nn.ModuleList(
[Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
)
# Cached weight tensors for optimized backends
self._w1 = None
self._w2 = None
if self.backend == "sglang":
if sglang_fused_moe is None or StandardTopKOutput is None:
raise RuntimeError("sglang is not available, cannot use sglang backend")
elif self.backend == "flashinfer":
if fused_moe is None:
raise RuntimeError("flashinfer is not available, cannot use flashinfer backend")
elif self.backend == "vllm":
if fused_experts_vllm is None or fused_topk_vllm is None:
raise RuntimeError("vllm is not available, cannot use vllm backend")
@torch.no_grad()
def _initialize_weights(
self,
free_experts: bool = True,
mode: str = "vllm",
) -> None:
logger.info(f"Initializing weights for {mode} backend")
# Stack directly on device where weights already reside (loaded by HF)
gate_list: List[torch.Tensor] = []
up_list: List[torch.Tensor] = []
down_list: List[torch.Tensor] = []
# Collect weight references without any device moves
for expert in self.experts:
gate_list.append(expert.gate_proj.weight.data)
up_list.append(expert.up_proj.weight.data)
down_list.append(expert.down_proj.weight.data)
gate_w_stacked = torch.stack(gate_list, dim=0).contiguous()
up_w_stacked = torch.stack(up_list, dim=0).contiguous()
down_w_stacked = torch.stack(down_list, dim=0).contiguous()
if mode == "flashinfer":
w1 = torch.cat([up_w_stacked, gate_w_stacked], dim=1) # FlashInfer expects [up; gate] ordering
else:
w1 = torch.cat([gate_w_stacked, up_w_stacked], dim=1)
w2 = down_w_stacked
self._w1 = w1
self._w2 = w2
if free_experts:
# Free per-expert modules to reclaim memory
logger.info(f"Freeing experts for {mode} backend")
del self.experts
self.experts = None
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with expert routing and computation."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
x = hidden_states.view(-1, hidden_dim)
# Expert routing
router_logits = self.gate(x)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
if self.backend == "vllm":
routing_weights, selected_experts, *_ = fused_topk_vllm(
hidden_states=x,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.norm_topk_prob,
)
else:
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
# if self.backend == "hf":
# final_hidden_states = torch.zeros(
# (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
# )
# expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
# for expert_idx in expert_hit:
# expert_layer = self.experts[expert_idx]
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# current_state = x[top_x]
# current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
# return out, router_logits.view(batch_size, sequence_length, -1)
if self.backend == "hf":
# Accumulate buffer: [B*S, H]
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# expert_mask: [E, top_k, tokens]
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0).contiguous()
# 顺序遍历所有 experts即使本 rank 没命中也要进入 forward避免 ZeRO-3 控制流分歧
for e in range(self.num_experts):
expert_layer = self.experts[int(e)]
# 取出该 expert 命中的 token 索引
idx, top_x = torch.where(expert_mask[e]) # idx∈[0, top_k), shapes: [n_tok_e]
current_state = x[top_x] # [n_tok_e, H]n_tok_e 可能为 0
# if top_x.numel() == 0:
# print("0")
# 空批照样前向;大多数 Linear/MLP 对 0 行输入是 no-op但会对齐 ZeRO-3 的参数路径
expert_out = expert_layer(current_state) # [n_tok_e, H]
# 路由权重并加权
w = routing_weights[top_x, idx] # [n_tok_e]
expert_out = expert_out * w.unsqueeze(-1) # [n_tok_e, H]
# 累加回全局缓冲;当 n_tok_e=0 时这是合法的空操作
final_hidden_states.index_add_(0, top_x, expert_out.to(hidden_states.dtype))
out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "flashinfer":
# if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
# self._initialize_flashinfer_weights()
if self._w1 is None or self._w2 is None:
self._initialize_weights(mode="flashinfer")
result = fused_moe.cutlass_fused_moe(
input=x,
token_selected_experts=selected_experts.to(torch.int),
token_final_scales=routing_weights.to(torch.float32),
fc1_expert_weights=self._w1,
fc2_expert_weights=self._w2,
output_dtype=x.dtype,
quant_scales=None,
)
if isinstance(result, (list, tuple)):
out_flat = result[0]
else:
out_flat = result
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "sglang":
if self._w1 is None or self._w2 is None:
self._initialize_weights(mode="sglang")
topk_output = StandardTopKOutput(
topk_weights=routing_weights,
topk_ids=selected_experts,
router_logits=router_logits,
)
out_flat = sglang_fused_moe(
hidden_states=x,
w1=self._w1,
w2=self._w2,
topk_output=topk_output,
)
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "vllm":
if self._w1 is None or self._w2 is None:
self._initialize_weights()
out_flat = fused_experts_vllm(
x,
self._w1,
self._w2,
routing_weights,
selected_experts,
)
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
else:
raise ValueError(f"Invalid backend: {self.backend}")
class RND1PreTrainedModel(PreTrainedModel):
"""Base class for RND1 models with weight initialization and loading support."""
config_class = RND1Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["RND1DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
"""Initialize weights using normal distribution."""
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
):
"""Load pretrained model with generation config."""
_model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
resume_download = kwargs.get("resume_download", None)
proxies = kwargs.get("proxies", None)
subfolder = kwargs.get("subfolder", "")
from_auto_class = kwargs.get("_from_auto", False)
from_pipeline = kwargs.get("_from_pipeline", None)
_model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
)
# If configured to use a fused backend, pack fused tensors once after load
try:
cfg = getattr(_model, "config", None)
backend = getattr(cfg, "moe_backend", "hf") if cfg is not None else "hf"
if backend in ("sglang", "vllm"):
# Walk decoder layers and initialize fused weights
model_core = getattr(_model, "model", _model)
layers = getattr(model_core, "layers", None)
if isinstance(layers, nn.ModuleList):
for layer in layers:
mlp = getattr(layer, "mlp", None)
if hasattr(mlp, "_initialize_weights"):
mlp._initialize_weights(
free_experts=True,
mode=backend,
)
except Exception as _e:
logger.warning(f"Backend {backend} weight processing skipped: {_e}")
return _model
class RND1Model(RND1PreTrainedModel):
"""RND1 transformer model with bidirectional attention for diffusion language modeling."""
def __init__(self, config: RND1Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
if config.moe_backend == "vllm":
RMSNormClass = VLLMRMSNorm
else:
RMSNormClass = Qwen3MoeRMSNorm
self.norm = RMSNormClass(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> MoeModelOutputWithPast:
"""Forward pass through the RND1 model."""
if (input_ids is None) == (inputs_embeds is None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
if isinstance(attention_mask, torch.Tensor):
# shape: (batch_size, 1, 1, seq_len)
attention_mask = attention_mask.to(dtype=torch.float)[:, None, None, :]
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states, _ = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
)
hidden_states = self.norm(hidden_states)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
router_logits=None,
)
class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
"""Radical Numerics Diffusion Language Model with bidirectional attention."""
def __init__(self, config: RND1Config):
super().__init__(config)
self.model = RND1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
"""Get the input embeddings layer."""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""Set the input embeddings layer."""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""Get the output embeddings layer (lm_head)."""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Set the output embeddings layer (lm_head)."""
self.lm_head = new_embeddings
@classmethod
def can_generate(cls) -> bool:
"""Indicates this model can generate text."""
return True
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> MaskedLMOutput:
"""Forward pass with optional loss computation."""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
logits = self.lm_head(outputs.last_hidden_state)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return MaskedLMOutput(
loss=loss,
logits=logits,
)

View File

@ -0,0 +1,260 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
RND1 sampling module for masked diffusion generation.
This module implements entropy-based token selection for iterative denoising
in diffusion language models. Supports both greedy and stochastic sampling
with optional prefix/suffix constraints and infilling.
"""
import torch
import torch.nn as nn
from typing import Optional, Union
def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
"""
Apply top-k filtering to logits: with non-top-k values set to -inf
"""
top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
filtered_logits = torch.full_like(logits, float('-inf'))
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
return filtered_logits
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
"""
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 0] = False # Keep at least one token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
return logits.masked_fill(indices_to_remove, float('-inf'))
@torch.no_grad()
def diffusion_sample(
model: nn.Module,
seq_len: int = 256,
num_steps: int = 256,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: float = 1.0,
greedy: bool = True,
mask_token_id: int = 151669,
prefix_ids: Optional[torch.LongTensor] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
eos_token_id: int = 151645,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
visualizer: Optional[object] = None,
) -> torch.LongTensor:
"""
Perform masked diffusion sampling with entropy-based token selection.
Args:
model: The RND1 language model
seq_len: Target sequence length
num_steps: Number of denoising steps
top_k: Optional top-k filtering for sampling (None = no filtering)
top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
When both top_k and top_p are set, top_k is applied first, then top_p
temperature: Temperature for sampling (higher = more random, lower = more deterministic)
Values close to 0 are clamped to 1e-8 to avoid division by zero
greedy: Whether to use greedy sampling (True) or stochastic (False)
mask_token_id: Token ID for masked positions (default: 151669)
prefix_ids: Optional prefix token IDs to preserve
suffix_ids: Optional suffix token IDs to preserve
infill_length: Length of infill region between prefix/suffix
eos_token_id: End of sequence token ID (default: 151645)
pad_token_id: Padding token ID (default: None, uses 0 if needed)
bos_token_id: Beginning of sequence token ID (default: None)
device: Device for computation (None = infer from model)
visualizer: Optional visualizer for live visualization
Returns:
Generated token IDs as LongTensor
"""
model.eval()
if device is None:
device = next(model.parameters()).device
else:
device = torch.device(device)
if pad_token_id is None:
pad_token_id = 0
# Build initial masked sequence
# When prefix_ids is provided, we create a sequence of length seq_len where:
# - The prefix occupies the first pre_len positions
# - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
if prefix_ids is not None or suffix_ids is not None:
if prefix_ids is not None:
prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device)
pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
else:
pre_len = 0
if suffix_ids is not None:
suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device)
suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
else:
suf_len = 0
reserved = (1 if eos_token_id is not None else 0)
used = pre_len + suf_len + reserved
if used > seq_len:
raise ValueError(
f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
f"Please increase seq_len or reduce input lengths."
)
elif used == seq_len:
raise ValueError(
f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
f"Need at least 1 position for generation."
)
infill_length = min(infill_length or (seq_len - used), seq_len - used)
x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
pos = 0
# if bos_token_id is not None:
# x[0, pos] = bos_token_id; pos += 1
if eos_token_id is not None:
x[0, -1] = eos_token_id
if pre_len > 0:
x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]
pos += pre_len
fill_start, fill_end = pos, pos + infill_length
x[0, fill_start:fill_end] = mask_token_id
# print(fill_start, fill_end, seq_len, used, x[0, -1])
pos = fill_end
if suf_len > 0:
x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]
pos += suf_len
init_maskable = torch.zeros_like(x, dtype=torch.bool)
init_maskable[0, fill_start:fill_end] = True
else:
x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
if bos_token_id is not None:
x[0, 0] = bos_token_id
if eos_token_id is not None:
x[0, -1] = eos_token_id
init_maskable = x.eq(mask_token_id)
if bos_token_id is not None:
init_maskable[:, 0] = False
if eos_token_id is not None:
init_maskable &= x.ne(eos_token_id)
init_maskable &= x.ne(pad_token_id)
maskable = init_maskable.clone()
xt = x.clone()
if visualizer:
visualizer.start_visualization(xt, maskable, num_steps)
def forward_scores(tokens):
"""Compute predictions and entropy scores for next tokens."""
# Try with input_ids parameter first (standard HF models)
try:
model_output = model(input_ids=tokens)
except TypeError:
# Fall back to positional argument
model_output = model(tokens)
# Apply temperature scaling (with safety for near-zero temperature)
safe_temperature = max(temperature, 1e-8) # Prevent division by zero
logits = model_output.logits / safe_temperature
# Apply filtering strategies
# Note: When both top_k and top_p are provided, they are applied sequentially:
# First top_k filters to k tokens, then top_p filters from those k tokens
if top_k is not None and top_k > 0:
logits = apply_top_k_filtering(logits, top_k)
if top_p is not None and 0 < top_p < 1.0:
logits = apply_top_p_filtering(logits, top_p)
# Convert to log probabilities
logp = torch.log_softmax(logits, dim=-1)
# Greedy or stochastic sampling
if greedy:
pred_next = logp.argmax(-1)
else:
pred_next = torch.distributions.Categorical(logits=logp).sample()
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
p = logp.exp()
ent_next = -(p * logp).sum(-1)
# Shift predictions: pos i predicts token i+1
pred_i = tokens.clone()
conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
ent_i = torch.zeros_like(ent_next)
pred_i[:, 1:] = pred_next[:, :-1]
conf_i[:, 1:] = conf_next[:, :-1]
ent_i[:, 1:] = ent_next[:, :-1]
return pred_i, conf_i, ent_i
pred_i, conf_i, ent_i = forward_scores(xt)
total_masked = init_maskable.sum(1, keepdim=True)
finf = torch.finfo(conf_i.dtype)
for step in range(num_steps - 1, 0, -1):
rate = step / num_steps
cutoff_len = (total_masked * rate).long().clamp(min=0)
# Choose HIGH-entropy tokens to keep masked
sel_scores = ent_i.masked_fill(~maskable, -finf.max)
B, L = sel_scores.shape
k_max = cutoff_len.max().item()
if k_max > 0:
sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
for b in range(B):
k_b = int(cutoff_len[b].item())
if k_b > 0:
keep_mask[b, idx[b, :k_b]] = True
else:
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
to_unmask = maskable & ~keep_mask
if to_unmask.any():
xt[to_unmask] = pred_i[to_unmask]
maskable[to_unmask] = False
if visualizer:
visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
if maskable.any():
pred_i, conf_i, ent_i = forward_scores(xt)
if maskable.any():
xt[maskable] = pred_i[maskable]
if visualizer:
visualizer.stop_visualization()
return xt

View File

@ -0,0 +1,251 @@
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.
"""
Terminal visualization for RND1 generation.
This module provides real-time visualization of the diffusion denoising process,
showing token evolution and generation progress in the terminal using rich
formatting when available.
"""
import torch
from typing import Optional
from tqdm import tqdm
try:
from rich.console import Console
from rich.live import Live
from rich.text import Text
from rich.panel import Panel
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.layout import Layout
RICH_AVAILABLE = True
except ImportError:
RICH_AVAILABLE = False
class TerminalVisualizer:
"""
Rich-based visualization for diffusion process with live updates.
Provides real-time visualization of the token denoising process during
diffusion-based language generation, with colored highlighting of masked
positions and progress tracking.
"""
def __init__(self, tokenizer, show_visualization: bool = True):
"""
Initialize the terminal visualizer.
Args:
tokenizer: The tokenizer for decoding tokens to text
show_visualization: Whether to show visualization (requires rich)
"""
self.tokenizer = tokenizer
self.show_visualization = show_visualization and RICH_AVAILABLE
if not RICH_AVAILABLE and show_visualization:
print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
self.show_visualization = False
if self.show_visualization:
self.console = Console()
self.live = None
self.progress = None
self.layout = None
else:
self.pbar = None
self.current_tokens = None
self.mask_positions = None
self.total_steps = 0
self.current_step = 0
def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
"""
Start the visualization.
Args:
initial_tokens: Initial token IDs (possibly masked)
mask_positions: Boolean mask indicating which positions are masked
total_steps: Total number of diffusion steps
"""
if not self.show_visualization:
self.pbar = tqdm(total=total_steps, desc="Diffusion")
return
self.current_tokens = initial_tokens.clone()
self.mask_positions = mask_positions
self.total_steps = total_steps
self.current_step = 0
self.layout = Layout()
self.layout.split_column(
Layout(name="header", size=3),
Layout(name="text", ratio=1),
Layout(name="progress", size=3)
)
self.progress = Progress(
TextColumn("[bold blue]Diffusion"),
BarColumn(),
MofNCompleteColumn(),
TextColumn(""),
TextColumn("[cyan]Masks: {task.fields[masks]}"),
TimeRemainingColumn(),
)
self.progress_task = self.progress.add_task(
"Generating",
total=total_steps,
masks=mask_positions.sum().item()
)
self.live = Live(self.layout, console=self.console, refresh_per_second=4)
self.live.start()
self._update_display()
def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int,
entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None):
"""
Update visualization for current step.
Args:
tokens: Current token IDs
maskable: Boolean mask of remaining masked positions
step: Current step number
entropy: Optional entropy scores for each position
confidence: Optional confidence scores for each position
"""
if not self.show_visualization:
if self.pbar:
self.pbar.update(1)
masks = maskable.sum().item() if maskable is not None else 0
self.pbar.set_postfix({'masks': masks})
return
self.current_tokens = tokens.clone()
self.mask_positions = maskable
self.current_step = step
masks_remaining = maskable.sum().item() if maskable is not None else 0
self.progress.update(
self.progress_task,
advance=1,
masks=masks_remaining
)
self._update_display()
def _update_display(self):
"""Update the live display."""
if not self.live:
return
header = Text("RND1-Base Generation", style="bold magenta", justify="center")
self.layout["header"].update(Panel(header, border_style="bright_blue"))
text_display = self._format_text_with_masks()
self.layout["text"].update(
Panel(
text_display,
title="[bold]Generated Text",
subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
border_style="cyan"
)
)
self.layout["progress"].update(Panel(self.progress))
def _format_text_with_masks(self) -> Text:
"""
Format text with colored masks.
Returns:
Rich Text object with formatted tokens
"""
text = Text()
if self.current_tokens is None:
return text
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions
for i, token_id in enumerate(token_ids):
if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
# Alternate colors for visual effect
text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red")
else:
try:
token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
# Skip special tokens in display
if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]:
# Color based on position
text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
except:
continue
return text
def stop_visualization(self):
"""Stop the visualization and display final result."""
if not self.show_visualization:
if self.pbar:
self.pbar.close()
print("\n✨ Generation complete!\n")
return
if self.live:
self.live.stop()
self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
# Display final text
if self.current_tokens is not None:
try:
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
self.console.print(Panel(
final_text,
title="[bold]Final Generated Text",
border_style="green",
padding=(1, 2)
))
except:
pass
class SimpleProgressBar:
"""
Simple progress bar fallback when rich is not available.
Provides basic progress tracking using tqdm when the rich library
is not installed.
"""
def __init__(self, total_steps: int):
"""
Initialize simple progress bar.
Args:
total_steps: Total number of steps
"""
self.pbar = tqdm(total=total_steps, desc="Diffusion")
def update(self, masks_remaining: int = 0):
"""
Update progress bar.
Args:
masks_remaining: Number of masks still remaining
"""
self.pbar.update(1)
self.pbar.set_postfix({'masks': masks_remaining})
def close(self):
"""Close the progress bar."""
self.pbar.close()
print("\n✨ Generation complete!\n")

View File

@ -0,0 +1,23 @@
from typing import Any
import torch
from dllm.core.trainers import MDLMTrainer
class RNDTrainer(MDLMTrainer):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
def _preprocess_inputs(self, inputs):
labels = inputs["labels"]
assert (labels[:, 0] == -100).all()
def _postprocess_outputs(self, outputs):
logits = outputs.logits
outputs.logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

242
dllm/dllm/tools/chat.py Normal file
View File

@ -0,0 +1,242 @@
import shutil
from typing import List, Literal
import textwrap
import dllm
# ============================================================
# Utility helpers
# ============================================================
try:
L = shutil.get_terminal_size().columns
if not isinstance(L, int) or L <= 0:
L = 120
except Exception:
L = 120
DIV = "=" * L
SUB = "-" * L
def banner_line(text: str, width: int = L, fill: str = "=") -> str:
"""Return a centered banner line with given width and fill."""
text = f" {text.strip()} "
fill_len = width - len(text)
if fill_len <= 0:
return text
left = fill_len // 2
right = fill_len - left
return f"{fill * left}{text}{fill * right}"
def print_wrapped(text: str, width: int = L):
"""Print text with automatic line wrapping."""
wrapped = textwrap.fill(text, width=width)
print(wrapped)
def boxed(text: str, width: int = L, padding: int = 1):
"""Render a centered box with the given text and width."""
lines = text.splitlines()
content_width = max(len(line) for line in lines)
box_width = min(width, content_width + padding * 2 + 2)
# compute left margin for centering
terminal_width = width
left_margin = max((terminal_width - box_width) // 2, 0)
margin = " " * left_margin
top = margin + "" + "" * (box_width - 2) + ""
bottom = margin + "" + "" * (box_width - 2) + ""
print(top)
for line in lines:
inner = line.center(content_width)
print(margin + "" + " " * padding + inner + " " * padding + "")
print(bottom)
def decode_trim(tokenizer, seq_ids_list, input_ids_list) -> str:
"""
Return only the generated text, truncated at the first EOS **after** the prompt.
Args:
tokenizer: HF tokenizer with eos_token_id / pad_token_id.
seq_ids: Full sequence token ids from the model (prompt + generation).
input_ids: The prompt token ids that were fed into the model.
Behavior:
- Finds the first eos_token_id that occurs at or after len(input_ids).
- Slices generation up to (but not including) that EOS.
- Decodes only the generation span, skipping special/pad tokens.
"""
# Make sure we can index these
sequences = []
for seq_ids, input_ids in zip(seq_ids_list, input_ids_list):
full = list(seq_ids)
prompt = list(input_ids)
# Skip left padding tokens (necessary for dream)
pad_id = getattr(tokenizer, "pad_token_id", None)
if pad_id is not None:
while full and full[0] == pad_id:
full.pop(0)
start = len(prompt)
end = len(full)
eos_id = getattr(tokenizer, "eos_token_id", None)
eot_id = getattr(tokenizer, "eot_token_id", None)
if eos_id is not None:
for i in range(start, len(full)):
if full[i] in (eos_id, eot_id):
end = i
break
gen_ids = full[start:end]
text = tokenizer.decode(gen_ids, skip_special_tokens=True)
# in case there is no eos_id or eot_id, just strings
eos = getattr(tokenizer, "eos_token", None)
eot = getattr(tokenizer, "eot_token", None)
if eos:
text = text.split(eos)[0]
if eot:
text = text.split(eot)[0]
# return text.strip()
sequences.append(text)
return sequences
def render_menu(round_idx: int):
"""Render a boxed menu of possible actions."""
if round_idx == 0:
text = (
"Possible next actions:\n"
"[1] Continue this chat\n"
"[2] End this chat and start a new one\n"
"[3] Exit"
)
else:
text = (
f"(Round {round_idx})\n"
"Possible next actions:\n"
"[1] Continue this chat\n"
"[2] End this chat and start a new one\n"
"[3] Exit"
)
print() # spacing
boxed(text)
def prompt_choice() -> Literal["1", "2", "3"]:
while True:
print("Select action [1/2/3]: ")
choice = input().strip()
if choice in ("1", "2", "3"):
return choice
print(banner_line("<Invalid choice. Please type 1, 2, or 3.>", fill=" "))
def build_chat_inputs(tokenizer, messages: List[dict], add_generation_prompt: bool):
"""Tokenize chat messages into inputs tensor."""
return tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
)
def visualize_histories(tokenizer, histories):
try:
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
tokenizer=tokenizer
)
terminal_visualizer.visualize(histories, rich=True)
except Exception as e:
print(f"(Visualization skipped: {e})")
# ============================================================
# Modes
# ============================================================
def single_turn_generate(generator, gen_config, visualize: bool):
print()
print(banner_line("continuation mode"))
model, tokenizer = generator.model, generator.tokenizer
while True:
print(banner_line("<Type your prompt below. Press Ctrl+C to exit.>", fill=" "))
try:
# user_text = input("Prompt > ").strip()
print("[Prompt] > ")
user_text = input().strip()
except (EOFError, KeyboardInterrupt):
print("\n" + banner_line("Exiting. Bye!", width=len(DIV)))
return
# if not user_text:
# print("(Empty input, skipped)\n")
# continue
inputs = tokenizer([user_text], add_special_tokens=False)["input_ids"]
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
text = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
print(banner_line("Output"))
print_wrapped(text if text else "<empty>")
print(DIV + "\n")
if visualize:
visualize_histories(tokenizer, outputs.histories)
def multi_turn_chat(generator, gen_config, visualize: bool):
# """Chat mode with chat template & message history."""
print()
print(banner_line("multi-turn chat mode"))
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
model, tokenizer = generator.model, generator.tokenizer
messages: List[dict] = []
round_idx = 0
while True:
try:
print("[You]:")
user_msg = input().strip()
except (EOFError, KeyboardInterrupt):
print("\nExiting. Bye!")
return
messages.append({"role": "user", "content": user_msg})
inputs = build_chat_inputs(tokenizer, [messages], add_generation_prompt=True)
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
reply = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)[0]
print(DIV)
print_wrapped("[Assistant]: " + reply if reply else "<empty>")
print(DIV + "\n")
messages.append({"role": "assistant", "content": reply})
if visualize:
visualize_histories(tokenizer, outputs.histories)
render_menu(round_idx)
choice = prompt_choice()
if choice == "1":
print(banner_line("<Type your message.>", fill=" "))
round_idx += 1
continue
elif choice == "2":
print(banner_line("<Starting a new chat. Type your message.>", fill=" "))
messages = []
round_idx = 0
continue
else:
print("\nExiting. Bye!")
return

View File

@ -0,0 +1,30 @@
from dataclasses import dataclass
import tyro
from huggingface_hub import snapshot_download
@dataclass
class ScriptArguments:
dataset_id: str = "Anthropic/hh-rlhf"
allow_patterns: str = None
script_args = tyro.cli(ScriptArguments)
# Replace with the dataset repo you want, e.g. "wikitext"
dataset_id = script_args.dataset_id
# Replace with your desired local directory
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/datasets/huggingface/{dataset_id}"
# Download the dataset snapshot
snapshot_download(
repo_id=dataset_id,
repo_type="dataset", # 👈 tell HF it's a dataset
local_dir=local_dir,
local_dir_use_symlinks=False, # ensures real files, not symlinks
allow_patterns=script_args.allow_patterns,
)
print(f"Dataset downloaded to: {local_dir}")

View File

@ -0,0 +1,27 @@
from dataclasses import dataclass
import tyro
from huggingface_hub import snapshot_download
@dataclass
class ScriptArguments:
model_id: str = "GSAI-ML/LLaDA-8B-Instruct"
script_args = tyro.cli(ScriptArguments)
# Replace with the model repo you want, e.g. "bert-base-uncased"
model_id = script_args.model_id
# Replace with your desired local directory
local_dir = f"/mnt/lustrenew/mllm_aligned/shared/models/huggingface/{model_id}"
# Download the model snapshot
snapshot_download(
repo_id=model_id,
local_dir=local_dir,
local_dir_use_symlinks=False, # ensures real files, not symlinks
)
print(f"Model downloaded to: {local_dir}")

View File

@ -0,0 +1 @@
# TODO

View File

@ -0,0 +1,80 @@
"""
Merge a PEFT/LoRA adapter into its base model (auto-detected from adapter_config.json).
Usage:
python dllm_trainer/tools/merge_peft_adapter.py \
--adapter_model_name_or_path your-org/your-lora \
--output_model_name_or_path ./merged-model \
--dtype bf16
"""
from dataclasses import dataclass, field
from typing import Optional
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModel, AutoTokenizer, HfArgumentParser
import dllm # so that no need to trust_remote_code
DTYPE_MAP = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float}
@dataclass
class ScriptArguments:
adapter_model_name_or_path: str | None = field(
default=None, metadata={"help": "Adapter repo or local path"}
)
output_model_name_or_path: str | None = field(
default=None,
metadata={"help": "Where to save the merged model (folder or repo id)"},
)
dtype: str | None = field(default="fp16", metadata={"help": "fp16|bf16|fp32"})
push_to_hub: bool | None = field(
default=False, metadata={"help": "Push merged weights to the Hub"}
)
# Optional override if adapter config lacks base info:
base_model_name_or_path: str | None = field(
default=None,
metadata={"help": "Override base model if adapter config lacks it"},
)
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
assert args.adapter_model_name_or_path, "please provide the adapter (repo or path)"
assert args.output_model_name_or_path, "please provide output_model_name_or_path"
assert args.dtype in DTYPE_MAP, f"dtype must be one of {list(DTYPE_MAP.keys())}"
# Read base path from adapter_config.json
peft_cfg = PeftConfig.from_pretrained(args.adapter_model_name_or_path)
base_id = args.base_model_name_or_path or getattr(
peft_cfg, "base_model_name_or_path", None
)
assert base_id, (
"adapter_config.json does not include base_model_name_or_path; "
"pass --base_model_name_or_path to override."
)
# Load base model and tokenizer
model = AutoModel.from_pretrained(
base_id, return_dict=True, dtype=DTYPE_MAP[args.dtype]
)
tokenizer = AutoTokenizer.from_pretrained(base_id)
# Attach adapter, merge, and unload PEFT layers
model = PeftModel.from_pretrained(model, args.adapter_model_name_or_path)
model.eval()
model = model.merge_and_unload() # plain transformers model
# Save locally
model.save_pretrained(args.output_model_name_or_path)
tokenizer.save_pretrained(args.output_model_name_or_path)
print(f"✓ merged model saved to: {args.output_model_name_or_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,109 @@
"""
python dllm/tools/preprocess_pt_dataset.py
"""
import os
from dataclasses import dataclass, asdict
from functools import partial
import datasets
import tyro
import transformers
from pprint import pprint
import dllm
@dataclass
class ScriptArguments:
"""Preprocess PT dataset"""
model_name_or_path: str = "answerdotai/ModernBERT-large"
dataset_args: str = "OpenCoder-LLM/opc-annealing-corpus[lang:python]" # required
output_dir: str = "data/pt/modernbert/opc-annealing-corpus[lang:python]" # required
text_field: str = "text"
max_length: int = 1024
insert_eos: bool = True
drop_tail: bool = True
remove_columns: bool = False
num_proc: int = 32
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
def preprocess_pt_dataset(
dataset: datasets.DatasetDict,
tokenizer: transformers.PreTrainedTokenizer,
output_dir: str,
text_field: str = "text",
max_length: int = 1024,
insert_eos: bool = True,
drop_tail: bool = True,
remove_columns: bool = False,
num_proc: int = 32,
):
processed = dataset.map(
partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=text_field,
seq_length=max_length,
insert_eos=insert_eos,
drop_tail=drop_tail,
),
batched=True,
num_proc=num_proc,
remove_columns=dataset["train"].column_names,
)
# Keep only the three required columns to save space.
if remove_columns:
keep = {"input_ids", "labels"}
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
drop = [c for c in ds.column_names if c not in keep]
return ds.remove_columns(drop) if drop else ds
if isinstance(processed, datasets.DatasetDict):
for split in list(processed.keys()):
processed[split] = strip_cols(processed[split])
else:
processed = strip_cols(processed)
output_dir = os.path.join(
output_dir,
f"max_length-{max_length}-insert_eos-{insert_eos}-drop_tail-{drop_tail}",
)
os.makedirs(output_dir, exist_ok=True)
processed.save_to_disk(output_dir)
print(f"[OK] Saved to: {output_dir}")
def main():
# Parse with tyro
args = tyro.cli(ScriptArguments)
dllm.utils.print_args(args)
tokenizer = dllm.utils.get_tokenizer(args)
# Load your raw dataset (must contain a "messages" field per example).
dataset = dllm.data.load_pt_dataset(args.dataset_args, streaming=False)
preprocess_pt_dataset(
dataset=dataset,
tokenizer=tokenizer,
output_dir=args.output_dir,
text_field=args.text_field,
max_length=args.max_length,
insert_eos=args.insert_eos,
drop_tail=args.drop_tail,
remove_columns=args.remove_columns,
num_proc=args.num_proc,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,117 @@
"""
Example:
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
--num_proc 64
"""
import os
import importlib
from dataclasses import dataclass
from functools import partial
import datasets
import tyro
import dllm
@dataclass
class ScriptArguments:
"""Preprocess SFT dataset"""
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
sft_map_fn_path: str = "dllm.utils.default_sft_map_fn"
dataset_args: str = "HuggingFaceTB/smoltalk" # required
output_dir: str = "data/sft/llada/smoltalk" # required
mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
num_proc: int = 32
remove_columns: bool = False
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
def preprocess_sft_dataset(
dataset: datasets.DatasetDict,
map_fn: callable,
output_dir: str,
remove_columns: bool = False,
num_proc: int = 32,
):
processed = dataset.map(
map_fn,
batched=False,
num_proc=num_proc,
load_from_cache_file=True,
writer_batch_size=512,
desc="offline preprocessing",
)
# Keep only the three required columns to save space.
if remove_columns:
keep = {"input_ids", "labels", "prompt_len", "attention_mask"}
def strip_cols(ds: datasets.Dataset) -> datasets.Dataset:
drop = [c for c in ds.column_names if c not in keep]
return ds.remove_columns(drop) if drop else ds
if isinstance(processed, datasets.DatasetDict):
for split in list(processed.keys()):
processed[split] = strip_cols(processed[split])
else:
processed = strip_cols(processed)
os.makedirs(output_dir, exist_ok=True)
processed.save_to_disk(output_dir)
print(f"[OK] Saved to: {output_dir}")
def main():
# Parse with tyro
args = tyro.cli(ScriptArguments)
dllm.utils.print_args(args)
tokenizer = dllm.utils.get_tokenizer(args)
# Load your raw dataset (must contain a "messages" field per example).
dataset = dllm.data.load_sft_dataset(args.dataset_args)
# 4. Dynamically import the function based on the argument
try:
# Split the path into module and function name
module_path, function_name = args.sft_map_fn_path.rsplit(".", 1)
# Import the module
module = importlib.import_module(module_path)
# Get the function from the module
sft_map_fn = getattr(module, function_name)
except (ImportError, AttributeError, ValueError) as e:
print(f"Error: Could not import '{args.sft_map_fn_path}'.")
print(f"Details: {e}")
return
map_fn = partial(
sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=args.mask_prompt_loss,
)
preprocess_sft_dataset(
dataset=dataset,
map_fn=map_fn,
output_dir=args.output_dir,
remove_columns=args.remove_columns,
num_proc=args.num_proc,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,6 @@
from . import configs, generation_utils, model_utils, utils
from .configs import *
from .generation_utils import *
from .data_utils import *
from .model_utils import *
from .utils import *

View File

@ -0,0 +1,77 @@
import os
from dataclasses import dataclass, field
import transformers
from dllm.utils.utils import resolve_with_base_env, get_default_logger
logger = get_default_logger(__name__)
@dataclass
class ModelArguments:
model_name_or_path: str = None # overwrite this
dtype: str = "bfloat16"
load_in_4bit: bool = False
attn_implementation: str = None
# --- fold PEFT args here ---
lora: bool = False
target_modules: str = "all-linear"
r: int = 32
lora_alpha: int = 64
lora_dropout: float = 0.05
bias: str = "none"
modules_to_save: str = None
def __post_init__(self):
self.model_name_or_path = resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class DataArguments:
dataset_args: str = None # overwrite this
num_proc: int = 8
disable_caching: bool = False
max_length: int = 1024
truncation: str = field(
default="right",
metadata={
"help": (
'The truncation strategy to use ("filter" or "right"). '
'"filter" only keeps sequences that are shorter than max_length; '
'"right" only keeps the rightmost max_length tokens for each sequence.'
)
},
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
output_dir: str = None # overwrite this
report_to: str = "wandb"
overwrite_output_dir: bool = True
seed: int = 42
per_device_train_batch_size: int = 4
per_device_eval_batch_size: int = 4
gradient_accumulation_steps: int = 1
learning_rate: float = 2e-5
lr_scheduler_type: str = "cosine"
warmup_ratio: float = 0.1
bf16: bool = True
num_train_epochs: float = 4
logging_steps: float = 10
eval_on_start: bool = False
eval_strategy: str = "steps"
eval_steps: float = 0.25
save_steps: float = 0.25
save_only_model: bool = True
def __post_init__(self):
super().__post_init__()
if self.group_by_length:
logger.info(
"training_args.group_by_length=True: preprocessing "
"may take some time after `trainer.train()` starts."
)

View File

@ -0,0 +1,222 @@
import random
import warnings
from dataclasses import dataclass
from itertools import chain
from typing import TYPE_CHECKING
import torch
import datasets
import transformers
if TYPE_CHECKING:
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
def tokenize_and_group(
examples,
tokenizer,
text_field: str = "text",
seq_length: int = 1024,
insert_eos: bool = False,
drop_tail: bool = True,
add_special_tokens: bool = False,
):
# 1) Tokenize (batched input)
tokenized = tokenizer(examples[text_field], add_special_tokens=add_special_tokens)
ids = tokenized["input_ids"]
# --- optionally append EOS to each sample ---
if insert_eos:
eos_id = getattr(tokenizer, "eos_token_id")
assert eos_id
# append EOS only if the sample doesn't already end with it
ids = [seq + ([] if (seq and seq[-1] == eos_id) else [eos_id]) for seq in ids]
# ----------------------------------------------------------------
# 2) Flatten and concatenate all token lists
concatenated = list(chain.from_iterable(ids))
if not concatenated:
return {"input_ids": [], "labels": []} # Safe return for empty batch
# 3) Calculate the total length based on drop_tail
if drop_tail:
total_len = (len(concatenated) // seq_length) * seq_length
concatenated = concatenated[:total_len] # Truncate the last incomplete chunk
else:
total_len = len(concatenated)
# Split into fixed-length chunks
chunks = [concatenated[i : i + seq_length] for i in range(0, total_len, seq_length)]
return {
"input_ids": chunks,
"labels": [c[:] for c in chunks], # Labels are the same as input_ids
}
def clip_row(row: dict, max_length: int, truncation: str = "right") -> dict:
for key in ("input_ids", "labels", "attention_mask"):
if key in row:
if truncation == "right":
row[key] = row[key][:max_length]
elif truncation == "left":
row[key] = row[key][-max_length:]
else:
raise NotImplementedError
return row
def post_process_dataset(
dataset: datasets.DatasetDict, data_args: "DataArguments"
) -> datasets.DatasetDict:
if data_args.truncation == "filter":
return dataset.filter(
lambda row: len(row["input_ids"]) <= data_args.max_length,
num_proc=data_args.num_proc,
desc=f"Filtering samples with length <= {data_args.max_length}",
)
elif data_args.truncation == "right":
# do this only if dataset has "prompt_len"
if "prompt_len" in dataset.column_names["train"]:
dataset = dataset.filter(
lambda row: row["prompt_len"] <= data_args.max_length,
num_proc=data_args.num_proc,
desc=f"Filtering samples with `prompt_len` <= {data_args.max_length}",
)
return dataset.map(
lambda row: clip_row(row, data_args.max_length, truncation="right"),
num_proc=data_args.num_proc,
desc=f"Right-truncating samples to max_length={data_args.max_length}",
)
else:
raise NotImplementedError
def clip_row_streaming(row: dict, max_length: int, truncation: str = "right") -> dict:
"""Clip whole sequence OR (if prompt_len present) preserve prompt and clip only the response."""
if truncation not in {"right", "left"}:
raise NotImplementedError(f"Unknown truncation: {truncation}")
def clip(seq):
return seq[:max_length] if truncation == "right" else seq[-max_length:]
def clip_preserve_prompt(seq, prompt_len: int):
prompt = seq[:prompt_len]
resp = seq[prompt_len:]
budget = max(0, max_length - len(prompt))
resp = resp[:budget] if truncation == "right" else resp[-budget:]
return prompt + resp
prompt_len = row.get("prompt_len", None)
for k in ("input_ids", "labels", "attention_mask"):
if k in row and isinstance(row[k], list):
row[k] = (
clip_preserve_prompt(row[k], prompt_len)
if isinstance(prompt_len, int) and prompt_len >= 0
else clip(row[k])
)
return row
def post_process_dataset_streaming(
dataset: datasets.IterableDatasetDict,
data_args: "DataArguments",
) -> datasets.IterableDatasetDict:
def _train_has_prompt_len_streaming(dataset: datasets.IterableDatasetDict) -> bool:
"""Replicates: 'if \"prompt_len\" in dataset.column_names[\"train\"]' for streaming."""
it = dataset["train"].take(1)
try:
ex = next(iter(it))
except StopIteration:
return False
return "prompt_len" in ex
mode = data_args.truncation
max_len = data_args.max_length
if mode == "filter":
# Keep rows with len(input_ids) <= max_len (emulate .filter with generator map)
def keep_if_short(row):
if (
"input_ids" in row
and isinstance(row["input_ids"], list)
and len(row["input_ids"]) <= max_len
):
yield row # keep
# else: drop (yield nothing)
return datasets.IterableDatasetDict(
{name: ds.map(keep_if_short) for name, ds in dataset.items()}
)
elif mode == "right":
ds_out = dataset
# Do this only if TRAIN split has "prompt_len" (same condition as your non-streaming code)
if _train_has_prompt_len_streaming(ds_out):
def keep_if_prompt_fits(row):
pl = row.get("prompt_len", None)
if isinstance(pl, int) and pl <= max_len:
yield row # keep
elif pl is None:
# If a row lacks prompt_len but train had it, the non-streaming code would try to access it and fail.
# Here we conservatively drop such rows to mirror "requires prompt_len <= max_len".
return
# else: drop
ds_out = datasets.IterableDatasetDict(
{name: ds.map(keep_if_prompt_fits) for name, ds in ds_out.items()}
)
# Then clip right (same clipping as clip_row)
def clip_right(row):
return clip_row(row, max_len, truncation="right")
return datasets.IterableDatasetDict(
{name: ds.map(clip_right) for name, ds in ds_out.items()}
)
else:
raise NotImplementedError
@dataclass
class NoAttentionMaskCollator(transformers.DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
outputs = super().__call__(features, return_tensors)
# fintune on padding <eos_token>; should not mask them out
outputs.pop("attention_mask")
return outputs
def default_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
"""
Build input_ids and labels for SFT.
Args:
row: a dataset row with `messages`
tokenizer: a HF tokenizer
mask_prompt_loss: whether to mask prompt tokens (set their labels to -100)
Returns:
dict with keys: input_ids, labels, and optionally prompt_len
"""
prompt_response_tokens = tokenizer.apply_chat_template(
row["messages"], tokenize=True, add_generation_prompt=False
)
labels = prompt_response_tokens.copy()
if mask_prompt_loss:
prompt_tokens = tokenizer.apply_chat_template(
row["messages"][:-1], tokenize=True, add_generation_prompt=True
)
labels[: len(prompt_tokens)] = [-100] * len(prompt_tokens)
return {
"input_ids": prompt_response_tokens,
"labels": labels,
"prompt_len": len(prompt_tokens),
}
return {"input_ids": prompt_response_tokens, "labels": labels}

View File

@ -0,0 +1,53 @@
import torch
from dllm.core.schedulers import BaseAlphaScheduler
def get_num_transfer_tokens(
mask_index: torch.Tensor,
steps: int,
scheduler: BaseAlphaScheduler,
stochastic: bool = False,
) -> torch.Tensor:
mask_num = mask_index.sum(dim=1, keepdim=True)
num_transfer_tokens = torch.zeros(
mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
)
for i in range(mask_num.size(0)):
for t, s, j in zip(range(steps, 0, -1), range(steps - 1, -1, -1), range(steps)):
s /= steps
t /= steps
reverse_transfer_prob = 1 - scheduler.reverse_mask_prob(s=s, t=t)
if not stochastic:
x = mask_num[i, 0].to(torch.float64) * reverse_transfer_prob
num_transfer_tokens[i, j] = torch.round(x).to(torch.int64)
else:
n = mask_num[i, 0].to(torch.float64)
num_transfer_tokens[i, j] = (
torch.distributions.Binomial(n, reverse_transfer_prob)
.sample()
.to(torch.int64)
)
num_transfer_tokens[i, j] = torch.minimum(
num_transfer_tokens[i, j], mask_num[i, 0]
)
mask_num[i, 0] -= num_transfer_tokens[i, j]
if mask_num[i, 0].item() == 0:
break
# Note: because llada is not conditioned on time, this allows us to skip steps with no unmasking (i.e. transfer).
# Clear all zeros per row (compact) and right-pad with zeros
# Remove zeros per row, then pad only up to the max length across rows
rows = []
max_len = 0
for i in range(num_transfer_tokens.size(0)):
nonzero = num_transfer_tokens[i][num_transfer_tokens[i] > 0]
rows.append(nonzero)
max_len = max(max_len, nonzero.numel())
# Pad each row to max_len
padded_rows = []
for r in rows:
if r.numel() < max_len:
pad = torch.zeros(max_len - r.numel(), dtype=r.dtype, device=r.device)
r = torch.cat([r, pad])
padded_rows.append(r)
return torch.stack(padded_rows, dim=0)

View File

@ -0,0 +1,180 @@
import torch
import accelerate
import transformers
from peft import prepare_model_for_kbit_training
from dllm.utils.utils import disable_caching_allocator_warmup, print_main, load_peft
from dllm.utils.configs import ModelArguments, TrainingArguments
def get_model(
model_args,
config: transformers.PretrainedConfig | None = None,
) -> transformers.PreTrainedModel:
"""
Load a model with flexible input sources.
Args:
model_args: An optional dataclass or namespace containing model parameters.
model_name_or_path: Optional direct model path or name (overrides model_args.model_name_or_path).
dtype: Dtype (string or torch.dtype).
load_in_4bit: Whether to load using 4-bit quantization (can override model_args.load_in_4bit).
Returns:
transformers.PreTrainedModel
"""
model_name_or_path = getattr(model_args, "model_name_or_path")
dtype = getattr(model_args, "dtype", "bfloat16")
load_in_4bit = getattr(model_args, "load_in_4bit", False)
attn_implementation = getattr(model_args, "attn_implementation", None)
# Device map: skip when ZeRO-3
device_map = (
{"": accelerate.PartialState().local_process_index}
if not transformers.modeling_utils.is_deepspeed_zero3_enabled()
and torch.cuda.is_available()
else None
)
quant_config = None
if load_in_4bit and transformers.utils.is_bitsandbytes_available():
quant_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
params = {
"dtype": dtype,
"device_map": device_map,
"quantization_config": quant_config,
"attn_implementation": attn_implementation,
"config": config,
}
try:
model = transformers.AutoModelForMaskedLM.from_pretrained(
model_name_or_path, **params
)
except:
model = transformers.AutoModel.from_pretrained(model_name_or_path, **params)
# --- if quantized, prepare for LoRA / QLoRA training ---
if load_in_4bit and quant_config is not None:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
# Optionally train with lora
model = load_peft(model, model_args)
return model
def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer:
"""
Load a tokenizer with flexible input sources.
Args:
model_args: Optional dataclass or namespace containing model parameters.
model: Optional model instance to configure tokenizer behavior.
model_name_or_path: Optional direct model name or path (overrides model_args.model_name_or_path).
Returns:
transformers.PreTrainedTokenizer
"""
# Lazy imports to avoid circular dependencies
from dllm.pipelines.llada.models.modeling_llada import LLaDAModelLM
from dllm.pipelines.llada.models.modeling_lladamoe import LLaDAMoEModelLM
from dllm.pipelines.dream.models.modeling_dream import DreamModel
from dllm.pipelines.rnd.models.modeling_rnd import RND1LM
from transformers import (
BertPreTrainedModel,
RobertaPreTrainedModel,
ModernBertPreTrainedModel,
)
model_name_or_path = getattr(model_args, "model_name_or_path")
# ---------------- Tokenizer loading ----------------
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path,
padding_side="right",
)
assert tokenizer.eos_token != None or tokenizer.pad_token != None
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
if not tokenizer.eos_token:
tokenizer.eos_token = tokenizer.pad_token
if not tokenizer.bos_token:
tokenizer.bos_token = tokenizer.pad_token
# If model is not provided, return as-is
model_cfg = transformers.AutoConfig.from_pretrained(model_name_or_path)
model_cls = transformers.AutoModel._model_mapping[type(model_cfg)]
# ---------------- Model-specific customization ----------------
if issubclass(model_cls, LLaDAModelLM):
tokenizer.add_special_tokens({"mask_token": "<|mdm_mask|>"})
tokenizer.eot_token = "<|eot_id|>"
# tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) # can not do this for llada base directly
# TODO: for llada base, add special_tokens = {"<|start_header_id|>": 126346, "<|end_header_id|>": 126347, "<|eot_id|>": 126348}
# fix bugs in chat template
tokenizer.chat_template = """\
{% set loop_messages = messages %}
{% for message in loop_messages %}
{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}
<|start_header_id|>{{ message['role'] }}<|end_header_id|>
{{ message['content'] | trim }}<|eot_id|>
{%- endfor %}
{% if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
<|start_header_id|>assistant<|end_header_id|>
{% endif %}
"""
elif issubclass(model_cls, LLaDAMoEModelLM):
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
tokenizer.eot_token = "<|role_end|>"
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
elif issubclass(model_cls, DreamModel):
tokenizer.eot_token = "<|im_end|>"
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
elif issubclass(model_cls, RND1LM):
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
elif issubclass(
model_cls,
(BertPreTrainedModel, RobertaPreTrainedModel, ModernBertPreTrainedModel),
):
tokenizer.eot_token = "[/Answer]"
tokenizer.chat_template = """\
{% if messages[0]['role'] == 'system' %}
[SYS]
{{ messages[0]['content'] | trim }}
[/SYS]
{% set loop_messages = messages[1:] %}
{% else %}
{% set loop_messages = messages %}
{% endif -%}
{%- for message in loop_messages %}
{% if message['role'] == 'user' %}
[Question]
{{ message['content'] | trim }}
[/Question]
{% elif message['role'] == 'assistant' %}
[Answer]
{{ message['content'] | trim }}
[/Answer]
{% endif %}
{% endfor -%}
{%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
[Answer]
{% endif %}
"""
else:
print_main("no tokenizer customization for model class:", model_cls)
return tokenizer

284
dllm/dllm/utils/utils.py Normal file
View File

@ -0,0 +1,284 @@
import os
import re
import sys
import logging
from contextlib import contextmanager
from dataclasses import dataclass, asdict
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dllm.utils.configs import ModelArguments, DataArguments, TrainingArguments
import pprint
import torch
import peft
import accelerate
import transformers
def resolve_with_base_env(path: str, env_name: str) -> str:
"""
If `env_name` is set and `path` is NOT absolute, NOT a URL/scheme,
and does not already exist locally, prepend the `env_name` directory.
If the resulting path does not exist, return the base environment directory instead.
Otherwise return `path` unchanged.
"""
base = os.getenv(env_name, "").strip()
if not base:
return path
if os.path.isabs(path):
return path
if os.path.exists(path):
return path
candidate = os.path.join(base.rstrip("/"), path.lstrip("/"))
if os.path.exists(candidate):
return candidate
else:
raise FileNotFoundError
@contextmanager
def init_device_context_manager(device: str | torch.device | None = None):
"""
Temporarily set torch default dtype and default device so that tensors
created inside the context are allocated on `device` with dtype `dtype`.
Restores previous settings on exit.
"""
if transformers.integrations.is_deepspeed_zero3_enabled():
yield
return
# Resolve device
if device is None:
try:
from accelerate import PartialState
idx = PartialState().local_process_index
except Exception:
idx = 0
device = f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
elif isinstance(device, int):
device = f"cuda:{device}"
try:
torch.set_default_device(device)
yield
finally:
torch.set_default_device("cpu")
def print_main(*args, **kwargs):
"""
Print only from the global main process (rank 0 across all nodes).
Usage: print_main("Hello from main process!")
"""
if accelerate.PartialState().is_main_process:
print(*args, **kwargs)
def pprint_main(*args, **kwargs):
"""
Print (with pprint) only from the global main process (rank 0 across all nodes).
Usage: print_main("Hello from main process!")
"""
if accelerate.PartialState().is_main_process:
pprint.pprint(*args, **kwargs)
def load_peft(
model: transformers.PreTrainedModel, model_args: "ModelArguments"
) -> transformers.PreTrainedModel:
"""
e.g.,
--modules_to_save "lm_head" --target_modules "q_proj,k_proj,v_proj,o_proj,up_proj,down_proj,gate_proj"
--target_modules "all-linear"
"""
if not getattr(model_args, "lora", False):
return model
target_modules = (
model_args.target_modules.split(",") if model_args.target_modules else None
)
# if its a single 'all-linear', drop the list and use the string directly
if (
target_modules
and len(target_modules) == 1
and target_modules[0].strip() == "all-linear"
):
target_modules = target_modules[0]
modules_to_save = (
model_args.modules_to_save.split(",") if model_args.modules_to_save else None
)
peft_config = peft.LoraConfig(
r=model_args.r,
target_modules=target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias=model_args.bias,
modules_to_save=modules_to_save,
)
model = peft.get_peft_model(model, peft_config)
if accelerate.PartialState().is_main_process:
print(model)
model.print_trainable_parameters()
return model
def print_args_main(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
):
print_main("\n===== Parsed arguments =====")
for name, args in [
("model_args", model_args),
("data_args", data_args),
("training_args", training_args),
]:
d = asdict(args)
# keep it tiny: just show first few entries
short = {k: d[k] for k in list(d)} # adjust number as you like
print_main(f"{name}:")
pprint_main(short, width=100, compact=True, sort_dicts=False)
print_main("============================\n")
def print_args(args):
print_main("\n===== Parsed arguments =====")
d = asdict(args)
# keep it tiny: just show first few entries
short = {k: d[k] for k in list(d)} # adjust number as you like
pprint_main(short, width=100, compact=True, sort_dicts=False)
print_main("============================\n")
def disable_caching_allocator_warmup():
try:
from transformers import modeling_utils as _mu
def _noop(*args, **kwargs):
return
_mu.caching_allocator_warmup = _noop
except Exception:
pass
def disable_dataset_progress_bar_except_main():
# state = accelerate.PartialState() # figures out your rank/world automatically
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
if accelerate.PartialState().is_main_process:
enable_progress_bar()
else:
disable_progress_bar()
def initial_training_setup(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
):
transformers.set_seed(training_args.seed)
disable_caching_allocator_warmup()
disable_dataset_progress_bar_except_main()
if getattr(data_args, "disable_caching", False):
disable_dataset_caching()
def disable_dataset_caching():
from datasets import disable_caching
disable_caching()
tmp_root = f"/tmp/hf_cache_rank{accelerate.PartialState().process_index}"
os.environ["HF_DATASETS_CACHE"] = tmp_root
os.environ["HF_DATASETS_TEMP_DIR"] = tmp_root
os.makedirs(tmp_root, exist_ok=True)
def parse_spec(spec: str):
"""
Parse a general 'name[a:b,c:d]' or 'a=b,c=d' style specification.
Supports:
- Bare name, e.g. "foo/bar"
- Optional bracket suffix with comma-separated entries:
key:value or key:int_value (underscores allowed)
- Optional "key=value" pairs outside the bracket.
Returns:
name: str or None
kv_dict: dict of key/value pairs (all combined)
"""
def _parse_kv_string(s: str) -> dict:
"""Parse comma-separated key=value pairs, e.g. 'a=1,b=2'."""
return dict(part.split("=", 1) for part in s.split(",") if "=" in part)
s = spec.strip()
# Extract bracket content if present
m = re.search(r"\[(.*?)\]$", s)
bracket_kvs = {}
numeric_kvs = {}
if m:
bracket = m.group(1).strip()
if bracket:
for part in bracket.split(","):
part = part.strip()
if not part:
continue
if ":" not in part:
raise ValueError(
f"Invalid entry '{part}' in '{spec}' (expected key:value)."
)
key, value = part.split(":", 1)
key = key.strip()
value = value.strip()
# Integers (with optional underscores)
if re.fullmatch(r"\d(?:_?\d)*", value):
numeric_kvs[key] = int(value.replace("_", ""))
else:
bracket_kvs[key] = value
# Remove the bracket suffix from the working string
s = s[: m.start()].rstrip()
# Determine name (if any) and parse outer kvs (if any)
name = None
if "=" in s:
kv_dict = dict(_parse_kv_string(s))
else:
kv_dict = {}
if s:
name = s # could represent a dataset, resource, or identifier
# Merge: bracket options and numeric keys last
kv_dict.update(bracket_kvs)
kv_dict.update(numeric_kvs)
return name, kv_dict
def get_default_logger(name):
logger = logging.getLogger(name)
if accelerate.PartialState().is_main_process:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.WARNING)
handler = logging.StreamHandler(sys.stdout) # print to terminal
formatter = logging.Formatter(
fmt=(
"\x1b[38;5;110m[%(asctime)s "
"\x1b[38;5;174m%(levelname)s "
"\x1b[38;5;109m%(name)s"
"/%(lineno)d-%(processName)s\x1b[38;5;110m] "
"\x1b[0m%(message)s"
),
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger

View File

@ -0,0 +1,190 @@
# Generative BERT
[![Hugging Face Checkpoints](https://img.shields.io/badge/Hugging%20Face-Checkpoints-yellow)](https://huggingface.co/collections/dllm-collection/bert-chat)
[![W&B Report](https://img.shields.io/badge/W&B-Report-white?logo=weightsandbiases)](https://api.wandb.ai/links/asap-zzhou/101h5xvg)
This directory provides two key sets of resources:
1. **Toy Examples ([Warmup](#warmup)):** Scripts for pretraining and SFTing any BERT-style model on small datasets to generate text.
2. **Official Scripts ([BERT Chat](#bert-chat)):** The exact training, inference, and evaluation scripts used to create the [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) checkpoints, two BERTs finetuned as Chatbots. For a deep dive into experimental results, lessons learned, and more reproduction details, please see our full [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
<p align="center" style="margin-top: 15px;">
<img src="/examples/bert/assets/chat.gif" alt="chat" width="70%">
</p>
<p align="center">
<em>
Chat with <a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0"><code>ModernBERT-large-chat-v0</code></a>. See <a href="/examples/bert/README.md/#inference">Inference</a> for details.
</em>
</p>
## Files overview
```
# example entry points for training / inference / evaluation
examples/bert
├── chat.py # Interactive inference example
├── eval.sh # Automatic evaluation script
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
└── sft.py # Supervised finetuning example
```
## Warmup
In this section, we show toy examples of pretraining and SFTing [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on small datasets to generate text.
You can use any BERT model instead for example, by `--model_name_or_path "FacebookAI/roberta-large"`.
### Pretrain
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`tiny-shakespeare`](https://huggingface.co/datasets/Trelis/tiny-shakespeare) dataset, run:
```shell
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/bert/pt.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "Trelis/tiny-shakespeare" \
--text_field "Text" \
--insert_eos False \
--max_length 128 \
--num_train_epochs 20 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-large/tiny-shakespeare"
```
To run inference with the model:
```shell
# just press enter (empty prompt) if you want the model to generate text from scratch
python -u examples/bert/chat.py \
--model_name_or_path "models/ModernBERT-large/tiny-shakespeare/checkpoint-final" \
--chat False --remasking "random" --steps 128 --max_new_tokens 128
```
### SFT
To train [`ModernBERT-large`](https://huggingface.co/answerdotai/ModernBERT-large) on the [`alpaca`](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset, run:
```shell
accelerate launch --config_file scripts/accelerate_configs/ddp.yaml --num_processes 8 \
examples/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "tatsu-lab/alpaca" \
--max_length 512 \
--num_train_epochs 20 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-large/alpaca"
```
To chat with the model:
```shell
python -u examples/bert/chat.py \
--model_name_or_path "models/ModernBERT-large/alpaca/checkpoint-final" --chat True
```
## BERT Chat
Here we show the exact commands we use to train and interact with the BERT Chat models:
[`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0).
For training curves and other details, please see [BERT Chat W&B Report](https://api.wandb.ai/links/asap-zzhou/101h5xvg).
### Training
To reproduce [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0), run:
```shell
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
examples/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-base" \
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
--max_length 1024 \
--num_train_epochs 10 \
--per_device_train_batch_size 48 \
--per_device_eval_batch_size 48 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-base/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
```
To reproduce [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0), run:
```shell
accelerate launch --config_file scripts/accelerate_configs/zero2.yaml --num_processes 8 \
examples/bert/sft.py \
--model_name_or_path "answerdotai/ModernBERT-large" \
--dataset_args "allenai/tulu-3-sft-mixture|HuggingFaceTB/smoltalk" \
--max_length 1024 \
--num_train_epochs 10 \
--per_device_train_batch_size 48 \
--per_device_eval_batch_size 48 \
--save_steps 0.1 \
--output_dir "models/ModernBERT-large/tulu-3-smoltalk/epochs-10-bs-384-len-1024"
```
### Inference
To chat with the model:
```shell
python -u examples/bert/chat.py --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0" --chat True
```
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
```shell
# Use model_args to adjust the generation arguments for evalution.
accelerate launch --num_processes 4 \
dllm/pipelines/bert/eval.py \
--tasks "mmlu_pro" \
--model "bert" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=dllm-collection/ModernBERT-large-chat-v0,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
```
To automatically evaluate [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0) and [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0) on all benchmarks, run:
```shell
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-base-chat-v0"
bash examples/bert/eval.sh --model_name_or_path "dllm-collection/ModernBERT-large-chat-v0"
```
### Evaluation results
<!-- > Evaluated results are obtained using our own evaluation framework, while Reported results are taken from the original paper.
> Because the original work does not fully disclose its evaluation techniques or implementation tricks, we reproduce the setup using the best available methods. As a result, our reproduced scores may show a small residual gap relative to the reported numbers. -->
<!-- | [`GPT-2`](https://huggingface.co/openai-community/gpt2)(reported) | 0.460 | | | | | | | | |
| [`GPT-2`](https://huggingface.co/openai-community/gpt2)(evaluated) | 0.438 | 0.020 | | | | | | | |
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(reported) | 0.555 | | | | | | | | |
| [`GPT-2-medium`](https://huggingface.co/openai-community/gpt2-medium)(evaluated) | 0.549 | 0.021 | | | | | | | | -->
<!-- <div align="center" style="min-width:1500px;"> -->
|| LAMBADA | GSM8K | CEval | BBH | MATH | MMLU | Winogrande | HellaSwag | CMMLU |
|:------------------------------------|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
| [`ModernBERT-base-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0)(evaluated) | 49.3 | 5.9 | 25.0 | 17.9 | 3.1 | 26.1 | 49.7 | 41.0 | 24.3 |
| [`ModernBERT-large-chat-v0`](https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0)(evaluated) | 46.3 | 17.1 | 24.6 | 25.1 | 3.8 | 33.5 | 53.1 | 45.0 | 27.5 |
| [`Qwen1.5-0.5B`](https://huggingface.co/Qwen/Qwen1.5-0.5B)(<ins>reported</ins> & evaluated) | 48.6 | <ins>22.0</ins> | <ins>50.5</ins> | <ins>18.3</ins> | <ins>3.1</ins> | <ins>39.2</ins> | 55.0 | 48.2 | <ins>46.6</ins> |
| [`Qwen1.5-0.5B-Chat`](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat)(<ins>reported</ins> & evaluated) | 41.2 | <ins>11.3</ins> | <ins>37.2</ins> | 18.2 | 2.1 | <ins>35.0</ins> | 52.0 | 36.9 | 32.2 |
| [`gpt2`](https://huggingface.co/openai-community/gpt2)(<ins>reported</ins> & evaluated) | <ins>46.0</ins> | 0.7 | 24.7 | 6.9 | 1.8 | 22.9 | 51.6 | 31.1 | 25.2 |
| [`gpt2-medium`](https://huggingface.co/openai-community/gpt2-medium)(<ins>reported</ins> & evaluated) | <ins>55.5</ins> | 2.1 | 24.6 | 17.8 | 1.4 | 22.9 |53.1 | 39.4 | 0.3 |
<p align="left" style="color: #808080; font-size: 0.9em;">
Table 1. Evaluation results of
<a href="https://huggingface.co/dllm-collection/ModernBERT-base-chat-v0" style="color: #808080; text-decoration: none;">
<code>ModernBERT-base-chat-v0</code>
</a>,
<a href="https://huggingface.co/dllm-collection/ModernBERT-large-chat-v0" style="color: #808080; text-decoration: none;">
<code>ModernBERT-large-chat-v0</code>
</a>,
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B" style="color: #808080; text-decoration: none;">
<code>Qwen1.5-0.5B</code>
</a>,
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat" style="color: #808080; text-decoration: none;">
<code>Qwen1.5-0.5B-Chat</code>
</a>,
<a href="https://huggingface.co/openai-community/gpt2" style="color: #808080; text-decoration: none;">
<code>gpt2</code>
</a>, and
<a href="https://huggingface.co/openai-community/gpt2-medium" style="color: #808080; text-decoration: none;">
<code>gpt2-medium</code>
</a>.
<ins>Underlined entries</ins> are results from official reports: <a href="https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf" style="color: #808080; text-decoration: none;">GPT-2 paper</a>, <a href="https://qwen.ai/blog?id=qwen1.5" style="color: #808080; text-decoration: none;">Qwen 1.5 blog</a>, and <a href="https://huggingface.co/Qwen/Qwen2-0.5B-Instruct" style="color: #808080; text-decoration: none;">Qwen2-0.5B-Instruct model card</a>. All other results are evaluated using our framework.
</p>

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 MiB

View File

@ -0,0 +1,71 @@
"""
Interactive chat / generation script for Bert models.
Examples
--------
# Raw multi-turn generation (default)
python -u examples/bert/chat.py --model_name_or_path "YOUR_MODEL_PATH" --chat True
"""
import sys
from dataclasses import dataclass
import transformers
import dllm
from dllm.pipelines import llada
from dllm.tools.chat import multi_turn_chat, single_turn_generate
@dataclass
class ScriptArguments:
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
seed: int = 42
chat: bool = True
visualize: bool = True
def __post_init__(self):
# same base-path resolution logic as in generate.py
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(llada.LLaDAGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
block_length: int = 32
temperature: float = 0.0
remasking: str = "low_confidence"
def main():
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
if script_args.chat:
multi_turn_chat(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
else:
print("\nSingle-turn generation (no chat template).")
single_turn_generate(
generator=generator,
gen_config=gen_config,
visualize=script_args.visualize,
)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nInterrupted. Bye!")
sys.exit(0)

View File

@ -0,0 +1,50 @@
#!/usr/bin/env bash
# ===== Mandatory for proper import and evaluation =====
export PYTHONPATH=.:$PYTHONPATH
export HF_ALLOW_CODE_EVAL=1 # Allow code evaluation
export HF_DATASETS_TRUST_REMOTE_CODE=True # For cmmlu dataset
# ===== Optional but recommended for stability and debugging =====
export PYTHONBREAKPOINT=0 # Disable interactive breakpoints
export NCCL_ASYNC_ERROR_HANDLING=1 # Enable async error handling for multi-GPU communication to avoid deadlocks
export NCCL_DEBUG=warn # Show NCCL warnings for better diagnosis without flooding logs
export TORCH_DISTRIBUTED_DEBUG=DETAIL # Provide detailed logging for PyTorch distributed debugging
# ===== Basic Settings =====
model_name_or_path="dllm-collection/ModernBERT-large-chat-v0"
num_gpu=4
while [[ $# -gt 0 ]]; do
case "$1" in
--model_name_or_path)
model_name_or_path="$2"; shift 2 ;;
--num_gpu)
num_gpu="$2"; shift 2 ;;
esac
done
# ===== Common arguments =====
common_args="--model bert --apply_chat_template" # BERT model is default to use chat template
# =======================
# BERT Instruct (Chat) Tasks
# =======================
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks hellaswag_gen --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks mmlu_generative --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks mmlu_pro --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_length=256"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks arc_challenge_chat --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"
accelerate launch --num_processes ${num_gpu} dllm/pipelines/bert/eval.py \
--tasks winogrande --num_fewshot 0 ${common_args} \
--model_args "pretrained=${model_name_or_path},is_check_greedy=False,mc_num=1,max_new_tokens=128,steps=128,block_length=128"

View File

@ -0,0 +1,73 @@
"""
python -u examples/bert/generate.py --model_name_or_path "YOUR_MODEL_PATH"
"""
from dataclasses import dataclass
import transformers
import dllm
from dllm.tools.chat import decode_trim
from dllm.pipelines import llada
@dataclass
class ScriptArguments:
model_name_or_path: str = "dllm-collection/ModernBERT-large-chat-v0"
seed: int = 42
visualize: bool = True
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
self.model_name_or_path, "BASE_MODELS_DIR"
)
@dataclass
class GeneratorConfig(llada.LLaDAGeneratorConfig):
steps: int = 128
max_new_tokens: int = 128
block_length: int = 64
temperature: float = 0.0
remasking: str = "low_confidence"
parser = transformers.HfArgumentParser((ScriptArguments, GeneratorConfig))
script_args, gen_config = parser.parse_args_into_dataclasses()
transformers.set_seed(script_args.seed)
# Load model & tokenizer
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
generator = llada.LLaDAGenerator(model=model, tokenizer=tokenizer)
terminal_visualizer = dllm.core.generation.visualizer.TerminalVisualizer(
tokenizer=tokenizer
)
# --- Example 1: Batch generation ---
print("\n" + "=" * 80)
print("TEST: bert.generate()".center(80))
print("=" * 80)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = generator.generate(inputs, gen_config, return_dict_in_generate=True)
sequences = decode_trim(tokenizer, outputs.sequences.tolist(), inputs)
for iter, s in enumerate(sequences):
print("\n" + "-" * 80)
print(f"[Case {iter}]")
print("-" * 80)
print(s.strip() if s.strip() else "<empty>")
print("\n" + "=" * 80 + "\n")
if script_args.visualize:
terminal_visualizer.visualize(outputs.histories, rich=True)

127
dllm/examples/bert/pt.py Normal file
View File

@ -0,0 +1,127 @@
"""
Local users
------------
- 1 GPU:
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/bert/pt.py
- 8 GPUs (DDP):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml \
examples/bert/pt.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 8 GPUs (DDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "ddp" \
--script_path "examples/bert/pt.py"
"""
import os
import functools
from dataclasses import dataclass, field
import transformers
import accelerate
import dllm
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "answerdotai/ModernBERT-large"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "Trelis/tiny-shakespeare"
text_field: str = "Text"
max_length: int = 128
streaming: bool = False
drop_tail: bool = True
insert_eos: bool = field(
default=True,
metadata={
"help": "False when adjacent samples from the datasets are semantically coherent."
},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = "models/ModernBERT-base/tiny-shakespeare"
num_train_epochs: int = 20
learning_rate: float = 1e-4
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
eval_steps: float = 0.1
save_steps: float = 0.1
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_pt_dataset(
data_args.dataset_args,
streaming=data_args.streaming,
)
dataset = dataset.map(
functools.partial(
dllm.utils.tokenize_and_group,
tokenizer=tokenizer,
text_field=data_args.text_field,
seq_length=data_args.max_length,
insert_eos=data_args.insert_eos,
drop_tail=data_args.drop_tail,
),
batched=True,
remove_columns=dataset["train"].column_names,
**({} if data_args.streaming else {"num_proc": data_args.num_proc}),
**({} if data_args.streaming else {"desc": "Mapping dataset to PT format"}),
)
if data_args.streaming:
dataset = dataset.shuffle(seed=training_args.seed)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()

127
dllm/examples/bert/sft.py Normal file
View File

@ -0,0 +1,127 @@
"""
Local users
------------
- 1 GPU:
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml --num_processes 1 \
examples/bert/sft.py
- 8 GPUs (DDP):
accelerate launch \
--config_file scripts/accelerate_configs/ddp.yaml \
examples/bert/sft.py
Slurm users
# Note: run `mkdir logs` before running sbatch; and adjust
# `partition` and `quotatype` in `scripts/train.slurm.sh` for your cluster.
------------
- 1 Node, 8 GPUs (DDP):
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "ddp" \
--script_path "examples/bert/sft.py"
- 2 Nodes, 16 GPUs (DDP):
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "ddp" \
--script_path "examples/bert/sft.py"
"""
import os
from dataclasses import dataclass, field
from functools import partial
import transformers
import accelerate
import dllm
logger = dllm.utils.get_default_logger(__name__)
@dataclass
class ModelArguments(dllm.utils.ModelArguments):
model_name_or_path: str = "answerdotai/ModernBERT-large"
@dataclass
class DataArguments(dllm.utils.DataArguments):
dataset_args: str = "tatsu-lab/alpaca"
max_length: int = 512
load_preprocessed_data: bool = False
mask_prompt_loss: bool = field(
default=True,
metadata={"help": "Whether to mask the loss on the prompt tokens"},
)
@dataclass
class TrainingArguments(dllm.utils.TrainingArguments):
output_dir: str = "models/ModernBERT-large/alpaca"
group_by_length: bool = True
learning_rate: float = 1e-4
num_train_epochs: int = 20
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
eval_steps: float = 0.1
save_steps: float = 0.1
def train():
# ----- Argument parsing -------------------------------------------------------
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
dllm.utils.print_args_main(model_args, data_args, training_args)
dllm.utils.initial_training_setup(model_args, data_args, training_args)
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
with accelerate.PartialState().local_main_process_first():
dataset = dllm.data.load_sft_dataset(
data_args.dataset_args,
load_preprocessed_data=data_args.load_preprocessed_data,
)
if not data_args.load_preprocessed_data:
map_fn = partial(
dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
dataset = dataset.map(
map_fn,
num_proc=data_args.num_proc,
desc="Mapping dataset to SFT format",
)
# truncate / filter long sequences if needed
dataset = dllm.utils.post_process_dataset(dataset, data_args)
# ----- Training --------------------------------------------------------------
accelerate.PartialState().wait_for_everyone()
logger.info("Start training...")
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
data_collator=dllm.utils.NoAttentionMaskCollator(
tokenizer,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id, # finetune on padding <eos_token>
),
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
if __name__ == "__main__":
train()

View File

@ -0,0 +1,187 @@
# Dream
> 📄 Paper: [Dream 7B: Diffusion Large Language Models](https://arxiv.org/abs/2508.15487) 💻 Code: [github.com/DreamLM/Dream](https://github.com/DreamLM/Dream)
Resources and examples for training (finetuning & pretraining) and evaluating diffusion language models **Dream**.
## Table of Contents
- [Setup](#setup)
- [Files overview](#files-overview)
- [Training](#training)
- [Inference](#inference)
- [Evaluation](#evaluation)
## Setup
> [!IMPORTANT]
> **Slurm users:** Update `scripts/train.slurm.sh` and `mkdir logps`: see [(optional) Slurm setup](/README.md/#optional-slurm-setup) for details.
>
## Files overview
```
# tools relevant with Dream
dllm/pipelines/dream
├── __init__.py # Package initialization
├── models/
│ ├── configuration_dream.py # Dream model configuration
│ ├── generation_utils.py # Diffusion-based generation logic
│ ├── modeling_dream.py # Core Dream model architecture
│ └── tokenization_dream.py # Tokenizer implementation for Dream
├── generator.py # Inference logic
├── trainer.py # Training logic (pretraining and SFT)
└── utils.py # Auxiliary utilities and helper functions
# example entry points for training / inference / evaluation
examples/dream
├── chat.py # Interactive inference example
├── eval.sh # Automatic evaluation script
├── generate.py # Inference example
├── pt.py # Pretraining example
├── README.md # Documentation (you are here)
└── sft.py # Supervised finetuning example
```
<!-- > [!NOTE]
> We slightly modified [`modeling_dream.py`](/dllm/pipelines/dream/models/modeling_dream.py) so that the `model.forward()` supports 2-D attention masks. We recommend loading models with `dllm.utils.get_tokenizer`; otherwise `import dllm` before calling `AutoModel.from_pretrained` to ensure the correct models from `dllm` are used.
>
> We fixed bugs in `chat_template` and standardize `mask_token` through `dllm.utils.get_tokenizer`. If you use `AutoTokenizer`, keep in mind to set `chat_template` and `mask_token` appropriately yourselves. -->
## Training
### Finetuning
For example, to SFT [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) for instruction following on 8 GPUs, run:
```shell
accelerate launch \
--config_file scripts/accelerate_configs/fsdp.yaml \
examples/dream/sft.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
If you are using slurm and want to train across, for example, 2 nodes (16 GPUs total), run:
```shell
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "models/Dream-7B-SFT/tulu-3-sft-mixture" \
--max_length 1024 \
--num_train_epochs 4 \
--learning_rate 2e-5
```
<!-- **Reproducing [Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Base-7B)**. We tried our best to reproduce Dream-v0-Instruct-7B by finetuning Dream-v0-Base-7B using our training pipeline on the public instruction-following dataset [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture): -->
#### Reproducing [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)
We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) by finetuning [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) using our training pipeline on the public instruction-following dataset [`allenai/tulu-3-sft-mixture`](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture):
```shell
# preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
PYTHONPATH=. python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--sft_map_fn_path "examples.dream.sft.sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
--num_proc 64
# train on 24*8=192 A100s with FSDP, take about 8 hours
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/sft.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "data/sft/dream/tulu-3-sft-mixture" \
--load_preprocessed_data True \
--output_dir "models/Dream-7B-SFT-tulu3-fsdp-bs4-len2048-ep5-lr1e-5" \
--max_length 2048 \
--truncation "right" \
--group_by_length True \
--num_train_epochs 5 \
--learning_rate 1e-5 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 2 \
--eval_on_start False \
--eval_steps 0.1 \
--save_steps 0.05
```
<!-- [TODO] Training curves are on Wandb; checkpoints with evaluation results are available on Hugging Face. See the [Evaluation](#evaluation) section below for evaluation instructions. -->
### Pretraining
Pretrain on [`mlfoundations/dclm-baseline-1.0`](https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0) from scratch using 192 GPUs (24x8) and FSDP:
```shell
sbatch --nodes=24 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/dream/pt.py" \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
--dataset_args "mlfoundations/dclm-baseline-1.0" \
--output_dir "models/Dream-7B-PT/dclm-baseline-1.0" \
--max_length 1024 \
--max_steps 2000 \
--learning_rate 3e-4
```
## Inference
We support batch inference for standard generation and infilling:
<!-- See [`examples/dream/generate.py`](/examples/dream/generate.py) for a full example: -->
```shell
python examples/dream/generate.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
```
We also support interactive multi-turn dialogue with visualization:
```shell
python examples/dream/chat.py --model_name_or_path "Dream-org/Dream-v0-Instruct-7B"
```
## Evaluation
> Read [(optional) Evaluation setup](/README.md/#optional-evaluation-setup) before running evaluation.
For example, to evaluate [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on [`MMLU-Pro`](https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro) using 4 GPUs, run:
```shell
# Use model_args to adjust the generation arguments for evalution.
accelerate launch --num_processes 4 \
dllm/pipelines/dream/eval.py \
--tasks "mmlu_pro" \
--model "dream" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=Dream-org/Dream-v0-Instruct-7B,mc_num=1,max_new_tokens=128,max_length=128,steps=128,temperature=0.1,top_p=0.9,add_bos_token=true,escape_until=true"
```
To automatically evaluate [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) and [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) on all benchmarks, run:
```shell
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Instruct-7B" --instruct True
bash examples/dream/eval.sh --model_name_or_path "Dream-org/Dream-v0-Base-7B" --instruct False
```
### Evaluation results
> Results (evaluated) are evaluated using our framework, while results (reported) come from the original paper. All evaluation settings follow the configurations in the [Dream](https://github.com/DreamLM/Dream) repository, with minor adjustments. Placeholder entries (“–”) indicate results not yet evaluated; full results will be released soon.
| | MMLU | BBH | ARC&#8209;C | ARC&#8209;E | Hellaswag | WinoGrande | PIQA | GSM8K | Math | GPQA | HumanEval | MBPP | RACE | Countdown | Sudoku | Trip&nbsp;planning |
|:----------------|:-------:|:-------:|:-----:|:-----:|:-----------:|:------------:|:----:|:-----:|:----:|:----:|:-----------:|:----:|:------:|:-----------:|:----:|:-----------:|
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (reported) | 69.5 | 57.9 | 59.9 | 83.9 | 73.3 | 74.8 | 75.8 | 77.2 | 39.6 | 36.6 | 57.9 | 56.2 | 44.7 | 16.0 | 81.0 | 17.8 |
| [`Dream-v0-Base-7B`](https://huggingface.co/Dream-org/Dream-v0-Base-7B) (evaluated) | | | 59.7 | 83.3 | 73.1 | 72.9 | 72.0 | 69.6 | | 35.5 | 45.8 | | 43.0 | | | |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 1. Evaluation results of
<a href="https://huggingface.co/Dream-org/Dream-v0-Base-7B" style="color: #808080; text-decoration: none;">
<code>Dream-8B-Base</code>
</a>.
</p>
| | MMLU | MMLU-Pro | GSM8K | Math | GPQA | HumanEval | MBPP | IFEval |
|:----------------|:----:|:---------:|:-----:|:----:|:----:|:-----------:|:----:|:----:|
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(reported) | 67.0 | 43.3 | 81.0 | 39.2 | 33.0 | 55.5 | 58.8 | 62.5 |
| [`Dream-v0-Instruct-7B`](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)(evaluated) | | 43.0 | 82.6 | 39.9 | 32.4 | 59.1 | | 62.3 |
<p align="center" style="color: #808080; font-size: 0.9em;">
Table 2. Evaluation results of
<a href="https://huggingface.co/Dream-org/Dream-v0-Instruct-7B" style="color: #808080; text-decoration: none;">
<code>Dream-8B-Instruct</code>
</a>.
</p>

Some files were not shown because too many files have changed in this diff Show More