1127 update to latest
This commit is contained in:
@ -1,8 +1,18 @@
|
||||
from math import ceil
|
||||
from typing import Optional, Union, Literal
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.modules.module import Module
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.modules.activation import _is_make_fx_tracing, _check_arg_device, _arg_requires_grad
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, hidden_size, dropout):
|
||||
@ -49,6 +59,64 @@ class extendedMLP(nn.Module):
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class extendedMLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
|
||||
super().__init__()
|
||||
self.input_size = in_size
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
if num_layers == 1:
|
||||
# Only one layer
|
||||
self.layers.append(nn.Linear(in_size, out_size))
|
||||
return
|
||||
elif num_layers > 1:
|
||||
# First layer
|
||||
self.layers.append(nn.Linear(in_size, hidden_size))
|
||||
self.layers.append(nn.Dropout(dropout))
|
||||
self.layers.append(nn.ReLU())
|
||||
# Intermediate layers
|
||||
if num_layers > 2:
|
||||
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
|
||||
self.layers.append(nn.Linear(hidden_size, hidden_size))
|
||||
self.layers.append(nn.Dropout(dropout))
|
||||
self.layers.append(nn.ReLU())
|
||||
# Last layer
|
||||
self.layers.append(nn.Linear(hidden_size, out_size))
|
||||
else:
|
||||
raise ValueError("num_layers should be a positive integer")
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, in_size, out_size, num_layers=2, hidden_size=2048, dropout=0.1):
|
||||
super().__init__()
|
||||
self.input_size = in_size
|
||||
|
||||
if num_layers == 1:
|
||||
# 单层情况,直接线性映射
|
||||
self.ffn = nn.Linear(in_size, out_size)
|
||||
elif num_layers == 2:
|
||||
# 两层时使用 SwiGLU
|
||||
self.w1 = nn.Linear(in_size, 2 * hidden_size) # 前半主分支,后半门控分支
|
||||
self.w2 = nn.Linear(hidden_size, out_size)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
raise ValueError("SwiGLU FFN 仅支持 num_layers=1 或 2")
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, "ffn"):
|
||||
return self.ffn(x)
|
||||
else:
|
||||
x_proj = self.w1(x)
|
||||
x_main, x_gate = x_proj.chunk(2, dim=-1) # 一分为二
|
||||
x = F.silu(x_main) * x_gate # SwiGLU: silu(a) * b
|
||||
x = self.dropout(x)
|
||||
x = self.w2(x)
|
||||
return x
|
||||
|
||||
class multiMLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
|
||||
super().__init__()
|
||||
@ -209,6 +277,28 @@ class TransformerLayer(nn.Module):
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||
return output_dict
|
||||
|
||||
class TransformerLayerV2(nn.Module):
|
||||
def __init__(self, dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
self.self_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||
self.cross_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||
self.residual_FF = ResidualLayerNormModule(SwiGLUFFN(in_size=dim, out_size=dim, num_layers=2, hidden_size=4*dim, dropout=dropout))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, input_dict):
|
||||
'''
|
||||
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||||
'''
|
||||
# self attention
|
||||
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
|
||||
|
||||
input_dict['input_seq'] = attn_output
|
||||
# cross attention
|
||||
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
|
||||
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||
attn_output = self.dropout(attn_output)
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||
return output_dict
|
||||
class FeatureEnricher(nn.Module):
|
||||
def __init__(self, dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
@ -225,4 +315,483 @@ class FeatureEnricher(nn.Module):
|
||||
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||
attn_output = self.dropout(attn_output)
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
|
||||
return output_dict
|
||||
return output_dict
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information from different representation subspaces.
|
||||
|
||||
.. note::
|
||||
See `this tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
|
||||
for an in depth discussion of the performant building blocks PyTorch offers for building your own
|
||||
transformer layers.
|
||||
|
||||
Method described in the paper:
|
||||
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Multi-Head Attention is defined as:
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O
|
||||
|
||||
where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
||||
|
||||
``nn.MultiheadAttention`` will use the optimized implementations of
|
||||
``scaled_dot_product_attention()`` when possible.
|
||||
|
||||
In addition to support for the new ``scaled_dot_product_attention()``
|
||||
function, for speeding up Inference, MHA will use
|
||||
fastpath inference with support for Nested Tensors, iff:
|
||||
|
||||
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
|
||||
- inputs are batched (3D) with ``batch_first==True``
|
||||
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
||||
- training is disabled (using ``.eval()``)
|
||||
- ``add_bias_kv`` is ``False``
|
||||
- ``add_zero_attn`` is ``False``
|
||||
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
||||
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
||||
nor ``attn_mask`` is passed
|
||||
- autocast is disabled
|
||||
|
||||
If the optimized inference fastpath implementation is in use, a
|
||||
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
||||
``query``/``key``/``value`` to represent padding more efficiently than using a
|
||||
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
||||
will be returned, and an additional speedup proportional to the fraction of the input
|
||||
that is padding can be expected.
|
||||
|
||||
Args:
|
||||
embed_dim: Total dimension of the model.
|
||||
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
||||
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
||||
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
||||
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
||||
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
||||
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
||||
Default: ``False``.
|
||||
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
||||
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
||||
batch_first: If ``True``, then the input and output tensors are provided
|
||||
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
||||
https://arxiv.org/abs/2205.14135
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
if embed_dim <= 0 or num_heads <= 0:
|
||||
raise ValueError(
|
||||
f"embed_dim and num_heads must be greater than 0,"
|
||||
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
|
||||
)
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
if self._qkv_same_embed_dim:
|
||||
xavier_uniform_(self.in_proj_weight)
|
||||
else:
|
||||
xavier_uniform_(self.q_proj_weight)
|
||||
xavier_uniform_(self.k_proj_weight)
|
||||
xavier_uniform_(self.v_proj_weight)
|
||||
|
||||
if self.in_proj_bias is not None:
|
||||
constant_(self.in_proj_bias, 0.0)
|
||||
constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
def __setstate__(self, state):
|
||||
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||||
if "_qkv_same_embed_dim" not in state:
|
||||
state["_qkv_same_embed_dim"] = True
|
||||
|
||||
super().__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[Tensor, Optional[Tensor]]:
|
||||
r"""Compute attention outputs using query, key, and value embeddings.
|
||||
|
||||
Supports optional parameters for padding, masks and attention weights.
|
||||
|
||||
Args:
|
||||
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
||||
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
||||
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
||||
Queries are compared against key-value pairs to produce the output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
||||
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
||||
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
||||
See "Attention Is All You Need" for more details.
|
||||
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
||||
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
||||
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
||||
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
||||
Binary and float masks are supported.
|
||||
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
||||
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
||||
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
||||
Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
|
||||
and achieve the best performance for MHA.
|
||||
Default: ``True``.
|
||||
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
||||
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
||||
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
||||
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
||||
Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
||||
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
||||
the attention weight.
|
||||
If both attn_mask and key_padding_mask are supplied, their types should match.
|
||||
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
||||
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
||||
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
||||
is_causal: If specified, applies a causal mask as attention mask.
|
||||
Default: ``False``.
|
||||
Warning:
|
||||
``is_causal`` provides a hint that ``attn_mask`` is the
|
||||
causal mask. Providing incorrect hints can result in
|
||||
incorrect execution, including forward and backward
|
||||
compatibility.
|
||||
|
||||
Outputs:
|
||||
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
||||
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
||||
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
||||
embedding dimension ``embed_dim``.
|
||||
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
||||
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
||||
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
||||
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
||||
|
||||
.. note::
|
||||
`batch_first` argument is ignored for unbatched inputs.
|
||||
""" # noqa: B950
|
||||
why_not_fast_path = ""
|
||||
if (
|
||||
(attn_mask is not None and torch.is_floating_point(attn_mask))
|
||||
or (key_padding_mask is not None)
|
||||
and torch.is_floating_point(key_padding_mask)
|
||||
):
|
||||
why_not_fast_path = "floating-point masks are not supported for fast path."
|
||||
|
||||
is_batched = query.dim() == 3
|
||||
|
||||
key_padding_mask = F._canonical_mask(
|
||||
mask=key_padding_mask,
|
||||
mask_name="key_padding_mask",
|
||||
other_type=F._none_or_dtype(attn_mask),
|
||||
other_name="attn_mask",
|
||||
target_type=query.dtype,
|
||||
)
|
||||
|
||||
attn_mask = F._canonical_mask(
|
||||
mask=attn_mask,
|
||||
mask_name="attn_mask",
|
||||
other_type=None,
|
||||
other_name="",
|
||||
target_type=query.dtype,
|
||||
check_other=False,
|
||||
)
|
||||
|
||||
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
||||
|
||||
if not is_fastpath_enabled:
|
||||
why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
||||
elif not is_batched:
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif self.in_proj_weight is None:
|
||||
why_not_fast_path = "in_proj_weight was None"
|
||||
elif query.dtype != self.in_proj_weight.dtype:
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif (self.num_heads % 2) != 0:
|
||||
why_not_fast_path = "self.num_heads is not even"
|
||||
elif not self.batch_first:
|
||||
why_not_fast_path = "batch_first was not True"
|
||||
elif self.bias_k is not None:
|
||||
why_not_fast_path = "self.bias_k was not None"
|
||||
elif self.bias_v is not None:
|
||||
why_not_fast_path = "self.bias_v was not None"
|
||||
elif self.add_zero_attn:
|
||||
why_not_fast_path = "add_zero_attn was enabled"
|
||||
elif not self._qkv_same_embed_dim:
|
||||
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
||||
elif query.is_nested and (
|
||||
key_padding_mask is not None or attn_mask is not None
|
||||
):
|
||||
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
|
||||
is not supported with NestedTensor input"
|
||||
elif torch.is_autocast_enabled():
|
||||
why_not_fast_path = "autocast is enabled"
|
||||
|
||||
if not why_not_fast_path:
|
||||
tensor_args = (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
)
|
||||
# We have to use list comprehensions below because TorchScript does not support
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif _is_make_fx_tracing():
|
||||
why_not_fast_path = "we are running make_fx tracing"
|
||||
elif not all(_check_arg_device(x) for x in tensor_args):
|
||||
why_not_fast_path = (
|
||||
"some Tensor argument's device is neither one of "
|
||||
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
|
||||
)
|
||||
elif torch.is_grad_enabled() and any(
|
||||
_arg_requires_grad(x) for x in tensor_args
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
if not why_not_fast_path:
|
||||
merged_mask, mask_type = self.merge_masks(
|
||||
attn_mask, key_padding_mask, query
|
||||
)
|
||||
|
||||
if self.in_proj_bias is not None and self.in_proj_weight is not None:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
merged_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
mask_type,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
assert not any_nested, (
|
||||
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
||||
+ f"The fast path was not hit because {why_not_fast_path}"
|
||||
)
|
||||
|
||||
if self.batch_first and is_batched:
|
||||
# make sure that the transpose op does not affect the "is" property
|
||||
if key is value:
|
||||
if query is key:
|
||||
query = key = value = query.transpose(1, 0)
|
||||
else:
|
||||
query, key = (x.transpose(1, 0) for x in (query, key))
|
||||
value = key
|
||||
else:
|
||||
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj_weight,
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if self.batch_first and is_batched:
|
||||
return attn_output.transpose(1, 0), attn_output_weights
|
||||
else:
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
def merge_masks(
|
||||
self,
|
||||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor],
|
||||
query: Tensor,
|
||||
) -> tuple[Optional[Tensor], Optional[int]]:
|
||||
r"""Determine mask type and combine masks if necessary.
|
||||
|
||||
If only one mask is provided, that mask
|
||||
and the corresponding mask type will be returned. If both masks are provided, they will be both
|
||||
expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
|
||||
and mask type 2 will be returned
|
||||
Args:
|
||||
attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
|
||||
key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
|
||||
query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
|
||||
Returns:
|
||||
merged_mask: merged mask
|
||||
mask_type: merged mask type (0, 1, or 2)
|
||||
"""
|
||||
mask_type: Optional[int] = None
|
||||
merged_mask: Optional[Tensor] = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
mask_type = 1
|
||||
merged_mask = key_padding_mask
|
||||
|
||||
if attn_mask is not None:
|
||||
# In this branch query can't be a nested tensor, so it has a shape
|
||||
batch_size, seq_len, _ = query.shape
|
||||
mask_type = 2
|
||||
|
||||
# Always expands attn_mask to 4D
|
||||
if attn_mask.dim() == 3:
|
||||
attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
|
||||
else: # attn_mask.dim() == 2:
|
||||
attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
|
||||
batch_size, self.num_heads, -1, -1
|
||||
)
|
||||
merged_mask = attn_mask_expanded
|
||||
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask_expanded = key_padding_mask.view(
|
||||
batch_size, 1, 1, seq_len
|
||||
).expand(-1, self.num_heads, -1, -1)
|
||||
merged_mask = attn_mask_expanded + key_padding_mask_expanded
|
||||
|
||||
# no attn_mask and no key_padding_mask, returns None, None
|
||||
return merged_mask, mask_type
|
||||
|
||||
|
||||
|
||||
|
||||
@ -347,13 +347,24 @@ class SelfAttention(SubDecoderClass):
|
||||
causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size)
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
|
||||
# self.transformer_decoder = Decoder(
|
||||
# dim = dim,
|
||||
# depth = sub_decoder_depth,
|
||||
# heads = heads,
|
||||
# attn_dropout = dropout,
|
||||
# ff_dropout = dropout,
|
||||
# attn_flash = True)
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
dim = dim,
|
||||
depth = sub_decoder_depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True)
|
||||
attn_flash = True,
|
||||
use_rmsnorm=True,
|
||||
ff_swish = True, # set this to True
|
||||
ff_glu = True, # set to true to use for all feedforwards
|
||||
)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
@ -713,7 +724,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
denoising_steps:int = 8,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1091,7 +1102,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
class DiffusionDecoder(SubDecoderClass):
|
||||
class DiffusionDecoderV2(SubDecoderClass):
|
||||
def __init__(
|
||||
self,
|
||||
prediction_order:list,
|
||||
@ -1102,7 +1113,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
denoising_steps:int = 8,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1129,7 +1140,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
self.input_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
if sub_decoder_enricher_use:
|
||||
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||||
@ -1138,14 +1149,21 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
self.register_buffer('causal_ca_mask', causal_ca_mask)
|
||||
|
||||
# get depth of the sub-decoder
|
||||
if sub_decoder_depth > 1:
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
else:
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
|
||||
if sub_decoder_enricher_use:
|
||||
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
|
||||
self.aux_ar_decoder = SelfAttention(prediction_order=prediction_order,
|
||||
vocab=vocab,
|
||||
sub_decoder_depth=1,
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dropout=dropout,
|
||||
sub_decoder_enricher_use=False)
|
||||
|
||||
# simplified version of the forward process in diffusion model
|
||||
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
|
||||
@ -1273,9 +1291,11 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None, aux_ar=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
copy_input_dict = input_dict.copy()
|
||||
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
|
||||
@ -1307,6 +1327,10 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
if aux_ar: # inference with auxiliary auto-regressive decoder
|
||||
aux_ar_logits, sampled_token_dict = self.aux_ar_decoder(copy_input_dict, sampling_method='auto-regressive', threshold=threshold, temperature=temperature, condition_step=condition_step)
|
||||
# print("aux_ar_logits", aux_ar_logits)
|
||||
return aux_ar_logits, sampled_token_dict
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
@ -1420,4 +1444,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
# get aux ar decoder logits
|
||||
aux_ar_logits = self.aux_ar_decoder(copy_input_dict, target) # B x T
|
||||
return (logits_dict, aux_ar_logits), (masked_indices, p_mask)
|
||||
# return logits_dict, (masked_indices, p_mask)
|
||||
@ -3,6 +3,7 @@ import random
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
from typing import Union, List, Tuple, Dict
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -157,6 +158,8 @@ class IterTuneCompiler(IterableDataset):
|
||||
segment = self.augmentor(segment)
|
||||
# use input_ids replace tune_name
|
||||
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
|
||||
# print(segment.shape, mask.shape, tune_name.shape)
|
||||
# segment = segment[torch.randperm(segment.size(0))]
|
||||
yield segment, mask, tune_name, encoded_caption
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
defaults:
|
||||
# - nn_params: nb8_embSum_NMT
|
||||
# - nn_params: remi8
|
||||
- nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
|
||||
# - nn_params: oct8_embSum_diff_t2m_300M_pretrainingv3
|
||||
# - nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2
|
||||
- nn_params: oct8_embSum_har_t2m_600M_pretrainingv3
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
# - nn_params: nb8_embSum_subPararell
|
||||
@ -15,7 +17,7 @@ defaults:
|
||||
# - nn_params: remi8_main12_head_16_dim512
|
||||
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||
|
||||
dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||
@ -23,28 +25,28 @@ captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
use_ddp: True # True, False | distributed data parallel
|
||||
use_fp16: True # True, False | mixed precision training
|
||||
use_diff: True # True,use diffusion in subdecoder
|
||||
use_diff: False # True,use diffusion in subdecoder
|
||||
diff_steps: 8 # number of diffusion steps
|
||||
use_dispLoss: True
|
||||
use_dispLoss: False
|
||||
lambda_weight: 0.5
|
||||
tau: 0.5
|
||||
|
||||
train_params:
|
||||
device: cuda
|
||||
batch_size: 10
|
||||
batch_size: 9
|
||||
grad_clip: 1.0
|
||||
num_iter: 300000 # total number of iterations
|
||||
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
||||
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
|
||||
iterations_per_training_cycle: 10 # number of iterations for logging training loss
|
||||
iterations_per_validation_cycle: 3000 # number of iterations for validation process
|
||||
input_length: 3072 # input sequence length3072
|
||||
input_length: 2048 # input sequence length3072
|
||||
# you can use focal loss, it it's not used, set focal_gamma to 0
|
||||
focal_alpha: 1
|
||||
focal_gamma: 0
|
||||
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||
scheduler : cosinelr
|
||||
initial_lr: 0.00001
|
||||
initial_lr: 0.0003
|
||||
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||
warmup_steps: 2000 #number of warmup steps
|
||||
|
||||
@ -5,7 +5,7 @@ model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoder
|
||||
model_dropout: 0.2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoderV2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 768
|
||||
num_layer: 16
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoderV2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1080
|
||||
num_layer: 13
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoderV2
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1080
|
||||
num_layer: 13
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: SelfAttention
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1080
|
||||
num_layer: 13
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 3 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: oct
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabOct
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: SelfAttention
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1272
|
||||
num_layer: 20
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
454
Amadeus/toy_train.py
Normal file
454
Amadeus/toy_train.py
Normal file
@ -0,0 +1,454 @@
|
||||
from math import ceil
|
||||
from typing import Optional, Union, Literal
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers import Qwen2Config
|
||||
from transformers.activations import get_activation
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
x = x.float()
|
||||
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||||
x = self.weight * x
|
||||
return x.to(dtype)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
intermediate_size: Optional[int] = None,
|
||||
dropout: float = 0.,
|
||||
hidden_act: Literal['silu', 'gelu'] = 'silu'):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = get_activation(hidden_act)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gate, hidden = self.up_proj(x).chunk(2, dim=-1)
|
||||
gate = self.act_fn(gate)
|
||||
|
||||
return self.down_proj(self.dropout(gate * hidden))
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
head_dim: int = 64,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_dropout: float = 0.,
|
||||
bias: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim)
|
||||
self.head_dim = head_dim
|
||||
self.scaling = head_dim ** -0.5
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||||
bias=bias)
|
||||
self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||||
bias=bias)
|
||||
self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||||
bias=bias)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_attention_heads * self.head_dim, hidden_size,
|
||||
bias=bias)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation]
|
||||
|
||||
# To be compatible with huggingface
|
||||
self.config = Qwen2Config()
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
key_states: Optional[Tensor] = None,
|
||||
value_states: Optional[Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
|
||||
q = self.q_proj(hidden_states)
|
||||
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2)
|
||||
|
||||
if(key_states is None or value_states is None):
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2)
|
||||
value_states = v.view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attn_output, attn_weights = self.attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DecoderLayerWithCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
|
||||
attn_head_dim: int = 64,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_dropout: float = 0.,
|
||||
attn_bias: bool = False,
|
||||
|
||||
mlp_intermediate_size: Optional[int] = None,
|
||||
mlp_dropout: float = 0.,
|
||||
mlp_act: Literal['gelu', 'silu'] = 'silu',
|
||||
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.self_attn = MultiHeadAttention(
|
||||
hidden_size,
|
||||
attn_head_dim,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
attn_bias,
|
||||
rms_norm_eps,
|
||||
attn_implementation)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
hidden_size,
|
||||
attn_head_dim,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
attn_bias,
|
||||
rms_norm_eps,
|
||||
attn_implementation)
|
||||
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
mlp_intermediate_size,
|
||||
mlp_dropout,
|
||||
mlp_act)
|
||||
|
||||
self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_dict,
|
||||
):
|
||||
'''
|
||||
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||||
'''
|
||||
hidden_states = input_dict['input_seq']
|
||||
cross_attn_states = input_dict['memory']
|
||||
cross_attn_mask = input_dict['memory_mask']
|
||||
attention_mask = input_dict['memory_mask']
|
||||
output_attentions = False
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layernorm(hidden_states)
|
||||
hidden_states, dec_attn_weight = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Cross Attention
|
||||
if(cross_attn_states is not None):
|
||||
residual = hidden_states
|
||||
hidden_states = self.cross_attn_layernorm(hidden_states)
|
||||
cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim)
|
||||
key_states = self.cross_attn.k_proj(cross_attn_states)
|
||||
value_states = self.cross_attn.v_proj(cross_attn_states)
|
||||
|
||||
key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2)
|
||||
value_states = value_states.view(cross_attn_shape).transpose(1, 2)
|
||||
hidden_states, cross_attn_weight = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=cross_attn_mask,
|
||||
key_states=key_states,
|
||||
value_states=value_states
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Feed Forward
|
||||
residual = hidden_states
|
||||
hidden_states = self.ffn_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask}
|
||||
if(output_attentions):
|
||||
if(cross_attn_states is not None):
|
||||
# return (hidden_states, (dec_attn_weight, cross_attn_weight))
|
||||
output_dict['dec_attn_weight'] = dec_attn_weight
|
||||
output_dict['cross_attn_weight'] = cross_attn_weight
|
||||
return output_dict
|
||||
else:
|
||||
# return (hidden_states, dec_attn_weight)
|
||||
output_dict['dec_attn_weight'] = dec_attn_weight
|
||||
return output_dict
|
||||
else:
|
||||
return output_dict
|
||||
|
||||
def sinusoidal_positional_embedding(
|
||||
x: Tensor,
|
||||
n: float = 10000.0) -> Tensor:
|
||||
device, dtype = x.device, x.dtype
|
||||
T = x.shape[-2]
|
||||
D = x.shape[-1]
|
||||
|
||||
positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1)
|
||||
embeddings = torch.zeros(T, D, device=device, dtype=dtype)
|
||||
|
||||
denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding
|
||||
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model))
|
||||
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model))
|
||||
|
||||
return x + embeddings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
|
||||
# 创建一个简单的数据集来测试收敛性
|
||||
class SimpleSeq2SeqDataset(Dataset):
|
||||
def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256):
|
||||
self.num_samples = num_samples
|
||||
self.seq_len = seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 模拟编码器输出 (memory)
|
||||
memory = torch.randn(self.seq_len, self.hidden_size)
|
||||
|
||||
# 模拟解码器输入和目标
|
||||
decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,))
|
||||
target = torch.randint(0, self.vocab_size, (self.seq_len,))
|
||||
|
||||
# 创建简单的注意力掩码
|
||||
memory_mask = torch.ones(self.seq_len, self.seq_len)
|
||||
|
||||
return {
|
||||
'memory': memory,
|
||||
'decoder_input': decoder_input,
|
||||
'target': target,
|
||||
'memory_mask': memory_mask
|
||||
}
|
||||
|
||||
# 创建简单的模型
|
||||
class SimpleDecoderModel(nn.Module):
|
||||
def __init__(self, vocab_size, hidden_size=256, num_layers=3):
|
||||
super(SimpleDecoderModel, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# 嵌入层
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
|
||||
# 位置编码
|
||||
self.pos_encoding = sinusoidal_positional_embedding
|
||||
|
||||
# 多个解码器层
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderLayerWithCrossAttention(
|
||||
hidden_size=hidden_size,
|
||||
attn_head_dim=64,
|
||||
num_attention_heads=2,
|
||||
attention_dropout=0.1,
|
||||
attn_bias=False,
|
||||
mlp_intermediate_size=512,
|
||||
mlp_dropout=0.1,
|
||||
mlp_act='silu',
|
||||
rms_norm_eps=1e-6,
|
||||
attn_implementation="sdpa"
|
||||
) for _ in range(num_layers)
|
||||
])
|
||||
|
||||
# 输出层
|
||||
self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6)
|
||||
self.output_proj = nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
def forward(self, decoder_input, memory, memory_mask):
|
||||
# 嵌入 + 位置编码
|
||||
x = self.embedding(decoder_input)
|
||||
x = self.pos_encoding(x)
|
||||
|
||||
# 通过多个解码器层
|
||||
input_dict = {
|
||||
'input_seq': x,
|
||||
'memory': memory,
|
||||
'memory_mask': memory_mask
|
||||
}
|
||||
|
||||
for layer in self.layers:
|
||||
output_dict = layer(input_dict)
|
||||
input_dict['input_seq'] = output_dict['input_seq']
|
||||
|
||||
# 最终输出
|
||||
hidden_states = self.output_norm(output_dict['input_seq'])
|
||||
logits = self.output_proj(hidden_states)
|
||||
|
||||
return logits
|
||||
|
||||
# 训练函数
|
||||
def train_decoder_model():
|
||||
# 超参数
|
||||
vocab_size = 100
|
||||
hidden_size = 256
|
||||
seq_len = 10
|
||||
batch_size = 32
|
||||
num_epochs = 10
|
||||
learning_rate = 0.001
|
||||
|
||||
# 设备设置
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 创建数据集和数据加载器
|
||||
dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len,
|
||||
vocab_size=vocab_size, hidden_size=hidden_size)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# 初始化模型
|
||||
model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size)
|
||||
model.to(device)
|
||||
|
||||
# 损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# 训练记录
|
||||
train_losses = []
|
||||
|
||||
print("开始训练...")
|
||||
model.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
epoch_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in dataloader:
|
||||
# 移动到设备
|
||||
memory = batch['memory'].to(device)
|
||||
decoder_input = batch['decoder_input'].long().to(device)
|
||||
target = batch['target'].long().to(device)
|
||||
memory_mask = batch['memory_mask'].to(device)
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
logits = model(decoder_input, memory, memory_mask)
|
||||
|
||||
# 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len)
|
||||
loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1))
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = epoch_loss / num_batches
|
||||
train_losses.append(avg_loss)
|
||||
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
||||
|
||||
# 早期停止检查:如果损失明显下降,说明模型在收敛
|
||||
if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值
|
||||
print("✅ 模型正常收敛!")
|
||||
return True
|
||||
|
||||
# 检查最终损失
|
||||
final_loss = train_losses[-1]
|
||||
if final_loss < 5.0:
|
||||
print("✅ 训练完成,模型正常收敛!")
|
||||
return True
|
||||
else:
|
||||
print("❌ 损失下降不明显,可能需要调整模型或超参数")
|
||||
return False
|
||||
|
||||
# 运行训练
|
||||
if __name__ == "__main__":
|
||||
# 先测试单个decoder层的前向传播
|
||||
print("测试DecoderLayerWithCrossAttention前向传播...")
|
||||
|
||||
# 创建测试数据
|
||||
batch_size, seq_len, hidden_size = 2, 64, 256
|
||||
decoder_input = torch.randn(seq_len, hidden_size)
|
||||
memory = torch.randn(seq_len, hidden_size)
|
||||
memory_mask = torch.ones(seq_len, seq_len)
|
||||
|
||||
# 测试单层
|
||||
decoder_layer = DecoderLayerWithCrossAttention(
|
||||
hidden_size=hidden_size,
|
||||
attn_head_dim=64,
|
||||
num_attention_heads=4
|
||||
)
|
||||
|
||||
input_dict = {
|
||||
'input_seq': decoder_input,
|
||||
'memory': memory,
|
||||
'memory_mask': memory_mask
|
||||
}
|
||||
|
||||
try:
|
||||
output = decoder_layer(input_dict)
|
||||
print("✅ DecoderLayer前向传播测试通过")
|
||||
print(f"输入形状: {decoder_input.shape}")
|
||||
print(f"输出形状: {output['input_seq'].shape}")
|
||||
|
||||
# 运行完整训练
|
||||
success = train_decoder_model()
|
||||
if success:
|
||||
print("\n🎉 测试完成!DecoderLayerWithCrossAttention能够正常训练和收敛")
|
||||
else:
|
||||
print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 前向传播测试失败: {e}")
|
||||
@ -228,19 +228,39 @@ class DiffusionLoss4CompoundToken():
|
||||
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_aux_ar_nll_loss(self, logits, target, mask):
|
||||
probs = logits.softmax(dim=-1)
|
||||
if probs.ndim == 3:
|
||||
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
# clamp min value to 1e-7 to avoid log(0)
|
||||
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
|
||||
return loss
|
||||
|
||||
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
|
||||
train_loss_list = []
|
||||
log_loss_dict_normal = {}
|
||||
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
disp_loss = None
|
||||
aux_ar_logits = None
|
||||
# print(len(logits_dict))
|
||||
if len(logits_dict) == 2: # has aux ar loss
|
||||
logits_dict, aux_ar_logits = logits_dict
|
||||
if input_dict is not None:
|
||||
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
|
||||
feat = hidden_vec.mean(dim=1) #bs,dim
|
||||
disp_loss = dispersive_loss(feat, tau=tau) # scalar
|
||||
for idx, key in enumerate(self.feature_list):
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
|
||||
if aux_ar_logits is not None:
|
||||
aux_ar_loss = self.get_aux_ar_nll_loss(aux_ar_logits[key], shifted_tgt[..., idx], mask)
|
||||
training_loss = 0.5 * training_loss + 0.5 * aux_ar_loss
|
||||
train_loss_list.append(training_loss)
|
||||
if valid:
|
||||
if key == 'type' or key == 'timesig':
|
||||
|
||||
@ -74,8 +74,8 @@ class LanguageModelTrainer:
|
||||
sampling_threshold: float, # Threshold for sampling decisions
|
||||
sampling_temperature: float, # Temperature for controlling sampling randomness
|
||||
config, # Configuration parameters (contains general, training, and inference settings)
|
||||
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
# model_checkpoint="wandb/run-20251114_151512-k21rnynj/files/checkpoints/iter104999_loss0.2490.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
):
|
||||
# Save model, optimizer, and other configurations
|
||||
self.model = model
|
||||
@ -892,6 +892,10 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
|
||||
segment, mask, caption,encoded_caption = batch
|
||||
input_seq, target = segment[:, :-1], segment[:, 1:]
|
||||
total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
|
||||
try:
|
||||
aux_ar_logits, logits_dict = logits_dict
|
||||
except:
|
||||
logits_dict = logits_dict
|
||||
probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()}
|
||||
num_nonmask_tokens = torch.sum(mask)
|
||||
input_seq = input_seq.to(self.device)
|
||||
|
||||
@ -729,11 +729,6 @@ class XtransformerNewPretrainingDecoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
# frozen text encoder
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
|
||||
Reference in New Issue
Block a user