1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -228,19 +228,39 @@ class DiffusionLoss4CompoundToken():
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
return loss
def get_aux_ar_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
train_loss_list = []
log_loss_dict_normal = {}
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
disp_loss = None
aux_ar_logits = None
# print(len(logits_dict))
if len(logits_dict) == 2: # has aux ar loss
logits_dict, aux_ar_logits = logits_dict
if input_dict is not None:
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
feat = hidden_vec.mean(dim=1) #bs,dim
disp_loss = dispersive_loss(feat, tau=tau) # scalar
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
if aux_ar_logits is not None:
aux_ar_loss = self.get_aux_ar_nll_loss(aux_ar_logits[key], shifted_tgt[..., idx], mask)
training_loss = 0.5 * training_loss + 0.5 * aux_ar_loss
train_loss_list.append(training_loss)
if valid:
if key == 'type' or key == 'timesig':