1021 add flexable attr control

This commit is contained in:
FelixChan
2025-10-21 15:27:03 +08:00
parent d6b68ef90b
commit b493ede479
15 changed files with 400 additions and 394 deletions

View File

@ -43,6 +43,22 @@ def typical_sampling(logits, thres=0.99):
scores = logits.masked_fill(indices_to_remove, float("-inf"))
return scores
def min_p_sampling(logits, alpha=0.05):
"""
logits: Tensor of shape [B, L, V]
alpha: float, relative probability threshold (e.g., 0.05)
"""
# 计算 softmax 概率
probs = F.softmax(logits, dim=-1)
# 找到每个位置的最大概率
max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1]
# 保留概率 >= alpha * max_prob 的 token
mask = probs < (alpha * max_probs) # True 表示要屏蔽
masked_logits = logits.masked_fill(mask, float('-inf'))
return masked_logits
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
@ -91,6 +107,8 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
modified_logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
modified_logits = eta_sampling(logits, epsilon=threshold)
elif sampling_method == "min_p":
modified_logits = min_p_sampling(logits, alpha=threshold)
else:
modified_logits = logits # 其他情况直接使用原始logits