0925 use custom x_transformers for easy develop

This commit is contained in:
FelixChan
2025-09-25 16:04:22 +08:00
parent 6f03357342
commit d077e3210e
6 changed files with 4775 additions and 3 deletions

556
Amadeus/custom_attend.py Normal file
View File

@ -0,0 +1,556 @@
from __future__ import annotations
from functools import partial
from typing import Tuple, Callable
import torch
from torch.nn import Module, Parameter
from torch import cat, nn, einsum, Tensor
import torch.nn.functional as F
from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass
from einops import rearrange, repeat, pack, unpack
# constants
@dataclass
class Intermediates:
qk_similarities: Tensor | None = None
pre_softmax_attn: Tensor | None = None
post_softmax_attn: Tensor | None = None
values: Tensor | None = None
cached_kv: tuple[Tensor, Tensor] | None = None
layer_type: str | None = None
hybrid_hidden: Tensor | None = None
def to_tuple(self):
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def at_most_one_of(*bools):
return sum([*map(int, bools)]) <= 1
def compact(arr):
return [*filter(exists, arr)]
@torch.jit.script
def softclamp(t: Tensor, value: float):
return (t / value).tanh() * value
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# selective attention
# https://arxiv.org/abs/2410.02703 - section 3.3
# it is a technique to allow each token to prevent itself from being attended to by future tokens
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
def selective_attn(
sim,
sim_head_gate = None,
no_mask_sos = True
):
i, j, device = *sim.shape[-2:], sim.device
sim_head_gate = default(sim_head_gate, sim[:, 0])
gate = F.relu(sim_head_gate) # only positive
if no_mask_sos:
gate = gate.clone()
gate[..., -i] = 0.
eye = torch.eye(i, device = device)
if j > i:
eye = F.pad(eye, (j - i, 0), value = 1.)
gate = (1. - eye) * gate
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
gate = gate.cumsum(dim = -2)
return sim - rearrange(gate, 'b i j -> b 1 i j')
# alternative distance functions
def qk_l2_dist_squared(q, k):
if k.ndim == 3:
k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
q, packed_shape = pack_one(q, '* i d')
k, _ = pack_one(k, '* j d')
l2_dist_squared = torch.cdist(q, k) ** 2
return unpack_one(l2_dist_squared, packed_shape, '* i j')
# one-hot straight through softmax
def one_hot_straight_through(logits, temperature = 1.):
one_hot_indices = logits.argmax(dim = -1, keepdim = True)
one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
soft_attn = (logits / temperature).softmax(dim = -1)
return one_hot + soft_attn - soft_attn.detach()
# sparse topk attention - only keep topk attn logits for softmax
# optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
def sparse_topk_attn(
logits,
sparse_topk,
temperature = 1.,
straight_through = False
):
orig_logits = logits
mask_value = -torch.finfo(logits.dtype).max
top_values, _ = logits.topk(sparse_topk, dim = -1)
sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
logits = logits.masked_fill(~sparse_topk_mask, mask_value)
topk_attn = logits.softmax(dim = -1)
if not straight_through:
return topk_attn
soft_attn = (orig_logits / temperature).softmax(dim = -1)
return topk_attn.detach() + soft_attn - soft_attn.detach()
# functions for creating causal mask
# need a special one for onnx cpu (no support for .triu)
def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
def onnx_create_causal_mask(i, j, device):
r = torch.arange(i, device = device)
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
return causal_mask
# main class
class Attend(Module):
def __init__(
self,
*,
dropout = 0.,
causal = False,
heads = None,
pre_talking_heads = False,
post_talking_heads = False,
pre_scale_post_talking_heads = False,
sparse_topk = None,
sparse_topk_straight_through = False, # https://arxiv.org/abs/2505.22074
scale = None,
qk_norm = False,
l2_distance = False,
sigmoid = False,
custom_attn_fn: Callable | None = None,
flash = False,
softclamp_logits = False,
logit_softclamp_value = 50.,
add_zero_kv = False,
head_learned_sink = False,
selective = False,
hard = False,
cope = None,
onnxable = False,
sdp_kwargs: dict = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
):
super().__init__()
self.scale = scale
# causal related
self.causal = causal
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
# attention type
is_sparse_topk_attn = exists(sparse_topk)
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
assert not (flash and hard), 'hard attention not available for flash'
assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
if exists(custom_attn_fn):
self.attn_fn = custom_attn_fn
elif sigmoid:
self.attn_fn = F.sigmoid
elif hard:
self.attn_fn = one_hot_straight_through
elif is_sparse_topk_attn:
self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
else:
softmax_fn = partial(F.softmax, dim = -1)
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
# dropouts
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
# talking heads
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
if exists(self.pre_softmax_talking_heads):
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
if exists(self.post_softmax_talking_heads):
nn.init.dirac_(self.post_softmax_talking_heads.weight)
if exists(self.pre_scale_post_talking_heads):
# an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
# selective attention
assert not (flash and selective), 'selective attention cannot work on flash attention'
assert not (selective and not causal), 'selective attention is designed for autoregressive'
self.selective = selective
# l2 distance attention
self.l2_distance = l2_distance
# add a key / value token composed of zeros
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
self.add_zero_kv = add_zero_kv
# learned sink concatted pre-softmax, working solution from gpt-oss
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
self.head_learned_sink = head_learned_sink
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
# soft clamp attention logit value
if softclamp_logits:
assert not flash, 'flash attention not compatible with logit softclamp value yet'
assert logit_softclamp_value > 0.
self.softclamp_logits = softclamp_logits
self.logit_softclamp_value = logit_softclamp_value
# contextual positional encoding
self.cope = cope
# flash attention
self.flash = flash
torch_version = version.parse(torch.__version__)
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# torch 2.3 uses new backend and context manager
if self.flash:
if torch_version >= version.parse('2.3'):
from torch.nn.attention import SDPBackend
str_to_backend = dict(
enable_flash = SDPBackend.FLASH_ATTENTION,
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
enable_math = SDPBackend.MATH,
enable_cudnn = SDPBackend.CUDNN_ATTENTION
)
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
else:
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
def flash_attn(
self,
q, k, v,
mask = None,
attn_bias = None
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if k.ndim == 3:
k = repeat(k, 'b ... -> b h ...', h = q.shape[1])
if v.ndim == 3:
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
# handle maybe l2 distance
if self.l2_distance:
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
k = F.pad(k, (0, 1), value = -1.)
k = cat((k, k_norm_sq), dim = -1)
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
q = cat((2 * q, q_norm_sq), dim = -1)
q = F.pad(q, (0, 1), value = -1.)
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L
causal = self.causal
# in the case of kv caching with one token (q_len == 1), just turn off causal masking
# in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
if q_len == 1 and causal:
causal = False
# expand key padding mask
if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)
# handle kv cache - this should be bypassable in updated flash attention 2
if k_len > q_len and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
if not exists(mask):
mask = ~causal_mask
else:
mask = mask & ~causal_mask
causal = False
# manually handle causal mask, if another mask was given
if exists(mask) and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
mask = mask & ~causal_mask
causal = False
# protect against an entire row being masked out
row_is_entirely_masked = None
if exists(mask):
row_is_entirely_masked = ~mask.any(dim = -1)
# handle alibi positional bias
# convert from bool to float
if exists(attn_bias):
attn_bias = attn_bias.expand(batch, heads, -1, -1)
# if mask given, the mask would already contain the causal mask from above logic
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
mask_value = -torch.finfo(q.dtype).max
if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
elif causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
causal = False
# scaled_dot_product_attention handles attn_mask either as bool or additive bias
# make it an additive bias here
mask = attn_bias
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with self.sdp_context_manager():
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)
# for a row that is entirely masked out, should zero out the output of that row token
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out, Intermediates()
def forward(
self,
q, k, v,
mask = None,
attn_bias = None,
prev_attn = None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)
causal = self.causal
# handle key padding mask
if exists(mask) and mask.ndim == 2:
mask = rearrange(mask, 'b j -> b 1 1 j')
# handle kv cached decoding
if n == 1 and causal:
causal = False
# handle grouped multi-query attention
if kv_heads == 1:
k, v = tuple(rearrange(t, 'b 1 n d -> b n d') for t in (k, v))
elif kv_heads < heads:
k, v = tuple(repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads) for t in (k, v))
# handle zero kv, as means for allowing network to attend to nothing
if self.add_zero_kv:
k, v = tuple(F.pad(t, (0, 0, 1, 0), value = 0.) for t in (k, v))
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
if exists(attn_bias):
attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
if self.flash:
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
if not self.l2_distance:
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
else:
sim = -qk_l2_dist_squared(q, k)
sim = sim * scale
if exists(prev_attn):
sim = sim + prev_attn
qk_similarities = sim.clone()
if exists(self.pre_scale_post_talking_heads):
pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
if exists(self.pre_softmax_talking_heads):
sim = sim + self.pre_softmax_talking_heads(sim)
if exists(attn_bias):
sim = sim + attn_bias
if self.softclamp_logits:
sim = softclamp(sim, self.logit_softclamp_value)
i, j, dtype = *sim.shape[-2:], sim.dtype
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
sim = sim.masked_fill(~mask, mask_value)
if causal:
causal_mask = self.create_causal_mask(i, j, device = device)
sim = sim.masked_fill(causal_mask, mask_value)
row_is_entirely_masked = None
if exists(mask):
row_is_entirely_masked = ~mask.any(dim = -1)
if exists(self.cope):
sim = sim + self.cope(q, sim)
if self.selective:
sim = selective_attn(sim)
if self.head_learned_sink:
# add learned attention sink
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
sim = cat((attn_sink, sim), dim = -1)
pre_softmax_attn = sim
attn = self.attn_fn(sim)
attn = attn.type(dtype)
post_softmax_attn = attn
if self.head_learned_sink:
# remove attention sink
attn = attn[..., 1:]
attn = self.attn_dropout(attn)
if exists(self.post_softmax_talking_heads):
attn = self.post_softmax_talking_heads(attn)
if exists(self.pre_scale_post_talking_heads):
attn = attn * pre_to_post_scale
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
intermediates = Intermediates(
qk_similarities = qk_similarities,
pre_softmax_attn = pre_softmax_attn,
post_softmax_attn = post_softmax_attn
)
if exists(row_is_entirely_masked) and row_is_entirely_masked.any():
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out, intermediates

