1021 add flexable attr control
This commit is contained in:
@ -43,6 +43,22 @@ def typical_sampling(logits, thres=0.99):
|
||||
scores = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||
return scores
|
||||
|
||||
def min_p_sampling(logits, alpha=0.05):
|
||||
"""
|
||||
logits: Tensor of shape [B, L, V]
|
||||
alpha: float, relative probability threshold (e.g., 0.05)
|
||||
"""
|
||||
# 计算 softmax 概率
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
# 找到每个位置的最大概率
|
||||
max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1]
|
||||
|
||||
# 保留概率 >= alpha * max_prob 的 token
|
||||
mask = probs < (alpha * max_probs) # True 表示要屏蔽
|
||||
masked_logits = logits.masked_fill(mask, float('-inf'))
|
||||
return masked_logits
|
||||
|
||||
def add_gumbel_noise(logits, temperature):
|
||||
'''
|
||||
The Gumbel max is a method for sampling categorical distributions.
|
||||
@ -91,6 +107,8 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
|
||||
modified_logits = typical_sampling(logits, thres=threshold)
|
||||
elif sampling_method == "eta":
|
||||
modified_logits = eta_sampling(logits, epsilon=threshold)
|
||||
elif sampling_method == "min_p":
|
||||
modified_logits = min_p_sampling(logits, alpha=threshold)
|
||||
else:
|
||||
modified_logits = logits # 其他情况直接使用原始logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user