1127 update to latest
This commit is contained in:
@ -1,8 +1,18 @@
|
||||
from math import ceil
|
||||
from typing import Optional, Union, Literal
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.modules.module import Module
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.modules.activation import _is_make_fx_tracing, _check_arg_device, _arg_requires_grad
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, hidden_size, dropout):
|
||||
@ -49,6 +59,64 @@ class extendedMLP(nn.Module):
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class extendedMLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
|
||||
super().__init__()
|
||||
self.input_size = in_size
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
if num_layers == 1:
|
||||
# Only one layer
|
||||
self.layers.append(nn.Linear(in_size, out_size))
|
||||
return
|
||||
elif num_layers > 1:
|
||||
# First layer
|
||||
self.layers.append(nn.Linear(in_size, hidden_size))
|
||||
self.layers.append(nn.Dropout(dropout))
|
||||
self.layers.append(nn.ReLU())
|
||||
# Intermediate layers
|
||||
if num_layers > 2:
|
||||
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
|
||||
self.layers.append(nn.Linear(hidden_size, hidden_size))
|
||||
self.layers.append(nn.Dropout(dropout))
|
||||
self.layers.append(nn.ReLU())
|
||||
# Last layer
|
||||
self.layers.append(nn.Linear(hidden_size, out_size))
|
||||
else:
|
||||
raise ValueError("num_layers should be a positive integer")
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, in_size, out_size, num_layers=2, hidden_size=2048, dropout=0.1):
|
||||
super().__init__()
|
||||
self.input_size = in_size
|
||||
|
||||
if num_layers == 1:
|
||||
# 单层情况,直接线性映射
|
||||
self.ffn = nn.Linear(in_size, out_size)
|
||||
elif num_layers == 2:
|
||||
# 两层时使用 SwiGLU
|
||||
self.w1 = nn.Linear(in_size, 2 * hidden_size) # 前半主分支,后半门控分支
|
||||
self.w2 = nn.Linear(hidden_size, out_size)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
raise ValueError("SwiGLU FFN 仅支持 num_layers=1 或 2")
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, "ffn"):
|
||||
return self.ffn(x)
|
||||
else:
|
||||
x_proj = self.w1(x)
|
||||
x_main, x_gate = x_proj.chunk(2, dim=-1) # 一分为二
|
||||
x = F.silu(x_main) * x_gate # SwiGLU: silu(a) * b
|
||||
x = self.dropout(x)
|
||||
x = self.w2(x)
|
||||
return x
|
||||
|
||||
class multiMLP(nn.Module):
|
||||
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
|
||||
super().__init__()
|
||||
@ -209,6 +277,28 @@ class TransformerLayer(nn.Module):
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||
return output_dict
|
||||
|
||||
class TransformerLayerV2(nn.Module):
|
||||
def __init__(self, dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
self.self_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||
self.cross_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||
self.residual_FF = ResidualLayerNormModule(SwiGLUFFN(in_size=dim, out_size=dim, num_layers=2, hidden_size=4*dim, dropout=dropout))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, input_dict):
|
||||
'''
|
||||
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||||
'''
|
||||
# self attention
|
||||
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
|
||||
|
||||
input_dict['input_seq'] = attn_output
|
||||
# cross attention
|
||||
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
|
||||
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||
attn_output = self.dropout(attn_output)
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||
return output_dict
|
||||
class FeatureEnricher(nn.Module):
|
||||
def __init__(self, dim, num_heads, dropout):
|
||||
super().__init__()
|
||||
@ -225,4 +315,483 @@ class FeatureEnricher(nn.Module):
|
||||
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||
attn_output = self.dropout(attn_output)
|
||||
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
|
||||
return output_dict
|
||||
return output_dict
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information from different representation subspaces.
|
||||
|
||||
.. note::
|
||||
See `this tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
|
||||
for an in depth discussion of the performant building blocks PyTorch offers for building your own
|
||||
transformer layers.
|
||||
|
||||
Method described in the paper:
|
||||
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Multi-Head Attention is defined as:
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O
|
||||
|
||||
where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
||||
|
||||
``nn.MultiheadAttention`` will use the optimized implementations of
|
||||
``scaled_dot_product_attention()`` when possible.
|
||||
|
||||
In addition to support for the new ``scaled_dot_product_attention()``
|
||||
function, for speeding up Inference, MHA will use
|
||||
fastpath inference with support for Nested Tensors, iff:
|
||||
|
||||
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
|
||||
- inputs are batched (3D) with ``batch_first==True``
|
||||
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
||||
- training is disabled (using ``.eval()``)
|
||||
- ``add_bias_kv`` is ``False``
|
||||
- ``add_zero_attn`` is ``False``
|
||||
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
||||
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
||||
nor ``attn_mask`` is passed
|
||||
- autocast is disabled
|
||||
|
||||
If the optimized inference fastpath implementation is in use, a
|
||||
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
||||
``query``/``key``/``value`` to represent padding more efficiently than using a
|
||||
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
||||
will be returned, and an additional speedup proportional to the fraction of the input
|
||||
that is padding can be expected.
|
||||
|
||||
Args:
|
||||
embed_dim: Total dimension of the model.
|
||||
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
||||
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
||||
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
||||
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
||||
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
||||
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
||||
Default: ``False``.
|
||||
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
||||
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
||||
batch_first: If ``True``, then the input and output tensors are provided
|
||||
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
.. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
|
||||
https://arxiv.org/abs/2205.14135
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
if embed_dim <= 0 or num_heads <= 0:
|
||||
raise ValueError(
|
||||
f"embed_dim and num_heads must be greater than 0,"
|
||||
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
|
||||
)
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
if self._qkv_same_embed_dim:
|
||||
xavier_uniform_(self.in_proj_weight)
|
||||
else:
|
||||
xavier_uniform_(self.q_proj_weight)
|
||||
xavier_uniform_(self.k_proj_weight)
|
||||
xavier_uniform_(self.v_proj_weight)
|
||||
|
||||
if self.in_proj_bias is not None:
|
||||
constant_(self.in_proj_bias, 0.0)
|
||||
constant_(self.out_proj.bias, 0.0)
|
||||
if self.bias_k is not None:
|
||||
xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
def __setstate__(self, state):
|
||||
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||||
if "_qkv_same_embed_dim" not in state:
|
||||
state["_qkv_same_embed_dim"] = True
|
||||
|
||||
super().__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
) -> tuple[Tensor, Optional[Tensor]]:
|
||||
r"""Compute attention outputs using query, key, and value embeddings.
|
||||
|
||||
Supports optional parameters for padding, masks and attention weights.
|
||||
|
||||
Args:
|
||||
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
||||
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
||||
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
||||
Queries are compared against key-value pairs to produce the output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
||||
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
||||
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
||||
See "Attention Is All You Need" for more details.
|
||||
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
||||
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
||||
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
||||
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
||||
Binary and float masks are supported.
|
||||
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
||||
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
||||
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
||||
Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
|
||||
and achieve the best performance for MHA.
|
||||
Default: ``True``.
|
||||
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
||||
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
||||
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
||||
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
||||
Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
||||
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
||||
the attention weight.
|
||||
If both attn_mask and key_padding_mask are supplied, their types should match.
|
||||
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
||||
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
||||
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
||||
is_causal: If specified, applies a causal mask as attention mask.
|
||||
Default: ``False``.
|
||||
Warning:
|
||||
``is_causal`` provides a hint that ``attn_mask`` is the
|
||||
causal mask. Providing incorrect hints can result in
|
||||
incorrect execution, including forward and backward
|
||||
compatibility.
|
||||
|
||||
Outputs:
|
||||
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
||||
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
||||
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
||||
embedding dimension ``embed_dim``.
|
||||
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
||||
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
||||
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
||||
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
||||
|
||||
.. note::
|
||||
`batch_first` argument is ignored for unbatched inputs.
|
||||
""" # noqa: B950
|
||||
why_not_fast_path = ""
|
||||
if (
|
||||
(attn_mask is not None and torch.is_floating_point(attn_mask))
|
||||
or (key_padding_mask is not None)
|
||||
and torch.is_floating_point(key_padding_mask)
|
||||
):
|
||||
why_not_fast_path = "floating-point masks are not supported for fast path."
|
||||
|
||||
is_batched = query.dim() == 3
|
||||
|
||||
key_padding_mask = F._canonical_mask(
|
||||
mask=key_padding_mask,
|
||||
mask_name="key_padding_mask",
|
||||
other_type=F._none_or_dtype(attn_mask),
|
||||
other_name="attn_mask",
|
||||
target_type=query.dtype,
|
||||
)
|
||||
|
||||
attn_mask = F._canonical_mask(
|
||||
mask=attn_mask,
|
||||
mask_name="attn_mask",
|
||||
other_type=None,
|
||||
other_name="",
|
||||
target_type=query.dtype,
|
||||
check_other=False,
|
||||
)
|
||||
|
||||
is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
|
||||
|
||||
if not is_fastpath_enabled:
|
||||
why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
|
||||
elif not is_batched:
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif self.in_proj_weight is None:
|
||||
why_not_fast_path = "in_proj_weight was None"
|
||||
elif query.dtype != self.in_proj_weight.dtype:
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif (self.num_heads % 2) != 0:
|
||||
why_not_fast_path = "self.num_heads is not even"
|
||||
elif not self.batch_first:
|
||||
why_not_fast_path = "batch_first was not True"
|
||||
elif self.bias_k is not None:
|
||||
why_not_fast_path = "self.bias_k was not None"
|
||||
elif self.bias_v is not None:
|
||||
why_not_fast_path = "self.bias_v was not None"
|
||||
elif self.add_zero_attn:
|
||||
why_not_fast_path = "add_zero_attn was enabled"
|
||||
elif not self._qkv_same_embed_dim:
|
||||
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
||||
elif query.is_nested and (
|
||||
key_padding_mask is not None or attn_mask is not None
|
||||
):
|
||||
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
|
||||
is not supported with NestedTensor input"
|
||||
elif torch.is_autocast_enabled():
|
||||
why_not_fast_path = "autocast is enabled"
|
||||
|
||||
if not why_not_fast_path:
|
||||
tensor_args = (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
)
|
||||
# We have to use list comprehensions below because TorchScript does not support
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif _is_make_fx_tracing():
|
||||
why_not_fast_path = "we are running make_fx tracing"
|
||||
elif not all(_check_arg_device(x) for x in tensor_args):
|
||||
why_not_fast_path = (
|
||||
"some Tensor argument's device is neither one of "
|
||||
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
|
||||
)
|
||||
elif torch.is_grad_enabled() and any(
|
||||
_arg_requires_grad(x) for x in tensor_args
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
if not why_not_fast_path:
|
||||
merged_mask, mask_type = self.merge_masks(
|
||||
attn_mask, key_padding_mask, query
|
||||
)
|
||||
|
||||
if self.in_proj_bias is not None and self.in_proj_weight is not None:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
merged_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
mask_type,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
assert not any_nested, (
|
||||
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
||||
+ f"The fast path was not hit because {why_not_fast_path}"
|
||||
)
|
||||
|
||||
if self.batch_first and is_batched:
|
||||
# make sure that the transpose op does not affect the "is" property
|
||||
if key is value:
|
||||
if query is key:
|
||||
query = key = value = query.transpose(1, 0)
|
||||
else:
|
||||
query, key = (x.transpose(1, 0) for x in (query, key))
|
||||
value = key
|
||||
else:
|
||||
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj_weight,
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
if self.batch_first and is_batched:
|
||||
return attn_output.transpose(1, 0), attn_output_weights
|
||||
else:
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
def merge_masks(
|
||||
self,
|
||||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor],
|
||||
query: Tensor,
|
||||
) -> tuple[Optional[Tensor], Optional[int]]:
|
||||
r"""Determine mask type and combine masks if necessary.
|
||||
|
||||
If only one mask is provided, that mask
|
||||
and the corresponding mask type will be returned. If both masks are provided, they will be both
|
||||
expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
|
||||
and mask type 2 will be returned
|
||||
Args:
|
||||
attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
|
||||
key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
|
||||
query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
|
||||
Returns:
|
||||
merged_mask: merged mask
|
||||
mask_type: merged mask type (0, 1, or 2)
|
||||
"""
|
||||
mask_type: Optional[int] = None
|
||||
merged_mask: Optional[Tensor] = None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
mask_type = 1
|
||||
merged_mask = key_padding_mask
|
||||
|
||||
if attn_mask is not None:
|
||||
# In this branch query can't be a nested tensor, so it has a shape
|
||||
batch_size, seq_len, _ = query.shape
|
||||
mask_type = 2
|
||||
|
||||
# Always expands attn_mask to 4D
|
||||
if attn_mask.dim() == 3:
|
||||
attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
|
||||
else: # attn_mask.dim() == 2:
|
||||
attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
|
||||
batch_size, self.num_heads, -1, -1
|
||||
)
|
||||
merged_mask = attn_mask_expanded
|
||||
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask_expanded = key_padding_mask.view(
|
||||
batch_size, 1, 1, seq_len
|
||||
).expand(-1, self.num_heads, -1, -1)
|
||||
merged_mask = attn_mask_expanded + key_padding_mask_expanded
|
||||
|
||||
# no attn_mask and no key_padding_mask, returns None, None
|
||||
return merged_mask, mask_type
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user