1127 update to latest
This commit is contained in:
@ -228,19 +228,39 @@ class DiffusionLoss4CompoundToken():
|
||||
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_aux_ar_nll_loss(self, logits, target, mask):
|
||||
probs = logits.softmax(dim=-1)
|
||||
if probs.ndim == 3:
|
||||
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
# clamp min value to 1e-7 to avoid log(0)
|
||||
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
|
||||
return loss
|
||||
|
||||
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
|
||||
train_loss_list = []
|
||||
log_loss_dict_normal = {}
|
||||
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
disp_loss = None
|
||||
aux_ar_logits = None
|
||||
# print(len(logits_dict))
|
||||
if len(logits_dict) == 2: # has aux ar loss
|
||||
logits_dict, aux_ar_logits = logits_dict
|
||||
if input_dict is not None:
|
||||
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
|
||||
feat = hidden_vec.mean(dim=1) #bs,dim
|
||||
disp_loss = dispersive_loss(feat, tau=tau) # scalar
|
||||
for idx, key in enumerate(self.feature_list):
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
|
||||
if aux_ar_logits is not None:
|
||||
aux_ar_loss = self.get_aux_ar_nll_loss(aux_ar_logits[key], shifted_tgt[..., idx], mask)
|
||||
training_loss = 0.5 * training_loss + 0.5 * aux_ar_loss
|
||||
train_loss_list.append(training_loss)
|
||||
if valid:
|
||||
if key == 'type' or key == 'timesig':
|
||||
|
||||
Reference in New Issue
Block a user