1013 update
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def top_p_sampling(logits, thres=0.9):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
@ -84,7 +84,7 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
|
||||
# temporarily apply the sampling method to logits
|
||||
logits = logits / temperature
|
||||
# logits = add_gumbel_noise(logits, temperature)
|
||||
|
||||
|
||||
if sampling_method == "top_p":
|
||||
modified_logits = top_p_sampling(logits, thres=threshold)
|
||||
elif sampling_method == "typical":
|
||||
|
||||
Reference in New Issue
Block a user