from __future__ import annotations from math import ceil, log from typing import Tuple, Callable import torch from torch import nn, tensor, Tensor from torch.nn import Module import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from einops import rearrange, repeat, pack, unpack def exists(val): return val is not None def default(val, d): return val if exists(val) else d def identity(t, *args, **kwargs): return t def join(arr, delimiter = ', '): return delimiter.join(arr) def cast_tuple(t, length = 1): return t if isinstance(t, tuple) else (t,) * length def eval_decorator(fn): def inner(self, *args, **kwargs): was_training = self.training self.eval() out = fn(self, *args, **kwargs) self.train(was_training) return out return inner # gumbel topk def log(t, eps = 1e-20): return t.clamp(min = eps).log() def gumbel_noise(t): return -log(-log(torch.rand_like(t))) # function for modifying all the cached key / values def modify_cached_kv(cache, fn): for inter in cache.attn_intermediates: if inter.layer_type == 'a': inter.cached_kv = [fn(t) for t in inter.cached_kv] # for variable lengthed prefixes def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.): if pad == (0, 0): return t dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = ((0, 0) * dims_from_right) return F.pad(t, (*zeros, *pad), value = value) def align_right(t, lens, pad_id = 0): batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype assert lens.ndim == 1 and lens.shape[0] == batch assert lens.amax() <= seq_len pad_lens = seq_len - lens max_pad_len = pad_lens.amax() batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None] prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long) t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1) offset = max_pad_len - pad_lens aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...] return aligned # nucleus def top_p(logits, thres = 0.9): sorted_logits, sorted_indices = torch.sort(logits, descending = True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1) sorted_indices_to_remove = cum_probs > thres sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False) sorted_logits[sorted_indices_to_remove] = float('-inf') return sorted_logits.scatter(1, sorted_indices, sorted_logits) # topk def top_k(logits, frac_num_tokens = 0.1, k = None): num_tokens = logits.shape[-1] k = default(k, ceil(frac_num_tokens * num_tokens)) k = min(k, num_tokens) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(1, ind, val) return probs # top_a def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02): probs = logits.softmax(dim = -1) max_probs = probs.amax(dim = -1, keepdim = True) limit = torch.pow(max_probs, min_p_pow) * min_p_ratio return torch.where(probs < limit, float('-inf'), logits) # min_p # https://arxiv.org/abs/2407.01082 def min_p(logits, min_p = 0.1): probs = logits.softmax(dim = -1) max_probs = probs.amax(dim = -1, keepdim = True) limit = min_p * max_probs return torch.where(probs < limit, float('-inf'), logits) # filter logits functions dict[str -> Callable] FILTER_LOGITS_FN = dict( top_p = top_p, top_k = top_k, top_a = top_a, min_p = min_p ) # contrastive decoding function def contrastive_decode_fn( expert_logits, amateur_logits, alpha = 0.1, beta = 0.5 ): """ Appendix A Algorithm 2 https://arxiv.org/abs/2309.09117 """ cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True) diffs = (1 + beta) * expert_logits - beta * amateur_logits contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max) return contrastive_decode_logits # autoregressive wrapper class class AutoregressiveWrapper(Module): def __init__( self, net, ignore_index = -100, pad_value = 0, mask_prob = 0., add_attn_z_loss = False, next_embed_loss_weight = 0.1 ): super().__init__() self.pad_value = pad_value self.ignore_index = ignore_index self.net = net self.max_seq_len = net.max_seq_len # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432 assert mask_prob < 1. self.mask_prob = mask_prob # whether to add router z-loss self.add_attn_z_loss = add_attn_z_loss # whether to add a continuous loss self.add_continuous_pred_head = net.add_continuous_pred_head self.next_embed_loss_weight = next_embed_loss_weight @torch.no_grad() @eval_decorator def beam_search( self, prompts, seq_len, beams = 4, return_beams_and_scores = False, eos_token = None, temperature = 1., stochastic = False, prompt_lens: Tensor | None = None, filter_logits_fn: str | Callable = identity, restrict_to_max_seq_len = True, filter_kwargs: dict = dict(), cache_kv = True, **kwargs ): assert not exists(eos_token), 'eos token not supported yet' max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device prompts, packed_shape = pack([prompts], '* n') batch, orig_seq_len = prompts.shape # handle filter logits fn given as string if isinstance(filter_logits_fn, str): assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available" filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn] # handle variable lengthed prompts (prefixes) seq_start_pos = None if exists(prompt_lens): prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value) seq_start_pos = orig_seq_len - prompt_lens # output from which sampled tokens appended to out = prompts # kv caches cache = None should_cache = cache_kv and self.net.can_cache_kv # scores for the beams scores = torch.zeros((batch,), device = device) batch_arange = torch.arange(batch, device = device) # sampling up to seq_len for i in range(seq_len): is_first = i == 0 if restrict_to_max_seq_len: max_len_exceeded = out.shape[-1] > max_seq_len assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue' x = out[:, -max_seq_len:] if exists(cache): modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :]) logits, new_cache = self.net( x, return_intermediates = True, cache = cache, seq_start_pos = seq_start_pos, **kwargs ) if should_cache: cache = new_cache logits = logits[:, -1] # to add to the scores log_probs = logits.log_softmax(dim = -1) # maybe filter by top_k, top_p (nucleus) for stochastic beam search if stochastic and not greedy: logits = filter_logits_fn(logits, **filter_kwargs) logits = (logits / temperature) + gumbel_noise(logits) # (gumbel) topk samples = logits.topk(beams, dim = -1).indices # get the scores for keeping track of beams next_scores = log_probs.gather(-1, samples) # expand beam times scores = repeat(scores, 'b -> b beams', beams = beams) scores = scores + next_scores out = repeat(out, 'b ... -> (b beams) ...', beams = beams) samples = rearrange(samples, 'b beams -> (b beams) 1') if should_cache and is_first: modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams)) # concat sample out = torch.cat((out, samples), dim=-1) # sort by score and excise # excise out the beams scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch) curr_num_beams = scores.shape[-1] if curr_num_beams > beams: scores, sort_indices = scores.sort(dim = -1, descending = True) scores = scores[:, :beams] top_beams_indices = sort_indices[:, :beams] top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)') out = out[flattened_beam_indices] scores = rearrange(scores, 'b beams -> (b beams)') if not exists(eos_token): continue is_eos_tokens = (out == eos_token) if is_eos_tokens.any(dim = -1).all(): break if exists(eos_token): # mask out everything after the eos tokens shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 out = out.masked_fill(mask, self.pad_value) # select out the top beam out = rearrange(out, '(b beams) seq -> b beams seq', b = batch) out = out[..., orig_seq_len:] out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension if not return_beams_and_scores: return out[..., 0, :] scores = rearrange(scores, '(b beams) -> beams b', b = batch) out = rearrange(out, 'b beams n -> beams b n') return out, scores @torch.no_grad() @eval_decorator def generate( self, prompts: list[Tensor] | Tensor, seq_len, eos_token = None, temperature = 1., prompt_lens: Tensor | None = None, filter_logits_fn: str | Callable = top_k, restrict_to_max_seq_len = True, amateur_model: Module | Tuple[Module] | None = None, filter_kwargs: dict = dict(), contrastive_decode_kwargs: dict | Tuple[dict] = dict( beta = 0.5, alpha = 0.1 ), cache_kv = True, **kwargs ): max_seq_len, greedy = self.max_seq_len, temperature == 0. # handle prompts given as list of variable lengthed token ids if isinstance(prompts, list): assert len(prompts) > 0, 'prompts cannot be empty list' assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors' prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device) prompts = pad_sequence(prompts, batch_first = True) # pack maybe no batch prompts, ps = pack([prompts], '* n') b, t, device = *prompts.shape, prompts.device # handle filter logits fn given as string if isinstance(filter_logits_fn, str): assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available" filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn] # handle variable lengthed prompts (prefixes) seq_start_pos = None if exists(prompt_lens): prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value) seq_start_pos = t - prompt_lens # output from which sampled tokens appended to out = prompts # kv caches cache = None # if doing contrastive decoding, turn off filter automatically if exists(amateur_model): amateur_model = cast_tuple(amateur_model) contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs) assert len(amateur_model) == len(contrastive_decode_kwargs) amateur_caches = [None] * len(amateur_model) filter_logits_fn = identity for i, module in enumerate(amateur_model): if isinstance(module, AutoregressiveWrapper): amateur_model[i] = module.net module.eval() # sampling up to seq_len for _ in range(seq_len): if restrict_to_max_seq_len: max_len_exceeded = out.shape[-1] > max_seq_len assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue' x = out[:, -max_seq_len:] if exists(cache): for inter in cache.attn_intermediates: if inter.layer_type == 'a': inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv] logits, new_cache = self.net( x, return_intermediates = True, cache = cache, seq_start_pos = seq_start_pos, **kwargs ) if cache_kv and self.net.can_cache_kv: cache = new_cache logits = logits[:, -1] # handle contrastive decoding, Li et al. # https://arxiv.org/abs/2210.15097 if exists(amateur_model): for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)): amateur_logits, next_amateur_cache = amateur( x, return_intermediates = True, cache = amateur_cache, seq_start_pos = seq_start_pos, **kwargs ) amateur_logits = amateur_logits[:, -1] assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model' logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs) if cache_kv and amateur.can_cache_kv: amateur_caches[i] = next_amateur_cache # filter by top_k, top_p (nucleus), top_a, or custom if greedy: sample = logits.argmax(dim = -1, keepdim = True) else: filtered_logits = filter_logits_fn(logits, **filter_kwargs) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) # concat sample out = torch.cat((out, sample), dim=-1) if not exists(eos_token): continue is_eos_tokens = (out == eos_token) if is_eos_tokens.any(dim = -1).all(): break if exists(eos_token): # mask out everything after the eos tokens shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 out = out.masked_fill(mask, self.pad_value) out = out[:, t:] out, = unpack(out, ps, '* n') return out def forward( self, x, return_outputs = False, prepend_embeds = None, **kwargs ): seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head inp, target = x, x[:, 1:] inp = torch.where(inp == ignore_index, self.pad_value, inp) if self.mask_prob > 0.: rand = torch.randn(inp.shape, device = x.device) rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out num_mask = min(int(seq * self.mask_prob), seq - 1) indices = rand.topk(num_mask, dim = -1).indices mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() kwargs.update(self_attn_kv_mask = mask) out, cache = self.net( inp, return_intermediates = True, return_attn_z_loss = add_attn_z_loss, return_next_embed_pred = add_next_embed_loss, prepend_embeds = prepend_embeds, **kwargs ) # destruct differently if doing continuous pred if add_next_embed_loss: logits, (next_embed_pred, init_embeds) = out else: logits = out # if there are prepended embeds, excise it out if exists(prepend_embeds): prepend_len = prepend_embeds.shape[1] logits = logits[:, prepend_len:] # take all tokens but the last logits = logits[:, :-1] # loss function loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss # cross entropy loss loss = loss_fn( rearrange(logits, 'b n c -> b c n'), target, ignore_index = ignore_index ) if add_attn_z_loss: loss = loss + cache.attn_z_loss if add_next_embed_loss: mask = target != ignore_index embed_pred = next_embed_pred[:, :-1] cont_targets = init_embeds[:, 1:].detach() cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none') cont_loss = cont_loss[mask].mean() loss = loss + cont_loss * self.next_embed_loss_weight if not return_outputs: return loss return loss, (logits, cache)