581
Amadeus/custom_wrapper.py Normal file
View File

@ -0,0 +1,581 @@
from __future__ import annotations
from math import ceil, log
from typing import Tuple, Callable
import torch
from torch import nn, tensor, Tensor
from torch.nn import Module
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange, repeat, pack, unpack
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def identity(t, *args, **kwargs):
return t
def join(arr, delimiter = ', '):
return delimiter.join(arr)
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else (t,) * length
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
# gumbel topk
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def gumbel_noise(t):
return -log(-log(torch.rand_like(t)))
# function for modifying all the cached key / values
def modify_cached_kv(cache, fn):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [fn(t) for t in inter.cached_kv]
# for variable lengthed prefixes
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
if pad == (0, 0):
return t
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
def align_right(t, lens, pad_id = 0):
batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
assert lens.ndim == 1 and lens.shape[0] == batch
assert lens.amax() <= seq_len
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
offset = max_pad_len - pad_lens
aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
return aligned
# nucleus
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending = True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# topk
def top_k(logits, frac_num_tokens = 0.1, k = None):
num_tokens = logits.shape[-1]
k = default(k, ceil(frac_num_tokens * num_tokens))
k = min(k, num_tokens)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# top_a
def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
return torch.where(probs < limit, float('-inf'), logits)
# min_p
# https://arxiv.org/abs/2407.01082
def min_p(logits, min_p = 0.1):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = min_p * max_probs
return torch.where(probs < limit, float('-inf'), logits)
# filter logits functions dict[str -> Callable]
FILTER_LOGITS_FN = dict(
top_p = top_p,
top_k = top_k,
top_a = top_a,
min_p = min_p
)
# contrastive decoding function
def contrastive_decode_fn(
expert_logits,
amateur_logits,
alpha = 0.1,
beta = 0.5
):
"""
Appendix A Algorithm 2
https://arxiv.org/abs/2309.09117
"""
cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
diffs = (1 + beta) * expert_logits - beta * amateur_logits
contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
return contrastive_decode_logits
# autoregressive wrapper class
class AutoregressiveWrapper(Module):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0,
mask_prob = 0.,
add_attn_z_loss = False,
next_embed_loss_weight = 0.1
):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.max_seq_len
# paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
assert mask_prob < 1.
self.mask_prob = mask_prob
# whether to add router z-loss
self.add_attn_z_loss = add_attn_z_loss
# whether to add a continuous loss
self.add_continuous_pred_head = net.add_continuous_pred_head
self.next_embed_loss_weight = next_embed_loss_weight
@torch.no_grad()
@eval_decorator
def beam_search(
self,
prompts,
seq_len,
beams = 4,
return_beams_and_scores = False,
eos_token = None,
temperature = 1.,
stochastic = False,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = identity,
restrict_to_max_seq_len = True,
filter_kwargs: dict = dict(),
cache_kv = True,
**kwargs
):
assert not exists(eos_token), 'eos token not supported yet'
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, packed_shape = pack([prompts], '* n')
batch, orig_seq_len = prompts.shape
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = orig_seq_len - prompt_lens
# output from which sampled tokens appended to
out = prompts
# kv caches
cache = None
should_cache = cache_kv and self.net.can_cache_kv
# scores for the beams
scores = torch.zeros((batch,), device = device)
batch_arange = torch.arange(batch, device = device)
# sampling up to seq_len
for i in range(seq_len):
is_first = i == 0
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :])
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if should_cache:
cache = new_cache
logits = logits[:, -1]
# to add to the scores
log_probs = logits.log_softmax(dim = -1)
# maybe filter by top_k, top_p (nucleus) for stochastic beam search
if stochastic and not greedy:
logits = filter_logits_fn(logits, **filter_kwargs)
logits = (logits / temperature) + gumbel_noise(logits)
# (gumbel) topk
samples = logits.topk(beams, dim = -1).indices
# get the scores for keeping track of beams
next_scores = log_probs.gather(-1, samples)
# expand beam times
scores = repeat(scores, 'b -> b beams', beams = beams)
scores = scores + next_scores
out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
samples = rearrange(samples, 'b beams -> (b beams) 1')
if should_cache and is_first:
modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
# concat sample
out = torch.cat((out, samples), dim=-1)
# sort by score and excise
# excise out the beams
scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch)
curr_num_beams = scores.shape[-1]
if curr_num_beams > beams:
scores, sort_indices = scores.sort(dim = -1, descending = True)
scores = scores[:, :beams]
top_beams_indices = sort_indices[:, :beams]
top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
out = out[flattened_beam_indices]
scores = rearrange(scores, 'b beams -> (b beams)')
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
# select out the top beam
out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
out = out[..., orig_seq_len:]
out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension
if not return_beams_and_scores:
return out[..., 0, :]
scores = rearrange(scores, '(b beams) -> beams b', b = batch)
out = rearrange(out, 'b beams n -> beams b n')
return out, scores
@torch.no_grad()
@eval_decorator
def generate(
self,
prompts: list[Tensor] | Tensor,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
**kwargs
):
max_seq_len, greedy = self.max_seq_len, temperature == 0.
# handle prompts given as list of variable lengthed token ids
if isinstance(prompts, list):
assert len(prompts) > 0, 'prompts cannot be empty list'
assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors'
prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device)
prompts = pad_sequence(prompts, batch_first = True)
# pack maybe no batch
prompts, ps = pack([prompts], '* n')
b, t, device = *prompts.shape, prompts.device
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = t - prompt_lens
# output from which sampled tokens appended to
out = prompts
# kv caches
cache = None
# if doing contrastive decoding, turn off filter automatically
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
# sampling up to seq_len
for _ in range(seq_len):
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if cache_kv and self.net.can_cache_kv:
cache = new_cache
logits = logits[:, -1]
# handle contrastive decoding, Li et al.
# https://arxiv.org/abs/2210.15097
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
# filter by top_k, top_p (nucleus), top_a, or custom
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim=-1)
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
out = out[:, t:]
out, = unpack(out, ps, '* n')
return out
def forward(
self,
x,
return_outputs = False,
prepend_embeds = None,
**kwargs
):
seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
inp, target = x, x[:, 1:]
inp = torch.where(inp == ignore_index, self.pad_value, inp)
if self.mask_prob > 0.:
rand = torch.randn(inp.shape, device = x.device)
rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
num_mask = min(int(seq * self.mask_prob), seq - 1)
indices = rand.topk(num_mask, dim = -1).indices
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
kwargs.update(self_attn_kv_mask = mask)
out, cache = self.net(
inp,
return_intermediates = True,
return_attn_z_loss = add_attn_z_loss,
return_next_embed_pred = add_next_embed_loss,
prepend_embeds = prepend_embeds,
**kwargs
)
# destruct differently if doing continuous pred
if add_next_embed_loss:
logits, (next_embed_pred, init_embeds) = out
else:
logits = out
# if there are prepended embeds, excise it out
if exists(prepend_embeds):
prepend_len = prepend_embeds.shape[1]
logits = logits[:, prepend_len:]
# take all tokens but the last
logits = logits[:, :-1]
# loss function
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
# cross entropy loss
loss = loss_fn(
rearrange(logits, 'b n c -> b c n'),
target,
ignore_index = ignore_index
)
if add_attn_z_loss:
loss = loss + cache.attn_z_loss
if add_next_embed_loss:
mask = target != ignore_index
embed_pred = next_embed_pred[:, :-1]
cont_targets = init_embeds[:, 1:].detach()
cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none')
cont_loss = cont_loss[mask].mean()
loss = loss + cont_loss * self.next_embed_loss_weight
if not return_outputs:
return loss
return loss, (logits, cache)

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
- nn_params: nb8_embSum_diff_t2m_150M_finetunning
# - nn_params: nb8_embSum_diff_t2m_150M_pretraining
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
- nn_params: nb8_embSum_diff_t2m_150M_pretrainingv2
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 20
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from x_transformers import Decoder, Encoder, PrefixDecoder, CrossAttender
from Amadeus.custom_x_transformers import Decoder, Encoder, PrefixDecoder, CrossAttender
from transformers import T5EncoderModel
from data_representation.vocab_utils import LangTokenVocab