from math import ceil from typing import Optional, Union, Literal from typing_extensions import Unpack import torch from torch import Tensor, nn from torch.nn import functional as F from torch import Tensor from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.parameter import Parameter from torch.nn.modules.module import Module from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.modules.activation import _is_make_fx_tracing, _check_arg_device, _arg_requires_grad class MLP(nn.Module): def __init__(self, in_size, out_size, hidden_size, dropout): super().__init__() self.out_size = out_size self.layer = nn.Sequential( nn.Linear(in_size, hidden_size), nn.Dropout(dropout), nn.ReLU(), nn.Linear(hidden_size, out_size) ) def forward(self, x): return self.layer(x) class extendedMLP(nn.Module): def __init__(self, in_size, out_size, num_layers, hidden_size, dropout): super().__init__() self.input_size = in_size self.layers = nn.ModuleList() if num_layers == 1: # Only one layer self.layers.append(nn.Linear(in_size, out_size)) return elif num_layers > 1: # First layer self.layers.append(nn.Linear(in_size, hidden_size)) self.layers.append(nn.Dropout(dropout)) self.layers.append(nn.ReLU()) # Intermediate layers if num_layers > 2: for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers self.layers.append(nn.Linear(hidden_size, hidden_size)) self.layers.append(nn.Dropout(dropout)) self.layers.append(nn.ReLU()) # Last layer self.layers.append(nn.Linear(hidden_size, out_size)) else: raise ValueError("num_layers should be a positive integer") def forward(self, x): for layer in self.layers: x = layer(x) return x class extendedMLP(nn.Module): def __init__(self, in_size, out_size, num_layers, hidden_size, dropout): super().__init__() self.input_size = in_size self.layers = nn.ModuleList() if num_layers == 1: # Only one layer self.layers.append(nn.Linear(in_size, out_size)) return elif num_layers > 1: # First layer self.layers.append(nn.Linear(in_size, hidden_size)) self.layers.append(nn.Dropout(dropout)) self.layers.append(nn.ReLU()) # Intermediate layers if num_layers > 2: for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers self.layers.append(nn.Linear(hidden_size, hidden_size)) self.layers.append(nn.Dropout(dropout)) self.layers.append(nn.ReLU()) # Last layer self.layers.append(nn.Linear(hidden_size, out_size)) else: raise ValueError("num_layers should be a positive integer") def forward(self, x): for layer in self.layers: x = layer(x) return x class SwiGLUFFN(nn.Module): def __init__(self, in_size, out_size, num_layers=2, hidden_size=2048, dropout=0.1): super().__init__() self.input_size = in_size if num_layers == 1: # 单层情况,直接线性映射 self.ffn = nn.Linear(in_size, out_size) elif num_layers == 2: # 两层时使用 SwiGLU self.w1 = nn.Linear(in_size, 2 * hidden_size) # 前半主分支,后半门控分支 self.w2 = nn.Linear(hidden_size, out_size) self.dropout = nn.Dropout(dropout) else: raise ValueError("SwiGLU FFN 仅支持 num_layers=1 或 2") def forward(self, x): if hasattr(self, "ffn"): return self.ffn(x) else: x_proj = self.w1(x) x_main, x_gate = x_proj.chunk(2, dim=-1) # 一分为二 x = F.silu(x_main) * x_gate # SwiGLU: silu(a) * b x = self.dropout(x) x = self.w2(x) return x class multiMLP(nn.Module): def __init__(self, in_size, out_size, hidden_size, dropout, pred_order): super().__init__() self.out_size = out_size self.layer = nn.ModuleList([MLP(in_size, out_size, hidden_size, dropout) for _ in pred_order]) def forward(self, x, choice): ''' x: B x T x d_model choice: token type from self.pred_order (str or list of str) ''' if isinstance(choice, str): idx = self.pred_order.index(choice) return self.layer[idx](x) elif len(choice) > 1 and not isinstance(choice, str): raise ValueError("multiMLP doesn't support parallel prediction") class ResidualLayerNormModule(nn.Module): def __init__(self, submodule: nn.Module): super().__init__() self.submodule = submodule if submodule.__class__.__name__ == 'MultiheadAttention': self.layer_norm = nn.LayerNorm(self.submodule.embed_dim) else: self.layer_norm = nn.LayerNorm(self.submodule.input_size) def forward_attention(self, q, k, v, attn_mask, type): attn_output, _ = self.submodule(q, k, v, attn_mask=attn_mask, need_weights=False, average_attn_weights=False) return self.layer_norm(attn_output + q) def forward_mlp(self, x): return self.layer_norm(self.submodule(x) + x) class MultiProj_hidden2logit(nn.Module): def __init__(self, dim, vocab_sizes): super().__init__() self.layers = nn.ModuleDict({ f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items() }) def forward(self, hidden_vec, feature): logit = self.layers[f"layer_{feature}"](hidden_vec) return logit class MultiProj_catvec2hidden(nn.Module): def __init__(self, config, par_pred_keys, seq_pred_keys): super().__init__() ''' This class is used in SQstyleEachEmbStrategy par_pred_keys: list of independent features(These tokens are predicted in parallel) seq_pred_keys: list of sequential features(These tokens are predicted sequentially) ''' net_param = config.nn_params self.d_model = net_param.model.d_model independent_emb_size = 0 for key in par_pred_keys: independent_emb_size += net_param.emb[key] self.layers = nn.ModuleDict({ 'layer_independent': nn.Linear(self.d_model + independent_emb_size, self.d_model), **{f"layer_{key}": nn.Linear(self.d_model + net_param.emb[key], self.d_model) for key in seq_pred_keys} }) self.par_pred_keys = par_pred_keys self.seq_pred_keys = seq_pred_keys self.dropout = nn.Dropout(0.1) self.relu = nn.ReLU() def forward(self, x, choice): ''' x: B x T x (d_model + emb_size) choice: key type (str or list of str) ''' if isinstance(choice, str): # single key assert choice in self.seq_pred_keys output = self.layers[f"layer_{choice}"](x) return self.relu(self.dropout(output)) elif len(choice) > 1 and not isinstance(choice, str): # multiple keys, parallel assert choice == self.par_pred_keys # the order of choice should be the same as the order of self.par_pred_keys output = self.layers['layer_independent'](x) return self.relu(self.dropout(output)) def mask_tensor(tensor, mask_rate=0.15): # Get the size of the tensor batch_size, seq_len, dim = tensor.size() # Calculate the total number of elements and the number to mask total_elements = batch_size * seq_len num_to_mask = int(total_elements * mask_rate) # Create a 1D binary mask where 1 indicates that element will be masked. # Start by creating a tensor of zeros with length equal to the total number of elements. mask = torch.zeros(total_elements).to(tensor.device) # Set `num_to_mask` random indices to 1 (masking) indices_to_mask = torch.randperm(total_elements)[:num_to_mask] mask[indices_to_mask] = 1 # Reshape the mask to match the original tensor's shape mask = mask.reshape(batch_size, seq_len) mask = mask.unsqueeze(2) # B x T x 1 masked_tensor = tensor * (mask == 0).float() # B x T x d_model return masked_tensor def generate_causality_mask_on_window(size, window_size): mask = torch.zeros((size, size)) for i in range(size): mask[i, i+window_size:] = 1 return mask.bool() # generate boolean mask, if the value is 1 or true, it means the value is masked # considers BOS token and mask margin def generate_CA_mask(tgt_len, memory_len, mask_margin=0): mask = torch.triu(torch.ones((tgt_len, memory_len)), diagonal=mask_margin+1) return mask.bool() # generate boolean mask, if the value is 1 or true, it means the value is masked def generate_SA_mask(tgt_len): mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1) return mask.bool() def generate_none_causality_mask(tgt_len, memory_len): mask = torch.zeros((tgt_len, memory_len)) return mask.bool() class DecoderLayer(nn.Module): def __init__(self, dim, num_heads, dropout): super().__init__() self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout)) self.dropout = nn.Dropout(dropout) def forward(self, input_dict): ''' input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask} ''' # cross attention attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross') attn_output = self.residual_FF.forward_mlp(attn_output) attn_output = self.dropout(attn_output) output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']} return output_dict class TransformerLayer(nn.Module): def __init__(self, dim, num_heads, dropout): super().__init__() self.self_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout)) self.dropout = nn.Dropout(dropout) def forward(self, input_dict): ''' input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask} ''' # self attention attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self') input_dict['input_seq'] = attn_output # cross attention attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross') attn_output = self.residual_FF.forward_mlp(attn_output) attn_output = self.dropout(attn_output) output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']} return output_dict class TransformerLayerV2(nn.Module): def __init__(self, dim, num_heads, dropout): super().__init__() self.self_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) self.cross_attn_block = ResidualLayerNormModule(MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) self.residual_FF = ResidualLayerNormModule(SwiGLUFFN(in_size=dim, out_size=dim, num_layers=2, hidden_size=4*dim, dropout=dropout)) self.dropout = nn.Dropout(dropout) def forward(self, input_dict): ''' input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask} ''' # self attention attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self') input_dict['input_seq'] = attn_output # cross attention attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross') attn_output = self.residual_FF.forward_mlp(attn_output) attn_output = self.dropout(attn_output) output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']} return output_dict class FeatureEnricher(nn.Module): def __init__(self, dim, num_heads, dropout): super().__init__() self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)) self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout)) self.dropout = nn.Dropout(dropout) def forward(self, input_dict): ''' input_dict = {'input_seq': input_seq, 'memory': memory} ''' # cross attention attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], None, type='feature_enrichment') attn_output = self.residual_FF.forward_mlp(attn_output) attn_output = self.dropout(attn_output) output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']} return output_dict class MultiheadAttention(Module): r"""Allows the model to jointly attend to information from different representation subspaces. .. note:: See `this tutorial `_ for an in depth discussion of the performant building blocks PyTorch offers for building your own transformer layers. Method described in the paper: `Attention Is All You Need `_. Multi-Head Attention is defined as: .. math:: \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. ``nn.MultiheadAttention`` will use the optimized implementations of ``scaled_dot_product_attention()`` when possible. In addition to support for the new ``scaled_dot_product_attention()`` function, for speeding up Inference, MHA will use fastpath inference with support for Nested Tensors, iff: - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor). - inputs are batched (3D) with ``batch_first==True`` - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - training is disabled (using ``.eval()``) - ``add_bias_kv`` is ``False`` - ``add_zero_attn`` is ``False`` - ``kdim`` and ``vdim`` are equal to ``embed_dim`` - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` nor ``attn_mask`` is passed - autocast is disabled If the optimized inference fastpath implementation is in use, a `NestedTensor `_ can be passed for ``query``/``key``/``value`` to represent padding more efficiently than using a padding mask. In this case, a `NestedTensor `_ will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected. Args: embed_dim: Total dimension of the model. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). bias: If specified, adds bias to input / output projection layers. Default: ``True``. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: ``False``. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). Examples:: >>> # xdoctest: +SKIP >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value) .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`: https://arxiv.org/abs/2205.14135 """ __constants__ = ["batch_first"] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__( self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None, ) -> None: if embed_dim <= 0 or num_heads <= 0: raise ValueError( f"embed_dim and num_heads must be greater than 0," f" got embed_dim={embed_dim} and num_heads={num_heads} instead" ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" if not self._qkv_same_embed_dim: self.q_proj_weight = Parameter( torch.empty((embed_dim, embed_dim), **factory_kwargs) ) self.k_proj_weight = Parameter( torch.empty((embed_dim, self.kdim), **factory_kwargs) ) self.v_proj_weight = Parameter( torch.empty((embed_dim, self.vdim), **factory_kwargs) ) self.register_parameter("in_proj_weight", None) else: self.in_proj_weight = Parameter( torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) else: self.register_parameter("in_proj_bias", None) self.out_proj = NonDynamicallyQuantizableLinear( embed_dim, embed_dim, bias=bias, **factory_kwargs ) if add_bias_kv: self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.0) constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if "_qkv_same_embed_dim" not in state: state["_qkv_same_embed_dim"] = True super().__setstate__(state) def forward( self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, is_causal: bool = False, ) -> tuple[Tensor, Optional[Tensor]]: r"""Compute attention outputs using query, key, and value embeddings. Supports optional parameters for padding, masks and attention weights. Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against key-value pairs to produce the output. See "Attention Is All You Need" for more details. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` and achieve the best performance for MHA. Default: ``True``. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight. If both attn_mask and key_padding_mask are supplied, their types should match. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) is_causal: If specified, applies a causal mask as attention mask. Default: ``False``. Warning: ``is_causal`` provides a hint that ``attn_mask`` is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility. Outputs: - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the embedding dimension ``embed_dim``. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. .. note:: `batch_first` argument is ignored for unbatched inputs. """ # noqa: B950 why_not_fast_path = "" if ( (attn_mask is not None and torch.is_floating_point(attn_mask)) or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask) ): why_not_fast_path = "floating-point masks are not supported for fast path." is_batched = query.dim() == 3 key_padding_mask = F._canonical_mask( mask=key_padding_mask, mask_name="key_padding_mask", other_type=F._none_or_dtype(attn_mask), other_name="attn_mask", target_type=query.dtype, ) attn_mask = F._canonical_mask( mask=attn_mask, mask_name="attn_mask", other_type=None, other_name="", target_type=query.dtype, check_other=False, ) is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() if not is_fastpath_enabled: why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" elif not is_batched: why_not_fast_path = ( f"input not batched; expected query.dim() of 3 but got {query.dim()}" ) elif query is not key or key is not value: # When lifting this restriction, don't forget to either # enforce that the dtypes all match or test cases where # they don't! why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" elif self.in_proj_weight is None: why_not_fast_path = "in_proj_weight was None" elif query.dtype != self.in_proj_weight.dtype: # this case will fail anyway, but at least they'll get a useful error message. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" elif self.training: why_not_fast_path = "training is enabled" elif (self.num_heads % 2) != 0: why_not_fast_path = "self.num_heads is not even" elif not self.batch_first: why_not_fast_path = "batch_first was not True" elif self.bias_k is not None: why_not_fast_path = "self.bias_k was not None" elif self.bias_v is not None: why_not_fast_path = "self.bias_v was not None" elif self.add_zero_attn: why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" elif query.is_nested and ( key_padding_mask is not None or attn_mask is not None ): why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ is not supported with NestedTensor input" elif torch.is_autocast_enabled(): why_not_fast_path = "autocast is enabled" if not why_not_fast_path: tensor_args = ( query, key, value, self.in_proj_weight, self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, ) # We have to use list comprehensions below because TorchScript does not support # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" elif _is_make_fx_tracing(): why_not_fast_path = "we are running make_fx tracing" elif not all(_check_arg_device(x) for x in tensor_args): why_not_fast_path = ( "some Tensor argument's device is neither one of " f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}" ) elif torch.is_grad_enabled() and any( _arg_requires_grad(x) for x in tensor_args ): why_not_fast_path = ( "grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad" ) if not why_not_fast_path: merged_mask, mask_type = self.merge_masks( attn_mask, key_padding_mask, query ) if self.in_proj_bias is not None and self.in_proj_weight is not None: return torch._native_multi_head_attention( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, merged_mask, need_weights, average_attn_weights, mask_type, ) any_nested = query.is_nested or key.is_nested or value.is_nested assert not any_nested, ( "MultiheadAttention does not support NestedTensor outside of its fast path. " + f"The fast path was not hit because {why_not_fast_path}" ) if self.batch_first and is_batched: # make sure that the transpose op does not affect the "is" property if key is value: if query is key: query = key = value = query.transpose(1, 0) else: query, key = (x.transpose(1, 0) for x in (query, key)) value = key else: query, key, value = (x.transpose(1, 0) for x in (query, key, value)) if not self._qkv_same_embed_dim: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights, is_causal=is_causal, ) else: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, average_attn_weights=average_attn_weights, is_causal=is_causal, ) if self.batch_first and is_batched: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights def merge_masks( self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], query: Tensor, ) -> tuple[Optional[Tensor], Optional[int]]: r"""Determine mask type and combine masks if necessary. If only one mask is provided, that mask and the corresponding mask type will be returned. If both masks are provided, they will be both expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` and mask type 2 will be returned Args: attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` Returns: merged_mask: merged mask mask_type: merged mask type (0, 1, or 2) """ mask_type: Optional[int] = None merged_mask: Optional[Tensor] = None if key_padding_mask is not None: mask_type = 1 merged_mask = key_padding_mask if attn_mask is not None: # In this branch query can't be a nested tensor, so it has a shape batch_size, seq_len, _ = query.shape mask_type = 2 # Always expands attn_mask to 4D if attn_mask.dim() == 3: attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) else: # attn_mask.dim() == 2: attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand( batch_size, self.num_heads, -1, -1 ) merged_mask = attn_mask_expanded if key_padding_mask is not None: key_padding_mask_expanded = key_padding_mask.view( batch_size, 1, 1, seq_len ).expand(-1, self.num_heads, -1, -1) merged_mask = attn_mask_expanded + key_padding_mask_expanded # no attn_mask and no key_padding_mask, returns None, None return merged_mask, mask_type