Files
MIDIFoundationModel/Amadeus/sampling_utils.py
2025-10-21 15:27:03 +08:00

187 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
def top_p_sampling(logits, thres=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Create an empty tensor to hold the new logits
new_logits = logits.clone()
# Use the sorted indices to place the '-inf' in the original places
indices_to_remove = sorted_indices[sorted_indices_to_remove]
new_logits[..., indices_to_remove] = float('-inf')
return new_logits
# refered: https://github.com/cimeister/typical-sampling
def typical_sampling(logits, thres=0.99):
# calculate entropy
normalized = torch.nn.functional.log_softmax(logits, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = logits.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < thres).sum(dim=-1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(-1, last_ind.view(-1, 1, 1))
# if self.min_tokens_to_keep > 1:
# # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
# sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
scores = logits.masked_fill(indices_to_remove, float("-inf"))
return scores
def min_p_sampling(logits, alpha=0.05):
"""
logits: Tensor of shape [B, L, V]
alpha: float, relative probability threshold (e.g., 0.05)
"""
# 计算 softmax 概率
probs = F.softmax(logits, dim=-1)
# 找到每个位置的最大概率
max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1]
# 保留概率 >= alpha * max_prob 的 token
mask = probs < (alpha * max_probs) # True 表示要屏蔽
masked_logits = logits.masked_fill(mask, float('-inf'))
return masked_logits
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
'''
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
#
# refered: https://github.com/john-hewitt/truncation-sampling
def eta_sampling(logits, epsilon) -> torch.FloatTensor:
probabilities = logits.softmax(dim=-1)
entropy = torch.distributions.Categorical(probs=probabilities).entropy()
new_epsilon = min(epsilon, torch.sqrt(torch.tensor(epsilon))*torch.exp(-entropy))
indices_to_remove = probabilities < new_epsilon
max_word = torch.argmax(logits, dim=-1)
indices_to_remove[..., max_word.squeeze()] = 0
new_scores = logits.masked_fill(indices_to_remove, float("-inf"))
return new_scores
def sample(logits, sampling_method, threshold, temperature):
"""Sample from the logits with a specific sampling strategy."""
if sampling_method == "top_p":
probs = F.softmax(top_p_sampling(logits, thres=threshold) / temperature, dim=-1)
elif sampling_method == "typical":
probs = F.softmax(typical_sampling(logits, thres=threshold) / temperature, dim=-1)
elif sampling_method == "eta":
probs = F.softmax(eta_sampling(logits, epsilon=threshold) / temperature, dim=-1)
else:
probs = F.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs[-1,-1,:], 1)
def sample_with_prob(logits, sampling_method, threshold, temperature):
"""Sample from the logits with a specific sampling strategy and return the token and its probability."""
# temporarily apply the sampling method to logits
logits = logits / temperature
# logits = add_gumbel_noise(logits, temperature)
if sampling_method == "top_p":
modified_logits = top_p_sampling(logits, thres=threshold)
elif sampling_method == "typical":
modified_logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
modified_logits = eta_sampling(logits, epsilon=threshold)
elif sampling_method == "min_p":
modified_logits = min_p_sampling(logits, alpha=threshold)
else:
modified_logits = logits # 其他情况直接使用原始logits
# print(modified_logits.shape)
# 应用温度调整并计算概率
# probs = F.softmax(modified_logits / temperature, dim=-1)
probs = F.softmax(modified_logits, dim=-1)
# 获取最后一个位置的概率分布
# probs_last = probs[-1, -1, :]
# print(probs.shape)
probs_last = probs[-1, -1, :]
# 采样
sampled_token = torch.multinomial(probs_last, num_samples=1)
# 获取对应的概率值
prob_value = probs_last[sampled_token]
return sampled_token, prob_value.squeeze()
def top_p_sampling_fast(logits, thres=0.9):
"""
logits: Tensor of shape [B, L, V]
Returns: logits with low-prob tokens masked as -inf, shape [B, L, V]
"""
# Step 1: sort logits and get indices
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) # [B, L, V]
# Step 2: compute cumulative probs
probs = F.softmax(sorted_logits, dim=-1) # [B, L, V]
cum_probs = torch.cumsum(probs, dim=-1) # [B, L, V]
# Step 3: mask tokens beyond cumulative threshold
sorted_mask = cum_probs > thres
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False # always keep at least one token
# Step 4: scatter back to original order
# Create mask of same shape as logits, default False
mask = torch.zeros_like(logits, dtype=torch.bool) # [B, L, V]
mask = mask.scatter(-1, sorted_indices, sorted_mask)
# Step 5: mask logits
logits = logits.masked_fill(mask, float('-inf')) # final masked logits
return logits
def sample_with_prob_fast(logits, sampling_method="top_p", threshold=0.9, temperature=1.0, mask_indices=None):
"""
logits: [B*T, num_sub_tokens, vocab_size]
mask_indices: mask indicating which tokens to sample, shape = [B*T, num_sub_tokens]
"""
if temperature != 1.0:
logits = logits / temperature
if sampling_method == "top_p":
logits = top_p_sampling_fast(logits, thres=threshold) # should support batch
elif sampling_method == "typical":
logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
logits = eta_sampling(logits, epsilon=threshold)
# else: keep logits as-is
probs = torch.softmax(logits, dim=-1) # [B*T, num_sub_tokens, vocab_size]
B, L, V = probs.shape
probs_flat = probs.view(-1, V) # [(B*T * num_sub_tokens), V]
# 采样multinomial 不能一次性处理 3D展平后采样
sampled = torch.multinomial(probs_flat, num_samples=1) # [(B*T * num_sub_tokens), 1]
sampled = sampled.view(B, L) # [B*T, num_sub_tokens]
sampled_probs = torch.gather(probs, 2, sampled.unsqueeze(-1)).squeeze(-1) # [B*T, num_sub_tokens]
return sampled, sampled_probs