1013 update

This commit is contained in:
FelixChan
2025-10-13 17:56:36 +08:00
parent d077e3210e
commit d6b68ef90b
17 changed files with 815 additions and 70 deletions

View File

@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
def top_p_sampling(logits, thres=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@ -84,7 +84,7 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
# temporarily apply the sampling method to logits
logits = logits / temperature
# logits = add_gumbel_noise(logits, temperature)
if sampling_method == "top_p":
modified_logits = top_p_sampling(logits, thres=threshold)
elif sampling_method == "typical":