first commit
This commit is contained in:
56
Amadeus/catsample.py
Normal file
56
Amadeus/catsample.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
|
||||
logits = categorical_probs.clamp(min=1e-9).log()
|
||||
return F.gumbel_softmax(logits, hard=hard)
|
||||
|
||||
|
||||
def sample_categorical(categorical_probs, method="hard"):
|
||||
if method == "hard":
|
||||
gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
|
||||
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
||||
else:
|
||||
raise ValueError(f"Method {method} for sampling categorical variables is not valid.")
|
||||
|
||||
def direct_sampling(logits):
|
||||
probs = logits.softmax(dim=-1)
|
||||
index = sample_categorical(probs.to(torch.float32))
|
||||
return index
|
||||
|
||||
|
||||
def top_p_sampling(logits, p=0.9):
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
probs.masked_fill_(indices_to_remove, 0)
|
||||
probs /= probs.sum(dim=-1).unsqueeze(-1)
|
||||
index = sample_categorical(probs.to(torch.float32))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def top_k_sampling(logits, k=400):
|
||||
top_k_values, top_k_indices = torch.topk(logits, int(k))
|
||||
top_k_probs = top_k_values.softmax(dim=-1)
|
||||
index = sample_categorical(top_k_probs.to(torch.float32))
|
||||
index = top_k_indices[torch.arange(index.size(0)), index]
|
||||
|
||||
return index
|
||||
|
||||
def sample_with_strategy(update_logits, strategy, para = None):
|
||||
if strategy == "direct":
|
||||
return direct_sampling(update_logits)
|
||||
elif strategy == "top_p":
|
||||
return top_p_sampling(update_logits, para)
|
||||
elif strategy == "top_k":
|
||||
return top_k_sampling(update_logits, para)
|
||||
else:
|
||||
raise ValueError(f"Strategy {strategy} is not valid.")
|
||||
Reference in New Issue
Block a user