0925 use custom x_transformers for easy develop
This commit is contained in:
556
Amadeus/custom_attend.py
Normal file
556
Amadeus/custom_attend.py
Normal file
@ -0,0 +1,556 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, Parameter
|
||||
from torch import cat, nn, einsum, Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
from collections import namedtuple
|
||||
from functools import wraps
|
||||
from packaging import version
|
||||
from dataclasses import dataclass
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
|
||||
# constants
|
||||
|
||||
@dataclass
|
||||
class Intermediates:
|
||||
qk_similarities: Tensor | None = None
|
||||
pre_softmax_attn: Tensor | None = None
|
||||
post_softmax_attn: Tensor | None = None
|
||||
values: Tensor | None = None
|
||||
cached_kv: tuple[Tensor, Tensor] | None = None
|
||||
layer_type: str | None = None
|
||||
hybrid_hidden: Tensor | None = None
|
||||
|
||||
def to_tuple(self):
|
||||
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def at_most_one_of(*bools):
|
||||
return sum([*map(int, bools)]) <= 1
|
||||
|
||||
def compact(arr):
|
||||
return [*filter(exists, arr)]
|
||||
|
||||
@torch.jit.script
|
||||
def softclamp(t: Tensor, value: float):
|
||||
return (t / value).tanh() * value
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
def once(fn):
|
||||
called = False
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
nonlocal called
|
||||
if called:
|
||||
return
|
||||
called = True
|
||||
return fn(x)
|
||||
return inner
|
||||
|
||||
print_once = once(print)
|
||||
|
||||
# selective attention
|
||||
# https://arxiv.org/abs/2410.02703 - section 3.3
|
||||
# it is a technique to allow each token to prevent itself from being attended to by future tokens
|
||||
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
|
||||
|
||||
def selective_attn(
|
||||
sim,
|
||||
sim_head_gate = None,
|
||||
no_mask_sos = True
|
||||
):
|
||||
i, j, device = *sim.shape[-2:], sim.device
|
||||
sim_head_gate = default(sim_head_gate, sim[:, 0])
|
||||
|
||||
gate = F.relu(sim_head_gate) # only positive
|
||||
|
||||
if no_mask_sos:
|
||||
gate = gate.clone()
|
||||
gate[..., -i] = 0.
|
||||
|
||||
eye = torch.eye(i, device = device)
|
||||
|
||||
if j > i:
|
||||
eye = F.pad(eye, (j - i, 0), value = 1.)
|
||||
|
||||
gate = (1. - eye) * gate
|
||||
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
|
||||
gate = gate.cumsum(dim = -2)
|
||||
|
||||
return sim - rearrange(gate, 'b i j -> b 1 i j')
|
||||
|
||||
# alternative distance functions
|
||||
|
||||
def qk_l2_dist_squared(q, k):
|
||||
if k.ndim == 3:
|
||||
k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
|
||||
|
||||
q, packed_shape = pack_one(q, '* i d')
|
||||
k, _ = pack_one(k, '* j d')
|
||||
|
||||
l2_dist_squared = torch.cdist(q, k) ** 2
|
||||
return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
||||
|
||||
# one-hot straight through softmax
|
||||
|
||||
def one_hot_straight_through(logits, temperature = 1.):
|
||||
one_hot_indices = logits.argmax(dim = -1, keepdim = True)
|
||||
one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
|
||||
|
||||
soft_attn = (logits / temperature).softmax(dim = -1)
|
||||
return one_hot + soft_attn - soft_attn.detach()
|
||||
|
||||
# sparse topk attention - only keep topk attn logits for softmax
|
||||
# optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
|
||||
|
||||
def sparse_topk_attn(
|
||||
logits,
|
||||
sparse_topk,
|
||||
temperature = 1.,
|
||||
straight_through = False
|
||||
):
|
||||
orig_logits = logits
|
||||
|
||||
mask_value = -torch.finfo(logits.dtype).max
|
||||
top_values, _ = logits.topk(sparse_topk, dim = -1)
|
||||
sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
|
||||
logits = logits.masked_fill(~sparse_topk_mask, mask_value)
|
||||
topk_attn = logits.softmax(dim = -1)
|
||||
|
||||
if not straight_through:
|
||||
return topk_attn
|
||||
|
||||
soft_attn = (orig_logits / temperature).softmax(dim = -1)
|
||||
return topk_attn.detach() + soft_attn - soft_attn.detach()
|
||||
|
||||
# functions for creating causal mask
|
||||
# need a special one for onnx cpu (no support for .triu)
|
||||
|
||||
def create_causal_mask(i, j, device):
|
||||
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
||||
|
||||
def onnx_create_causal_mask(i, j, device):
|
||||
r = torch.arange(i, device = device)
|
||||
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
|
||||
causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
|
||||
return causal_mask
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
heads = None,
|
||||
pre_talking_heads = False,
|
||||
post_talking_heads = False,
|
||||
pre_scale_post_talking_heads = False,
|
||||
sparse_topk = None,
|
||||
sparse_topk_straight_through = False, # https://arxiv.org/abs/2505.22074
|
||||
scale = None,
|
||||
qk_norm = False,
|
||||
l2_distance = False,
|
||||
sigmoid = False,
|
||||
custom_attn_fn: Callable | None = None,
|
||||
flash = False,
|
||||
softclamp_logits = False,
|
||||
logit_softclamp_value = 50.,
|
||||
add_zero_kv = False,
|
||||
head_learned_sink = False,
|
||||
selective = False,
|
||||
hard = False,
|
||||
cope = None,
|
||||
onnxable = False,
|
||||
sdp_kwargs: dict = dict(
|
||||
enable_flash = True,
|
||||
enable_math = True,
|
||||
enable_mem_efficient = True
|
||||
)
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
# causal related
|
||||
|
||||
self.causal = causal
|
||||
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
|
||||
|
||||
# attention type
|
||||
|
||||
is_sparse_topk_attn = exists(sparse_topk)
|
||||
|
||||
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
||||
assert not (flash and hard), 'hard attention not available for flash'
|
||||
assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
|
||||
|
||||
assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
|
||||
|
||||
if exists(custom_attn_fn):
|
||||
self.attn_fn = custom_attn_fn
|
||||
elif sigmoid:
|
||||
self.attn_fn = F.sigmoid
|
||||
elif hard:
|
||||
self.attn_fn = one_hot_straight_through
|
||||
elif is_sparse_topk_attn:
|
||||
self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
|
||||
else:
|
||||
softmax_fn = partial(F.softmax, dim = -1)
|
||||
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
||||
|
||||
# dropouts
|
||||
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
|
||||
# talking heads
|
||||
|
||||
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
|
||||
|
||||
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
||||
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
||||
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
|
||||
|
||||
if exists(self.pre_softmax_talking_heads):
|
||||
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
||||
|
||||
if exists(self.post_softmax_talking_heads):
|
||||
nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
||||
|
||||
if exists(self.pre_scale_post_talking_heads):
|
||||
# an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
|
||||
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
|
||||
|
||||
# selective attention
|
||||
|
||||
assert not (flash and selective), 'selective attention cannot work on flash attention'
|
||||
assert not (selective and not causal), 'selective attention is designed for autoregressive'
|
||||
self.selective = selective
|
||||
|
||||
# l2 distance attention
|
||||
|
||||
self.l2_distance = l2_distance
|
||||
|
||||
# add a key / value token composed of zeros
|
||||
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
|
||||
|
||||
self.add_zero_kv = add_zero_kv
|
||||
|
||||
# learned sink concatted pre-softmax, working solution from gpt-oss
|
||||
|
||||
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
|
||||
|
||||
self.head_learned_sink = head_learned_sink
|
||||
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
|
||||
|
||||
# soft clamp attention logit value
|
||||
|
||||
if softclamp_logits:
|
||||
assert not flash, 'flash attention not compatible with logit softclamp value yet'
|
||||
assert logit_softclamp_value > 0.
|
||||
|
||||
self.softclamp_logits = softclamp_logits
|
||||
self.logit_softclamp_value = logit_softclamp_value
|
||||
|
||||
# contextual positional encoding
|
||||
|
||||
self.cope = cope
|
||||
|
||||
# flash attention
|
||||
|
||||
self.flash = flash
|
||||
|
||||
torch_version = version.parse(torch.__version__)
|
||||
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
# torch 2.3 uses new backend and context manager
|
||||
|
||||
if self.flash:
|
||||
if torch_version >= version.parse('2.3'):
|
||||
from torch.nn.attention import SDPBackend
|
||||
|
||||
str_to_backend = dict(
|
||||
enable_flash = SDPBackend.FLASH_ATTENTION,
|
||||
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
||||
enable_math = SDPBackend.MATH,
|
||||
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
||||
)
|
||||
|
||||
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
||||
|
||||
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
||||
else:
|
||||
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
||||
|
||||
def flash_attn(
|
||||
self,
|
||||
q, k, v,
|
||||
mask = None,
|
||||
attn_bias = None
|
||||
):
|
||||
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
||||
|
||||
# Recommended for multi-query single-key-value attention by Tri Dao
|
||||
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
||||
|
||||
if k.ndim == 3:
|
||||
k = repeat(k, 'b ... -> b h ...', h = q.shape[1])
|
||||
|
||||
if v.ndim == 3:
|
||||
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
|
||||
|
||||
# handle maybe l2 distance
|
||||
|
||||
if self.l2_distance:
|
||||
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
||||
k = F.pad(k, (0, 1), value = -1.)
|
||||
k = cat((k, k_norm_sq), dim = -1)
|
||||
|
||||
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
||||
q = cat((2 * q, q_norm_sq), dim = -1)
|
||||
q = F.pad(q, (0, 1), value = -1.)
|
||||
|
||||
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
||||
|
||||
if exists(self.scale):
|
||||
default_scale = q.shape[-1] ** -0.5
|
||||
q = q * (self.scale / default_scale)
|
||||
|
||||
# Check if mask exists and expand to compatible shape
|
||||
# The mask is B L, so it would have to be expanded to B H N L
|
||||
|
||||
causal = self.causal
|
||||
|
||||
# in the case of kv caching with one token (q_len == 1), just turn off causal masking
|
||||
# in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
|
||||
|
||||
if q_len == 1 and causal:
|
||||
causal = False
|
||||
|
||||
# expand key padding mask
|
||||
|
||||
if exists(mask):
|
||||
assert mask.ndim == 4
|
||||
mask = mask.expand(batch, heads, q_len, k_len)
|
||||
|
||||
# handle kv cache - this should be bypassable in updated flash attention 2
|
||||
|
||||
if k_len > q_len and causal:
|
||||
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||
if not exists(mask):
|
||||
mask = ~causal_mask
|
||||
else:
|
||||
mask = mask & ~causal_mask
|
||||
causal = False
|
||||
|
||||
# manually handle causal mask, if another mask was given
|
||||
|
||||
if exists(mask) and causal:
|
||||
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||
mask = mask & ~causal_mask
|
||||
causal = False
|
||||
|
||||
# protect against an entire row being masked out
|
||||
|
||||
row_is_entirely_masked = None
|
||||
|
||||
if exists(mask):
|
||||
row_is_entirely_masked = ~mask.any(dim = -1)
|
||||
|
||||
# handle alibi positional bias
|
||||
# convert from bool to float
|
||||
|
||||
if exists(attn_bias):
|
||||
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
||||
|
||||
# if mask given, the mask would already contain the causal mask from above logic
|
||||
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
||||
|
||||
mask_value = -torch.finfo(q.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
|
||||
elif causal:
|
||||
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
|
||||
causal = False
|
||||
|
||||
# scaled_dot_product_attention handles attn_mask either as bool or additive bias
|
||||
# make it an additive bias here
|
||||
|
||||
mask = attn_bias
|
||||
|
||||
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
||||
|
||||
with self.sdp_context_manager():
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = mask,
|
||||
dropout_p = self.dropout if self.training else 0.,
|
||||
is_causal = causal
|
||||
)
|
||||
|
||||
# for a row that is entirely masked out, should zero out the output of that row token
|
||||
|
||||
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
|
||||
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
||||
|
||||
return out, Intermediates()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q, k, v,
|
||||
mask = None,
|
||||
attn_bias = None,
|
||||
prev_attn = None
|
||||
):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
|
||||
|
||||
scale = default(self.scale, q.shape[-1] ** -0.5)
|
||||
|
||||
causal = self.causal
|
||||
|
||||
# handle key padding mask
|
||||
|
||||
if exists(mask) and mask.ndim == 2:
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
|
||||
# handle kv cached decoding
|
||||
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
# handle grouped multi-query attention
|
||||
|
||||
if kv_heads == 1:
|
||||
k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v))
|
||||
elif kv_heads < heads:
|
||||
k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v))
|
||||
|
||||
# handle zero kv, as means for allowing network to attend to nothing
|
||||
|
||||
if self.add_zero_kv:
|
||||
k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v))
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (1, 0), value = True)
|
||||
|
||||
if exists(attn_bias):
|
||||
attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
|
||||
|
||||
if self.flash:
|
||||
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
|
||||
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
|
||||
|
||||
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
||||
|
||||
if not self.l2_distance:
|
||||
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
|
||||
else:
|
||||
sim = -qk_l2_dist_squared(q, k)
|
||||
|
||||
sim = sim * scale
|
||||
|
||||
if exists(prev_attn):
|
||||
sim = sim + prev_attn
|
||||
|
||||
qk_similarities = sim.clone()
|
||||
|
||||
if exists(self.pre_scale_post_talking_heads):
|
||||
pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
|
||||
|
||||
if exists(self.pre_softmax_talking_heads):
|
||||
sim = sim + self.pre_softmax_talking_heads(sim)
|
||||
|
||||
if exists(attn_bias):
|
||||
sim = sim + attn_bias
|
||||
|
||||
if self.softclamp_logits:
|
||||
sim = softclamp(sim, self.logit_softclamp_value)
|
||||
|
||||
i, j, dtype = *sim.shape[-2:], sim.dtype
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
sim = sim.masked_fill(~mask, mask_value)
|
||||
|
||||
if causal:
|
||||
causal_mask = self.create_causal_mask(i, j, device = device)
|
||||
sim = sim.masked_fill(causal_mask, mask_value)
|
||||
|
||||
row_is_entirely_masked = None
|
||||
|
||||
if exists(mask):
|
||||
row_is_entirely_masked = ~mask.any(dim = -1)
|
||||
|
||||
if exists(self.cope):
|
||||
sim = sim + self.cope(q, sim)
|
||||
|
||||
if self.selective:
|
||||
sim = selective_attn(sim)
|
||||
|
||||
if self.head_learned_sink:
|
||||
# add learned attention sink
|
||||
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
|
||||
sim = cat((attn_sink, sim), dim = -1)
|
||||
|
||||
pre_softmax_attn = sim
|
||||
|
||||
attn = self.attn_fn(sim)
|
||||
|
||||
attn = attn.type(dtype)
|
||||
|
||||
post_softmax_attn = attn
|
||||
|
||||
if self.head_learned_sink:
|
||||
# remove attention sink
|
||||
attn = attn[..., 1:]
|
||||
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
if exists(self.post_softmax_talking_heads):
|
||||
attn = self.post_softmax_talking_heads(attn)
|
||||
|
||||
if exists(self.pre_scale_post_talking_heads):
|
||||
attn = attn * pre_to_post_scale
|
||||
|
||||
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
||||
|
||||
intermediates = Intermediates(
|
||||
qk_similarities = qk_similarities,
|
||||
pre_softmax_attn = pre_softmax_attn,
|
||||
post_softmax_attn = post_softmax_attn
|
||||
)
|
||||
|
||||
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
|
||||
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
||||
|
||||
return out, intermediates
|
||||
Reference in New Issue
Block a user