0925 use custom x_transformers for easy develop
This commit is contained in:
556
Amadeus/custom_attend.py
Normal file
556
Amadeus/custom_attend.py
Normal file
@ -0,0 +1,556 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Tuple, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Module, Parameter
|
||||||
|
from torch import cat, nn, einsum, Tensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import wraps
|
||||||
|
from packaging import version
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from einops import rearrange, repeat, pack, unpack
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Intermediates:
|
||||||
|
qk_similarities: Tensor | None = None
|
||||||
|
pre_softmax_attn: Tensor | None = None
|
||||||
|
post_softmax_attn: Tensor | None = None
|
||||||
|
values: Tensor | None = None
|
||||||
|
cached_kv: tuple[Tensor, Tensor] | None = None
|
||||||
|
layer_type: str | None = None
|
||||||
|
hybrid_hidden: Tensor | None = None
|
||||||
|
|
||||||
|
def to_tuple(self):
|
||||||
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
def at_most_one_of(*bools):
|
||||||
|
return sum([*map(int, bools)]) <= 1
|
||||||
|
|
||||||
|
def compact(arr):
|
||||||
|
return [*filter(exists, arr)]
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def softclamp(t: Tensor, value: float):
|
||||||
|
return (t / value).tanh() * value
|
||||||
|
|
||||||
|
def pack_one(t, pattern):
|
||||||
|
return pack([t], pattern)
|
||||||
|
|
||||||
|
def unpack_one(t, ps, pattern):
|
||||||
|
return unpack(t, ps, pattern)[0]
|
||||||
|
|
||||||
|
def once(fn):
|
||||||
|
called = False
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(x):
|
||||||
|
nonlocal called
|
||||||
|
if called:
|
||||||
|
return
|
||||||
|
called = True
|
||||||
|
return fn(x)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
print_once = once(print)
|
||||||
|
|
||||||
|
# selective attention
|
||||||
|
# https://arxiv.org/abs/2410.02703 - section 3.3
|
||||||
|
# it is a technique to allow each token to prevent itself from being attended to by future tokens
|
||||||
|
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
|
||||||
|
|
||||||
|
def selective_attn(
|
||||||
|
sim,
|
||||||
|
sim_head_gate = None,
|
||||||
|
no_mask_sos = True
|
||||||
|
):
|
||||||
|
i, j, device = *sim.shape[-2:], sim.device
|
||||||
|
sim_head_gate = default(sim_head_gate, sim[:, 0])
|
||||||
|
|
||||||
|
gate = F.relu(sim_head_gate) # only positive
|
||||||
|
|
||||||
|
if no_mask_sos:
|
||||||
|
gate = gate.clone()
|
||||||
|
gate[..., -i] = 0.
|
||||||
|
|
||||||
|
eye = torch.eye(i, device = device)
|
||||||
|
|
||||||
|
if j > i:
|
||||||
|
eye = F.pad(eye, (j - i, 0), value = 1.)
|
||||||
|
|
||||||
|
gate = (1. - eye) * gate
|
||||||
|
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
|
||||||
|
gate = gate.cumsum(dim = -2)
|
||||||
|
|
||||||
|
return sim - rearrange(gate, 'b i j -> b 1 i j')
|
||||||
|
|
||||||
|
# alternative distance functions
|
||||||
|
|
||||||
|
def qk_l2_dist_squared(q, k):
|
||||||
|
if k.ndim == 3:
|
||||||
|
k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
|
||||||
|
|
||||||
|
q, packed_shape = pack_one(q, '* i d')
|
||||||
|
k, _ = pack_one(k, '* j d')
|
||||||
|
|
||||||
|
l2_dist_squared = torch.cdist(q, k) ** 2
|
||||||
|
return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
||||||
|
|
||||||
|
# one-hot straight through softmax
|
||||||
|
|
||||||
|
def one_hot_straight_through(logits, temperature = 1.):
|
||||||
|
one_hot_indices = logits.argmax(dim = -1, keepdim = True)
|
||||||
|
one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
|
||||||
|
|
||||||
|
soft_attn = (logits / temperature).softmax(dim = -1)
|
||||||
|
return one_hot + soft_attn - soft_attn.detach()
|
||||||
|
|
||||||
|
# sparse topk attention - only keep topk attn logits for softmax
|
||||||
|
# optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
|
||||||
|
|
||||||
|
def sparse_topk_attn(
|
||||||
|
logits,
|
||||||
|
sparse_topk,
|
||||||
|
temperature = 1.,
|
||||||
|
straight_through = False
|
||||||
|
):
|
||||||
|
orig_logits = logits
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(logits.dtype).max
|
||||||
|
top_values, _ = logits.topk(sparse_topk, dim = -1)
|
||||||
|
sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
|
||||||
|
logits = logits.masked_fill(~sparse_topk_mask, mask_value)
|
||||||
|
topk_attn = logits.softmax(dim = -1)
|
||||||
|
|
||||||
|
if not straight_through:
|
||||||
|
return topk_attn
|
||||||
|
|
||||||
|
soft_attn = (orig_logits / temperature).softmax(dim = -1)
|
||||||
|
return topk_attn.detach() + soft_attn - soft_attn.detach()
|
||||||
|
|
||||||
|
# functions for creating causal mask
|
||||||
|
# need a special one for onnx cpu (no support for .triu)
|
||||||
|
|
||||||
|
def create_causal_mask(i, j, device):
|
||||||
|
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
||||||
|
|
||||||
|
def onnx_create_causal_mask(i, j, device):
|
||||||
|
r = torch.arange(i, device = device)
|
||||||
|
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
|
||||||
|
causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
# main class
|
||||||
|
|
||||||
|
class Attend(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dropout = 0.,
|
||||||
|
causal = False,
|
||||||
|
heads = None,
|
||||||
|
pre_talking_heads = False,
|
||||||
|
post_talking_heads = False,
|
||||||
|
pre_scale_post_talking_heads = False,
|
||||||
|
sparse_topk = None,
|
||||||
|
sparse_topk_straight_through = False, # https://arxiv.org/abs/2505.22074
|
||||||
|
scale = None,
|
||||||
|
qk_norm = False,
|
||||||
|
l2_distance = False,
|
||||||
|
sigmoid = False,
|
||||||
|
custom_attn_fn: Callable | None = None,
|
||||||
|
flash = False,
|
||||||
|
softclamp_logits = False,
|
||||||
|
logit_softclamp_value = 50.,
|
||||||
|
add_zero_kv = False,
|
||||||
|
head_learned_sink = False,
|
||||||
|
selective = False,
|
||||||
|
hard = False,
|
||||||
|
cope = None,
|
||||||
|
onnxable = False,
|
||||||
|
sdp_kwargs: dict = dict(
|
||||||
|
enable_flash = True,
|
||||||
|
enable_math = True,
|
||||||
|
enable_mem_efficient = True
|
||||||
|
)
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
# causal related
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
|
||||||
|
|
||||||
|
# attention type
|
||||||
|
|
||||||
|
is_sparse_topk_attn = exists(sparse_topk)
|
||||||
|
|
||||||
|
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
||||||
|
assert not (flash and hard), 'hard attention not available for flash'
|
||||||
|
assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
|
||||||
|
|
||||||
|
assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
|
||||||
|
|
||||||
|
if exists(custom_attn_fn):
|
||||||
|
self.attn_fn = custom_attn_fn
|
||||||
|
elif sigmoid:
|
||||||
|
self.attn_fn = F.sigmoid
|
||||||
|
elif hard:
|
||||||
|
self.attn_fn = one_hot_straight_through
|
||||||
|
elif is_sparse_topk_attn:
|
||||||
|
self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
|
||||||
|
else:
|
||||||
|
softmax_fn = partial(F.softmax, dim = -1)
|
||||||
|
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
||||||
|
|
||||||
|
# dropouts
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
# talking heads
|
||||||
|
|
||||||
|
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
|
||||||
|
|
||||||
|
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
||||||
|
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
||||||
|
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
|
||||||
|
|
||||||
|
if exists(self.pre_softmax_talking_heads):
|
||||||
|
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
||||||
|
|
||||||
|
if exists(self.post_softmax_talking_heads):
|
||||||
|
nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
||||||
|
|
||||||
|
if exists(self.pre_scale_post_talking_heads):
|
||||||
|
# an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
|
||||||
|
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
|
||||||
|
|
||||||
|
# selective attention
|
||||||
|
|
||||||
|
assert not (flash and selective), 'selective attention cannot work on flash attention'
|
||||||
|
assert not (selective and not causal), 'selective attention is designed for autoregressive'
|
||||||
|
self.selective = selective
|
||||||
|
|
||||||
|
# l2 distance attention
|
||||||
|
|
||||||
|
self.l2_distance = l2_distance
|
||||||
|
|
||||||
|
# add a key / value token composed of zeros
|
||||||
|
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
|
||||||
|
|
||||||
|
self.add_zero_kv = add_zero_kv
|
||||||
|
|
||||||
|
# learned sink concatted pre-softmax, working solution from gpt-oss
|
||||||
|
|
||||||
|
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
|
||||||
|
|
||||||
|
self.head_learned_sink = head_learned_sink
|
||||||
|
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
|
||||||
|
|
||||||
|
# soft clamp attention logit value
|
||||||
|
|
||||||
|
if softclamp_logits:
|
||||||
|
assert not flash, 'flash attention not compatible with logit softclamp value yet'
|
||||||
|
assert logit_softclamp_value > 0.
|
||||||
|
|
||||||
|
self.softclamp_logits = softclamp_logits
|
||||||
|
self.logit_softclamp_value = logit_softclamp_value
|
||||||
|
|
||||||
|
# contextual positional encoding
|
||||||
|
|
||||||
|
self.cope = cope
|
||||||
|
|
||||||
|
# flash attention
|
||||||
|
|
||||||
|
self.flash = flash
|
||||||
|
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||||
|
|
||||||
|
# torch 2.3 uses new backend and context manager
|
||||||
|
|
||||||
|
if self.flash:
|
||||||
|
if torch_version >= version.parse('2.3'):
|
||||||
|
from torch.nn.attention import SDPBackend
|
||||||
|
|
||||||
|
str_to_backend = dict(
|
||||||
|
enable_flash = SDPBackend.FLASH_ATTENTION,
|
||||||
|
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
||||||
|
enable_math = SDPBackend.MATH,
|
||||||
|
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
||||||
|
)
|
||||||
|
|
||||||
|
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
||||||
|
|
||||||
|
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
||||||
|
else:
|
||||||
|
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
||||||
|
|
||||||
|
def flash_attn(
|
||||||
|
self,
|
||||||
|
q, k, v,
|
||||||
|
mask = None,
|
||||||
|
attn_bias = None
|
||||||
|
):
|
||||||
|
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
||||||
|
|
||||||
|
# Recommended for multi-query single-key-value attention by Tri Dao
|
||||||
|
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
||||||
|
|
||||||
|
if k.ndim == 3:
|
||||||
|
k = repeat(k, 'b ... -> b h ...', h = q.shape[1])
|
||||||
|
|
||||||
|
if v.ndim == 3:
|
||||||
|
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
|
||||||
|
|
||||||
|
# handle maybe l2 distance
|
||||||
|
|
||||||
|
if self.l2_distance:
|
||||||
|
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
||||||
|
k = F.pad(k, (0, 1), value = -1.)
|
||||||
|
k = cat((k, k_norm_sq), dim = -1)
|
||||||
|
|
||||||
|
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
||||||
|
q = cat((2 * q, q_norm_sq), dim = -1)
|
||||||
|
q = F.pad(q, (0, 1), value = -1.)
|
||||||
|
|
||||||
|
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
||||||
|
|
||||||
|
if exists(self.scale):
|
||||||
|
default_scale = q.shape[-1] ** -0.5
|
||||||
|
q = q * (self.scale / default_scale)
|
||||||
|
|
||||||
|
# Check if mask exists and expand to compatible shape
|
||||||
|
# The mask is B L, so it would have to be expanded to B H N L
|
||||||
|
|
||||||
|
causal = self.causal
|
||||||
|
|
||||||
|
# in the case of kv caching with one token (q_len == 1), just turn off causal masking
|
||||||
|
# in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
|
||||||
|
|
||||||
|
if q_len == 1 and causal:
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# expand key padding mask
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
assert mask.ndim == 4
|
||||||
|
mask = mask.expand(batch, heads, q_len, k_len)
|
||||||
|
|
||||||
|
# handle kv cache - this should be bypassable in updated flash attention 2
|
||||||
|
|
||||||
|
if k_len > q_len and causal:
|
||||||
|
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||||
|
if not exists(mask):
|
||||||
|
mask = ~causal_mask
|
||||||
|
else:
|
||||||
|
mask = mask & ~causal_mask
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# manually handle causal mask, if another mask was given
|
||||||
|
|
||||||
|
if exists(mask) and causal:
|
||||||
|
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||||
|
mask = mask & ~causal_mask
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# protect against an entire row being masked out
|
||||||
|
|
||||||
|
row_is_entirely_masked = None
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
row_is_entirely_masked = ~mask.any(dim = -1)
|
||||||
|
|
||||||
|
# handle alibi positional bias
|
||||||
|
# convert from bool to float
|
||||||
|
|
||||||
|
if exists(attn_bias):
|
||||||
|
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
||||||
|
|
||||||
|
# if mask given, the mask would already contain the causal mask from above logic
|
||||||
|
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(q.dtype).max
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
|
||||||
|
elif causal:
|
||||||
|
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
||||||
|
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# scaled_dot_product_attention handles attn_mask either as bool or additive bias
|
||||||
|
# make it an additive bias here
|
||||||
|
|
||||||
|
mask = attn_bias
|
||||||
|
|
||||||
|
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
||||||
|
|
||||||
|
with self.sdp_context_manager():
|
||||||
|
out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask = mask,
|
||||||
|
dropout_p = self.dropout if self.training else 0.,
|
||||||
|
is_causal = causal
|
||||||
|
)
|
||||||
|
|
||||||
|
# for a row that is entirely masked out, should zero out the output of that row token
|
||||||
|
|
||||||
|
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
|
||||||
|
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
||||||
|
|
||||||
|
return out, Intermediates()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q, k, v,
|
||||||
|
mask = None,
|
||||||
|
attn_bias = None,
|
||||||
|
prev_attn = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
|
||||||
|
|
||||||
|
scale = default(self.scale, q.shape[-1] ** -0.5)
|
||||||
|
|
||||||
|
causal = self.causal
|
||||||
|
|
||||||
|
# handle key padding mask
|
||||||
|
|
||||||
|
if exists(mask) and mask.ndim == 2:
|
||||||
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
|
|
||||||
|
# handle kv cached decoding
|
||||||
|
|
||||||
|
if n == 1 and causal:
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
# handle grouped multi-query attention
|
||||||
|
|
||||||
|
if kv_heads == 1:
|
||||||
|
k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v))
|
||||||
|
elif kv_heads < heads:
|
||||||
|
k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v))
|
||||||
|
|
||||||
|
# handle zero kv, as means for allowing network to attend to nothing
|
||||||
|
|
||||||
|
if self.add_zero_kv:
|
||||||
|
k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v))
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = F.pad(mask, (1, 0), value = True)
|
||||||
|
|
||||||
|
if exists(attn_bias):
|
||||||
|
attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
|
||||||
|
|
||||||
|
if self.flash:
|
||||||
|
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
|
||||||
|
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
|
||||||
|
|
||||||
|
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
||||||
|
|
||||||
|
if not self.l2_distance:
|
||||||
|
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
|
||||||
|
else:
|
||||||
|
sim = -qk_l2_dist_squared(q, k)
|
||||||
|
|
||||||
|
sim = sim * scale
|
||||||
|
|
||||||
|
if exists(prev_attn):
|
||||||
|
sim = sim + prev_attn
|
||||||
|
|
||||||
|
qk_similarities = sim.clone()
|
||||||
|
|
||||||
|
if exists(self.pre_scale_post_talking_heads):
|
||||||
|
pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
|
||||||
|
|
||||||
|
if exists(self.pre_softmax_talking_heads):
|
||||||
|
sim = sim + self.pre_softmax_talking_heads(sim)
|
||||||
|
|
||||||
|
if exists(attn_bias):
|
||||||
|
sim = sim + attn_bias
|
||||||
|
|
||||||
|
if self.softclamp_logits:
|
||||||
|
sim = softclamp(sim, self.logit_softclamp_value)
|
||||||
|
|
||||||
|
i, j, dtype = *sim.shape[-2:], sim.dtype
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
sim = sim.masked_fill(~mask, mask_value)
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
||||||
|
sim = sim.masked_fill(causal_mask, mask_value)
|
||||||
|
|
||||||
|
row_is_entirely_masked = None
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
row_is_entirely_masked = ~mask.any(dim = -1)
|
||||||
|
|
||||||
|
if exists(self.cope):
|
||||||
|
sim = sim + self.cope(q, sim)
|
||||||
|
|
||||||
|
if self.selective:
|
||||||
|
sim = selective_attn(sim)
|
||||||
|
|
||||||
|
if self.head_learned_sink:
|
||||||
|
# add learned attention sink
|
||||||
|
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
|
||||||
|
sim = cat((attn_sink, sim), dim = -1)
|
||||||
|
|
||||||
|
pre_softmax_attn = sim
|
||||||
|
|
||||||
|
attn = self.attn_fn(sim)
|
||||||
|
|
||||||
|
attn = attn.type(dtype)
|
||||||
|
|
||||||
|
post_softmax_attn = attn
|
||||||
|
|
||||||
|
if self.head_learned_sink:
|
||||||
|
# remove attention sink
|
||||||
|
attn = attn[..., 1:]
|
||||||
|
|
||||||
|
attn = self.attn_dropout(attn)
|
||||||
|
|
||||||
|
if exists(self.post_softmax_talking_heads):
|
||||||
|
attn = self.post_softmax_talking_heads(attn)
|
||||||
|
|
||||||
|
if exists(self.pre_scale_post_talking_heads):
|
||||||
|
attn = attn * pre_to_post_scale
|
||||||
|
|
||||||
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
||||||
|
|
||||||
|
intermediates = Intermediates(
|
||||||
|
qk_similarities = qk_similarities,
|
||||||
|
pre_softmax_attn = pre_softmax_attn,
|
||||||
|
post_softmax_attn = post_softmax_attn
|
||||||
|
)
|
||||||
|
|
||||||
|
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
|
||||||
|
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
||||||
|
|
||||||
|
return out, intermediates
|
||||||
581
Amadeus/custom_wrapper.py
Normal file
581
Amadeus/custom_wrapper.py
Normal file
@ -0,0 +1,581 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from math import ceil, log
|
||||||
|
from typing import Tuple, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, tensor, Tensor
|
||||||
|
from torch.nn import Module
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
from einops import rearrange, repeat, pack, unpack
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
def identity(t, *args, **kwargs):
|
||||||
|
return t
|
||||||
|
|
||||||
|
def join(arr, delimiter = ', '):
|
||||||
|
return delimiter.join(arr)
|
||||||
|
|
||||||
|
def cast_tuple(t, length = 1):
|
||||||
|
return t if isinstance(t, tuple) else (t,) * length
|
||||||
|
|
||||||
|
def eval_decorator(fn):
|
||||||
|
def inner(self, *args, **kwargs):
|
||||||
|
was_training = self.training
|
||||||
|
self.eval()
|
||||||
|
out = fn(self, *args, **kwargs)
|
||||||
|
self.train(was_training)
|
||||||
|
return out
|
||||||
|
return inner
|
||||||
|
|
||||||
|
# gumbel topk
|
||||||
|
|
||||||
|
def log(t, eps = 1e-20):
|
||||||
|
return t.clamp(min = eps).log()
|
||||||
|
|
||||||
|
def gumbel_noise(t):
|
||||||
|
return -log(-log(torch.rand_like(t)))
|
||||||
|
|
||||||
|
# function for modifying all the cached key / values
|
||||||
|
|
||||||
|
def modify_cached_kv(cache, fn):
|
||||||
|
for inter in cache.attn_intermediates:
|
||||||
|
if inter.layer_type == 'a':
|
||||||
|
inter.cached_kv = [fn(t) for t in inter.cached_kv]
|
||||||
|
|
||||||
|
# for variable lengthed prefixes
|
||||||
|
|
||||||
|
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
||||||
|
if pad == (0, 0):
|
||||||
|
return t
|
||||||
|
|
||||||
|
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||||
|
zeros = ((0, 0) * dims_from_right)
|
||||||
|
return F.pad(t, (*zeros, *pad), value = value)
|
||||||
|
|
||||||
|
def align_right(t, lens, pad_id = 0):
|
||||||
|
batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
|
||||||
|
|
||||||
|
assert lens.ndim == 1 and lens.shape[0] == batch
|
||||||
|
assert lens.amax() <= seq_len
|
||||||
|
|
||||||
|
pad_lens = seq_len - lens
|
||||||
|
max_pad_len = pad_lens.amax()
|
||||||
|
|
||||||
|
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
|
||||||
|
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
|
||||||
|
|
||||||
|
t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
|
||||||
|
offset = max_pad_len - pad_lens
|
||||||
|
|
||||||
|
aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
|
||||||
|
return aligned
|
||||||
|
|
||||||
|
# nucleus
|
||||||
|
|
||||||
|
def top_p(logits, thres = 0.9):
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending = True)
|
||||||
|
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = cum_probs > thres
|
||||||
|
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
|
||||||
|
|
||||||
|
sorted_logits[sorted_indices_to_remove] = float('-inf')
|
||||||
|
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
||||||
|
|
||||||
|
# topk
|
||||||
|
|
||||||
|
def top_k(logits, frac_num_tokens = 0.1, k = None):
|
||||||
|
num_tokens = logits.shape[-1]
|
||||||
|
|
||||||
|
k = default(k, ceil(frac_num_tokens * num_tokens))
|
||||||
|
k = min(k, num_tokens)
|
||||||
|
|
||||||
|
val, ind = torch.topk(logits, k)
|
||||||
|
probs = torch.full_like(logits, float('-inf'))
|
||||||
|
probs.scatter_(1, ind, val)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
# top_a
|
||||||
|
|
||||||
|
def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
|
||||||
|
probs = logits.softmax(dim = -1)
|
||||||
|
max_probs = probs.amax(dim = -1, keepdim = True)
|
||||||
|
limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
|
||||||
|
return torch.where(probs < limit, float('-inf'), logits)
|
||||||
|
|
||||||
|
# min_p
|
||||||
|
# https://arxiv.org/abs/2407.01082
|
||||||
|
|
||||||
|
def min_p(logits, min_p = 0.1):
|
||||||
|
probs = logits.softmax(dim = -1)
|
||||||
|
max_probs = probs.amax(dim = -1, keepdim = True)
|
||||||
|
limit = min_p * max_probs
|
||||||
|
return torch.where(probs < limit, float('-inf'), logits)
|
||||||
|
|
||||||
|
# filter logits functions dict[str -> Callable]
|
||||||
|
|
||||||
|
FILTER_LOGITS_FN = dict(
|
||||||
|
top_p = top_p,
|
||||||
|
top_k = top_k,
|
||||||
|
top_a = top_a,
|
||||||
|
min_p = min_p
|
||||||
|
)
|
||||||
|
|
||||||
|
# contrastive decoding function
|
||||||
|
|
||||||
|
def contrastive_decode_fn(
|
||||||
|
expert_logits,
|
||||||
|
amateur_logits,
|
||||||
|
alpha = 0.1,
|
||||||
|
beta = 0.5
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Appendix A Algorithm 2
|
||||||
|
https://arxiv.org/abs/2309.09117
|
||||||
|
"""
|
||||||
|
|
||||||
|
cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
|
||||||
|
diffs = (1 + beta) * expert_logits - beta * amateur_logits
|
||||||
|
contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
|
||||||
|
return contrastive_decode_logits
|
||||||
|
|
||||||
|
# autoregressive wrapper class
|
||||||
|
|
||||||
|
class AutoregressiveWrapper(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
net,
|
||||||
|
ignore_index = -100,
|
||||||
|
pad_value = 0,
|
||||||
|
mask_prob = 0.,
|
||||||
|
add_attn_z_loss = False,
|
||||||
|
next_embed_loss_weight = 0.1
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_value = pad_value
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
|
self.net = net
|
||||||
|
self.max_seq_len = net.max_seq_len
|
||||||
|
|
||||||
|
# paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
|
||||||
|
assert mask_prob < 1.
|
||||||
|
self.mask_prob = mask_prob
|
||||||
|
|
||||||
|
# whether to add router z-loss
|
||||||
|
self.add_attn_z_loss = add_attn_z_loss
|
||||||
|
|
||||||
|
# whether to add a continuous loss
|
||||||
|
self.add_continuous_pred_head = net.add_continuous_pred_head
|
||||||
|
self.next_embed_loss_weight = next_embed_loss_weight
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@eval_decorator
|
||||||
|
def beam_search(
|
||||||
|
self,
|
||||||
|
prompts,
|
||||||
|
seq_len,
|
||||||
|
beams = 4,
|
||||||
|
return_beams_and_scores = False,
|
||||||
|
eos_token = None,
|
||||||
|
temperature = 1.,
|
||||||
|
stochastic = False,
|
||||||
|
prompt_lens: Tensor | None = None,
|
||||||
|
filter_logits_fn: str | Callable = identity,
|
||||||
|
restrict_to_max_seq_len = True,
|
||||||
|
filter_kwargs: dict = dict(),
|
||||||
|
cache_kv = True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert not exists(eos_token), 'eos token not supported yet'
|
||||||
|
|
||||||
|
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
||||||
|
|
||||||
|
prompts, packed_shape = pack([prompts], '* n')
|
||||||
|
|
||||||
|
batch, orig_seq_len = prompts.shape
|
||||||
|
|
||||||
|
# handle filter logits fn given as string
|
||||||
|
|
||||||
|
if isinstance(filter_logits_fn, str):
|
||||||
|
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
||||||
|
|
||||||
|
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
||||||
|
|
||||||
|
# handle variable lengthed prompts (prefixes)
|
||||||
|
|
||||||
|
seq_start_pos = None
|
||||||
|
if exists(prompt_lens):
|
||||||
|
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
||||||
|
seq_start_pos = orig_seq_len - prompt_lens
|
||||||
|
|
||||||
|
# output from which sampled tokens appended to
|
||||||
|
|
||||||
|
out = prompts
|
||||||
|
|
||||||
|
# kv caches
|
||||||
|
|
||||||
|
cache = None
|
||||||
|
|
||||||
|
should_cache = cache_kv and self.net.can_cache_kv
|
||||||
|
|
||||||
|
# scores for the beams
|
||||||
|
|
||||||
|
scores = torch.zeros((batch,), device = device)
|
||||||
|
|
||||||
|
batch_arange = torch.arange(batch, device = device)
|
||||||
|
|
||||||
|
# sampling up to seq_len
|
||||||
|
|
||||||
|
for i in range(seq_len):
|
||||||
|
is_first = i == 0
|
||||||
|
|
||||||
|
if restrict_to_max_seq_len:
|
||||||
|
max_len_exceeded = out.shape[-1] > max_seq_len
|
||||||
|
|
||||||
|
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
||||||
|
|
||||||
|
x = out[:, -max_seq_len:]
|
||||||
|
|
||||||
|
if exists(cache):
|
||||||
|
modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :])
|
||||||
|
|
||||||
|
logits, new_cache = self.net(
|
||||||
|
x,
|
||||||
|
return_intermediates = True,
|
||||||
|
cache = cache,
|
||||||
|
seq_start_pos = seq_start_pos,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_cache:
|
||||||
|
cache = new_cache
|
||||||
|
|
||||||
|
logits = logits[:, -1]
|
||||||
|
|
||||||
|
# to add to the scores
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim = -1)
|
||||||
|
|
||||||
|
# maybe filter by top_k, top_p (nucleus) for stochastic beam search
|
||||||
|
|
||||||
|
if stochastic and not greedy:
|
||||||
|
logits = filter_logits_fn(logits, **filter_kwargs)
|
||||||
|
logits = (logits / temperature) + gumbel_noise(logits)
|
||||||
|
|
||||||
|
# (gumbel) topk
|
||||||
|
|
||||||
|
samples = logits.topk(beams, dim = -1).indices
|
||||||
|
|
||||||
|
# get the scores for keeping track of beams
|
||||||
|
|
||||||
|
next_scores = log_probs.gather(-1, samples)
|
||||||
|
|
||||||
|
# expand beam times
|
||||||
|
|
||||||
|
scores = repeat(scores, 'b -> b beams', beams = beams)
|
||||||
|
scores = scores + next_scores
|
||||||
|
|
||||||
|
out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
|
||||||
|
samples = rearrange(samples, 'b beams -> (b beams) 1')
|
||||||
|
|
||||||
|
if should_cache and is_first:
|
||||||
|
modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
|
||||||
|
|
||||||
|
# concat sample
|
||||||
|
|
||||||
|
out = torch.cat((out, samples), dim=-1)
|
||||||
|
|
||||||
|
# sort by score and excise
|
||||||
|
# excise out the beams
|
||||||
|
|
||||||
|
scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch)
|
||||||
|
curr_num_beams = scores.shape[-1]
|
||||||
|
|
||||||
|
if curr_num_beams > beams:
|
||||||
|
scores, sort_indices = scores.sort(dim = -1, descending = True)
|
||||||
|
|
||||||
|
scores = scores[:, :beams]
|
||||||
|
top_beams_indices = sort_indices[:, :beams]
|
||||||
|
|
||||||
|
top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
|
||||||
|
|
||||||
|
flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
|
||||||
|
|
||||||
|
out = out[flattened_beam_indices]
|
||||||
|
|
||||||
|
scores = rearrange(scores, 'b beams -> (b beams)')
|
||||||
|
|
||||||
|
if not exists(eos_token):
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_eos_tokens = (out == eos_token)
|
||||||
|
|
||||||
|
if is_eos_tokens.any(dim = -1).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
if exists(eos_token):
|
||||||
|
# mask out everything after the eos tokens
|
||||||
|
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
||||||
|
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
||||||
|
out = out.masked_fill(mask, self.pad_value)
|
||||||
|
|
||||||
|
# select out the top beam
|
||||||
|
|
||||||
|
out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
|
||||||
|
|
||||||
|
out = out[..., orig_seq_len:]
|
||||||
|
|
||||||
|
out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension
|
||||||
|
|
||||||
|
if not return_beams_and_scores:
|
||||||
|
return out[..., 0, :]
|
||||||
|
|
||||||
|
scores = rearrange(scores, '(b beams) -> beams b', b = batch)
|
||||||
|
out = rearrange(out, 'b beams n -> beams b n')
|
||||||
|
|
||||||
|
return out, scores
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@eval_decorator
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: list[Tensor] | Tensor,
|
||||||
|
seq_len,
|
||||||
|
eos_token = None,
|
||||||
|
temperature = 1.,
|
||||||
|
prompt_lens: Tensor | None = None,
|
||||||
|
filter_logits_fn: str | Callable = top_k,
|
||||||
|
restrict_to_max_seq_len = True,
|
||||||
|
amateur_model: Module | Tuple[Module] | None = None,
|
||||||
|
filter_kwargs: dict = dict(),
|
||||||
|
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
|
||||||
|
beta = 0.5,
|
||||||
|
alpha = 0.1
|
||||||
|
),
|
||||||
|
cache_kv = True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
max_seq_len, greedy = self.max_seq_len, temperature == 0.
|
||||||
|
|
||||||
|
# handle prompts given as list of variable lengthed token ids
|
||||||
|
|
||||||
|
if isinstance(prompts, list):
|
||||||
|
assert len(prompts) > 0, 'prompts cannot be empty list'
|
||||||
|
assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors'
|
||||||
|
|
||||||
|
prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device)
|
||||||
|
|
||||||
|
prompts = pad_sequence(prompts, batch_first = True)
|
||||||
|
|
||||||
|
# pack maybe no batch
|
||||||
|
|
||||||
|
prompts, ps = pack([prompts], '* n')
|
||||||
|
|
||||||
|
b, t, device = *prompts.shape, prompts.device
|
||||||
|
|
||||||
|
# handle filter logits fn given as string
|
||||||
|
|
||||||
|
if isinstance(filter_logits_fn, str):
|
||||||
|
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
||||||
|
|
||||||
|
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
||||||
|
|
||||||
|
# handle variable lengthed prompts (prefixes)
|
||||||
|
|
||||||
|
seq_start_pos = None
|
||||||
|
if exists(prompt_lens):
|
||||||
|
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
||||||
|
seq_start_pos = t - prompt_lens
|
||||||
|
|
||||||
|
# output from which sampled tokens appended to
|
||||||
|
|
||||||
|
out = prompts
|
||||||
|
|
||||||
|
# kv caches
|
||||||
|
|
||||||
|
cache = None
|
||||||
|
|
||||||
|
# if doing contrastive decoding, turn off filter automatically
|
||||||
|
|
||||||
|
if exists(amateur_model):
|
||||||
|
amateur_model = cast_tuple(amateur_model)
|
||||||
|
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
|
||||||
|
|
||||||
|
assert len(amateur_model) == len(contrastive_decode_kwargs)
|
||||||
|
|
||||||
|
amateur_caches = [None] * len(amateur_model)
|
||||||
|
filter_logits_fn = identity
|
||||||
|
|
||||||
|
for i, module in enumerate(amateur_model):
|
||||||
|
if isinstance(module, AutoregressiveWrapper):
|
||||||
|
amateur_model[i] = module.net
|
||||||
|
|
||||||
|
module.eval()
|
||||||
|
|
||||||
|
# sampling up to seq_len
|
||||||
|
|
||||||
|
for _ in range(seq_len):
|
||||||
|
|
||||||
|
if restrict_to_max_seq_len:
|
||||||
|
max_len_exceeded = out.shape[-1] > max_seq_len
|
||||||
|
|
||||||
|
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
||||||
|
|
||||||
|
x = out[:, -max_seq_len:]
|
||||||
|
|
||||||
|
if exists(cache):
|
||||||
|
for inter in cache.attn_intermediates:
|
||||||
|
if inter.layer_type == 'a':
|
||||||
|
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
|
||||||
|
|
||||||
|
logits, new_cache = self.net(
|
||||||
|
x,
|
||||||
|
return_intermediates = True,
|
||||||
|
cache = cache,
|
||||||
|
seq_start_pos = seq_start_pos,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if cache_kv and self.net.can_cache_kv:
|
||||||
|
cache = new_cache
|
||||||
|
|
||||||
|
logits = logits[:, -1]
|
||||||
|
|
||||||
|
# handle contrastive decoding, Li et al.
|
||||||
|
# https://arxiv.org/abs/2210.15097
|
||||||
|
|
||||||
|
if exists(amateur_model):
|
||||||
|
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
|
||||||
|
amateur_logits, next_amateur_cache = amateur(
|
||||||
|
x,
|
||||||
|
return_intermediates = True,
|
||||||
|
cache = amateur_cache,
|
||||||
|
seq_start_pos = seq_start_pos,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
amateur_logits = amateur_logits[:, -1]
|
||||||
|
|
||||||
|
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
|
||||||
|
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
|
||||||
|
|
||||||
|
if cache_kv and amateur.can_cache_kv:
|
||||||
|
amateur_caches[i] = next_amateur_cache
|
||||||
|
|
||||||
|
# filter by top_k, top_p (nucleus), top_a, or custom
|
||||||
|
|
||||||
|
if greedy:
|
||||||
|
sample = logits.argmax(dim = -1, keepdim = True)
|
||||||
|
else:
|
||||||
|
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
|
||||||
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
||||||
|
sample = torch.multinomial(probs, 1)
|
||||||
|
|
||||||
|
# concat sample
|
||||||
|
|
||||||
|
out = torch.cat((out, sample), dim=-1)
|
||||||
|
|
||||||
|
if not exists(eos_token):
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_eos_tokens = (out == eos_token)
|
||||||
|
|
||||||
|
if is_eos_tokens.any(dim = -1).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
if exists(eos_token):
|
||||||
|
# mask out everything after the eos tokens
|
||||||
|
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
||||||
|
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
||||||
|
out = out.masked_fill(mask, self.pad_value)
|
||||||
|
|
||||||
|
out = out[:, t:]
|
||||||
|
|
||||||
|
out, = unpack(out, ps, '* n')
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
return_outputs = False,
|
||||||
|
prepend_embeds = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
|
||||||
|
|
||||||
|
inp, target = x, x[:, 1:]
|
||||||
|
inp = torch.where(inp == ignore_index, self.pad_value, inp)
|
||||||
|
|
||||||
|
if self.mask_prob > 0.:
|
||||||
|
rand = torch.randn(inp.shape, device = x.device)
|
||||||
|
rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
|
||||||
|
num_mask = min(int(seq * self.mask_prob), seq - 1)
|
||||||
|
indices = rand.topk(num_mask, dim = -1).indices
|
||||||
|
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
|
||||||
|
kwargs.update(self_attn_kv_mask = mask)
|
||||||
|
|
||||||
|
out, cache = self.net(
|
||||||
|
inp,
|
||||||
|
return_intermediates = True,
|
||||||
|
return_attn_z_loss = add_attn_z_loss,
|
||||||
|
return_next_embed_pred = add_next_embed_loss,
|
||||||
|
prepend_embeds = prepend_embeds,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# destruct differently if doing continuous pred
|
||||||
|
|
||||||
|
if add_next_embed_loss:
|
||||||
|
logits, (next_embed_pred, init_embeds) = out
|
||||||
|
else:
|
||||||
|
logits = out
|
||||||
|
|
||||||
|
# if there are prepended embeds, excise it out
|
||||||
|
|
||||||
|
if exists(prepend_embeds):
|
||||||
|
prepend_len = prepend_embeds.shape[1]
|
||||||
|
logits = logits[:, prepend_len:]
|
||||||
|
|
||||||
|
# take all tokens but the last
|
||||||
|
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
|
# loss function
|
||||||
|
|
||||||
|
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
|
||||||
|
|
||||||
|
# cross entropy loss
|
||||||
|
|
||||||
|
loss = loss_fn(
|
||||||
|
rearrange(logits, 'b n c -> b c n'),
|
||||||
|
target,
|
||||||
|
ignore_index = ignore_index
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_attn_z_loss:
|
||||||
|
loss = loss + cache.attn_z_loss
|
||||||
|
|
||||||
|
if add_next_embed_loss:
|
||||||
|
mask = target != ignore_index
|
||||||
|
embed_pred = next_embed_pred[:, :-1]
|
||||||
|
cont_targets = init_embeds[:, 1:].detach()
|
||||||
|
|
||||||
|
cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none')
|
||||||
|
cont_loss = cont_loss[mask].mean()
|
||||||
|
|
||||||
|
loss = loss + cont_loss * self.next_embed_loss_weight
|
||||||
|
|
||||||
|
if not return_outputs:
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return loss, (logits, cache)
|
||||||
3616
Amadeus/custom_x_transformers.py
Normal file
3616
Amadeus/custom_x_transformers.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,8 @@
|
|||||||
defaults:
|
defaults:
|
||||||
# - nn_params: nb8_embSum_NMT
|
# - nn_params: nb8_embSum_NMT
|
||||||
# - nn_params: remi8
|
# - nn_params: remi8
|
||||||
- nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
||||||
# - nn_params: nb8_embSum_diff_t2m_150M_pretraining
|
- nn_params: nb8_embSum_diff_t2m_150M_pretrainingv2
|
||||||
# - nn_params: nb8_embSum_subPararell
|
# - nn_params: nb8_embSum_subPararell
|
||||||
# - nn_params: nb8_embSum_diff_t2m_150M
|
# - nn_params: nb8_embSum_diff_t2m_150M
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: AmadeusModel
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 768
|
||||||
|
num_layer: 20
|
||||||
|
num_head: 12
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from x_transformers import Decoder, Encoder, PrefixDecoder, CrossAttender
|
from Amadeus.custom_x_transformers import Decoder, Encoder, PrefixDecoder, CrossAttender
|
||||||
from transformers import T5EncoderModel
|
from transformers import T5EncoderModel
|
||||||
from data_representation.vocab_utils import LangTokenVocab
|
from data_representation.vocab_utils import LangTokenVocab
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user