first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

56
Amadeus/catsample.py Normal file
View File

@ -0,0 +1,56 @@
import torch
import torch.nn.functional as F
def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
logits = categorical_probs.clamp(min=1e-9).log()
return F.gumbel_softmax(logits, hard=hard)
def sample_categorical(categorical_probs, method="hard"):
if method == "hard":
gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
return (categorical_probs / gumbel_norm).argmax(dim=-1)
else:
raise ValueError(f"Method {method} for sampling categorical variables is not valid.")
def direct_sampling(logits):
probs = logits.softmax(dim=-1)
index = sample_categorical(probs.to(torch.float32))
return index
def top_p_sampling(logits, p=0.9):
probs = logits.softmax(dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs.masked_fill_(indices_to_remove, 0)
probs /= probs.sum(dim=-1).unsqueeze(-1)
index = sample_categorical(probs.to(torch.float32))
return index
def top_k_sampling(logits, k=400):
top_k_values, top_k_indices = torch.topk(logits, int(k))
top_k_probs = top_k_values.softmax(dim=-1)
index = sample_categorical(top_k_probs.to(torch.float32))
index = top_k_indices[torch.arange(index.size(0)), index]
return index
def sample_with_strategy(update_logits, strategy, para = None):
if strategy == "direct":
return direct_sampling(update_logits)
elif strategy == "top_p":
return top_p_sampling(update_logits, para)
elif strategy == "top_k":
return top_k_sampling(update_logits, para)
else:
raise ValueError(f"Strategy {strategy} is not valid.")