0925 use custom x_transformers for easy develop

This commit is contained in:
FelixChan
2025-09-25 16:04:22 +08:00
parent 6f03357342
commit d077e3210e
6 changed files with 4775 additions and 3 deletions

581
Amadeus/custom_wrapper.py Normal file
View File

@ -0,0 +1,581 @@
from __future__ import annotations
from math import ceil, log
from typing import Tuple, Callable
import torch
from torch import nn, tensor, Tensor
from torch.nn import Module
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange, repeat, pack, unpack
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def identity(t, *args, **kwargs):
return t
def join(arr, delimiter = ', '):
return delimiter.join(arr)
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else (t,) * length
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
# gumbel topk
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def gumbel_noise(t):
return -log(-log(torch.rand_like(t)))
# function for modifying all the cached key / values
def modify_cached_kv(cache, fn):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [fn(t) for t in inter.cached_kv]
# for variable lengthed prefixes
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
if pad == (0, 0):
return t
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
def align_right(t, lens, pad_id = 0):
batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
assert lens.ndim == 1 and lens.shape[0] == batch
assert lens.amax() <= seq_len
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
offset = max_pad_len - pad_lens
aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
return aligned
# nucleus
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending = True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
# topk
def top_k(logits, frac_num_tokens = 0.1, k = None):
num_tokens = logits.shape[-1]
k = default(k, ceil(frac_num_tokens * num_tokens))
k = min(k, num_tokens)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# top_a
def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
return torch.where(probs < limit, float('-inf'), logits)
# min_p
# https://arxiv.org/abs/2407.01082
def min_p(logits, min_p = 0.1):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = min_p * max_probs
return torch.where(probs < limit, float('-inf'), logits)
# filter logits functions dict[str -> Callable]
FILTER_LOGITS_FN = dict(
top_p = top_p,
top_k = top_k,
top_a = top_a,
min_p = min_p
)
# contrastive decoding function
def contrastive_decode_fn(
expert_logits,
amateur_logits,
alpha = 0.1,
beta = 0.5
):
"""
Appendix A Algorithm 2
https://arxiv.org/abs/2309.09117
"""
cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
diffs = (1 + beta) * expert_logits - beta * amateur_logits
contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
return contrastive_decode_logits
# autoregressive wrapper class
class AutoregressiveWrapper(Module):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0,
mask_prob = 0.,
add_attn_z_loss = False,
next_embed_loss_weight = 0.1
):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.max_seq_len
# paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
assert mask_prob < 1.
self.mask_prob = mask_prob
# whether to add router z-loss
self.add_attn_z_loss = add_attn_z_loss
# whether to add a continuous loss
self.add_continuous_pred_head = net.add_continuous_pred_head
self.next_embed_loss_weight = next_embed_loss_weight
@torch.no_grad()
@eval_decorator
def beam_search(
self,
prompts,
seq_len,
beams = 4,
return_beams_and_scores = False,
eos_token = None,
temperature = 1.,
stochastic = False,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = identity,
restrict_to_max_seq_len = True,
filter_kwargs: dict = dict(),
cache_kv = True,
**kwargs
):
assert not exists(eos_token), 'eos token not supported yet'
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, packed_shape = pack([prompts], '* n')
batch, orig_seq_len = prompts.shape
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = orig_seq_len - prompt_lens
# output from which sampled tokens appended to
out = prompts
# kv caches
cache = None
should_cache = cache_kv and self.net.can_cache_kv
# scores for the beams
scores = torch.zeros((batch,), device = device)
batch_arange = torch.arange(batch, device = device)
# sampling up to seq_len
for i in range(seq_len):
is_first = i == 0
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :])
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if should_cache:
cache = new_cache
logits = logits[:, -1]
# to add to the scores
log_probs = logits.log_softmax(dim = -1)
# maybe filter by top_k, top_p (nucleus) for stochastic beam search
if stochastic and not greedy:
logits = filter_logits_fn(logits, **filter_kwargs)
logits = (logits / temperature) + gumbel_noise(logits)
# (gumbel) topk
samples = logits.topk(beams, dim = -1).indices
# get the scores for keeping track of beams
next_scores = log_probs.gather(-1, samples)
# expand beam times
scores = repeat(scores, 'b -> b beams', beams = beams)
scores = scores + next_scores
out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
samples = rearrange(samples, 'b beams -> (b beams) 1')
if should_cache and is_first:
modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
# concat sample
out = torch.cat((out, samples), dim=-1)
# sort by score and excise
# excise out the beams
scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch)
curr_num_beams = scores.shape[-1]
if curr_num_beams > beams:
scores, sort_indices = scores.sort(dim = -1, descending = True)
scores = scores[:, :beams]
top_beams_indices = sort_indices[:, :beams]
top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
out = out[flattened_beam_indices]
scores = rearrange(scores, 'b beams -> (b beams)')
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
# select out the top beam
out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
out = out[..., orig_seq_len:]
out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension
if not return_beams_and_scores:
return out[..., 0, :]
scores = rearrange(scores, '(b beams) -> beams b', b = batch)
out = rearrange(out, 'b beams n -> beams b n')
return out, scores
@torch.no_grad()
@eval_decorator
def generate(
self,
prompts: list[Tensor] | Tensor,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
**kwargs
):
max_seq_len, greedy = self.max_seq_len, temperature == 0.
# handle prompts given as list of variable lengthed token ids
if isinstance(prompts, list):
assert len(prompts) > 0, 'prompts cannot be empty list'
assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors'
prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device)
prompts = pad_sequence(prompts, batch_first = True)
# pack maybe no batch
prompts, ps = pack([prompts], '* n')
b, t, device = *prompts.shape, prompts.device
# handle filter logits fn given as string
if isinstance(filter_logits_fn, str):
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
# handle variable lengthed prompts (prefixes)
seq_start_pos = None
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = t - prompt_lens
# output from which sampled tokens appended to
out = prompts
# kv caches
cache = None
# if doing contrastive decoding, turn off filter automatically
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
# sampling up to seq_len
for _ in range(seq_len):
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
x = out[:, -max_seq_len:]
if exists(cache):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
if cache_kv and self.net.can_cache_kv:
cache = new_cache
logits = logits[:, -1]
# handle contrastive decoding, Li et al.
# https://arxiv.org/abs/2210.15097
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
# filter by top_k, top_p (nucleus), top_a, or custom
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim=-1)
if not exists(eos_token):
continue
is_eos_tokens = (out == eos_token)
if is_eos_tokens.any(dim = -1).all():
break
if exists(eos_token):
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
out = out.masked_fill(mask, self.pad_value)
out = out[:, t:]
out, = unpack(out, ps, '* n')
return out
def forward(
self,
x,
return_outputs = False,
prepend_embeds = None,
**kwargs
):
seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
inp, target = x, x[:, 1:]
inp = torch.where(inp == ignore_index, self.pad_value, inp)
if self.mask_prob > 0.:
rand = torch.randn(inp.shape, device = x.device)
rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
num_mask = min(int(seq * self.mask_prob), seq - 1)
indices = rand.topk(num_mask, dim = -1).indices
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
kwargs.update(self_attn_kv_mask = mask)
out, cache = self.net(
inp,
return_intermediates = True,
return_attn_z_loss = add_attn_z_loss,
return_next_embed_pred = add_next_embed_loss,
prepend_embeds = prepend_embeds,
**kwargs
)
# destruct differently if doing continuous pred
if add_next_embed_loss:
logits, (next_embed_pred, init_embeds) = out
else:
logits = out
# if there are prepended embeds, excise it out
if exists(prepend_embeds):
prepend_len = prepend_embeds.shape[1]
logits = logits[:, prepend_len:]
# take all tokens but the last
logits = logits[:, :-1]
# loss function
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
# cross entropy loss
loss = loss_fn(
rearrange(logits, 'b n c -> b c n'),
target,
ignore_index = ignore_index
)
if add_attn_z_loss:
loss = loss + cache.attn_z_loss
if add_next_embed_loss:
mask = target != ignore_index
embed_pred = next_embed_pred[:, :-1]
cont_targets = init_embeds[:, 1:].detach()
cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none')
cont_loss = cont_loss[mask].mean()
loss = loss + cont_loss * self.next_embed_loss_weight
if not return_outputs:
return loss
return loss, (logits, cache)