import math from numpy import mask_indices import torch import torch.nn as nn from torch.optim.lr_scheduler import _LRScheduler from torch.optim import Optimizer from collections import defaultdict import torch.nn.functional as F def add_conti_for_single_feature(tensor): new_target = tensor.clone() # Assuming tensor shape is [batch, sequence, features] # Create a shifted version of the tensor shifted_tensor = torch.roll(new_target, shifts=1, dims=1) # The first element of each sequence cannot be a duplicate by definition shifted_tensor[:, 0] = new_target[:, 0] + 1 # Identify where the original and shifted tensors are the same (duplicates) duplicates = new_target == shifted_tensor # Replace duplicates with 9999 new_target[duplicates] = 9999 return new_target def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_params): feature_prediction_order_dict = { 4: ["type", "beat", "pitch", "duration"], 5: ["type", "beat", "instrument", "pitch", "duration"], 7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"], 8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"] } if encoding_scheme == 'remi': prediction_order = feature_prediction_order_dict[num_features] elif encoding_scheme == 'cp': if nn_params.get("partial_sequential_prediction", False): default_prediction_order = feature_prediction_order_dict[num_features] prediction_order = [default_prediction_order[0], default_prediction_order[1:]] else: prediction_order = feature_prediction_order_dict[num_features] elif encoding_scheme == 'nb': assert target_feature in feature_prediction_order_dict[num_features], f"Target feature {target_feature} not in the selected sub-token set. Please check target feature in the config and num_features in nn_params." default_prediction_order = feature_prediction_order_dict[num_features] # Reorganize the prediction order based on the target_feature target_index = default_prediction_order.index(target_feature) prediction_order = default_prediction_order[target_index:] + default_prediction_order[:target_index] return prediction_order ########################### Loss function ################################ class NLLLoss4REMI(): def __init__( self, focal_alpha:float, focal_gamma:float, ): self.alpha = focal_alpha self.gamma = focal_gamma def get_nll_loss(self, logits, target, mask): probs = logits.softmax(dim=-1) if probs.ndim == 3: probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size] if target.ndim == 2: target = target.flatten(0, 1) # [batch_size*seq_len] # clamp min value to 1e-7 to avoid log(0) pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len] loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len] loss_seq = loss * mask.flatten(0, 1) # [batch_size*seq_len] loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask return loss, loss_seq def __call__(self, logits, shifted_tgt, mask, vocab): if vocab is not None: loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask) loss_by_class_normal = defaultdict(float) shifted_tgt_with_mask = shifted_tgt * mask # [b, t] answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t] for feature in vocab.feature_list: feature_mask = vocab.total_mask[feature].to(answers_idx.device) # [327,] mask_for_target = feature_mask[answers_idx] # [b*t] normal_loss_seq_by_class = loss_seq * mask_for_target if mask_for_target.sum().item() != 0: loss_by_class_normal[feature+'_normal'] += (normal_loss_seq_by_class.sum().item() / mask_for_target.sum().item()) return loss, loss_by_class_normal else: loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask) return loss, None class NLLLoss4CompoundToken(): def __init__(self, feature_list, focal_alpha:float, focal_gamma:float): self.feature_list = feature_list self.alpha = focal_alpha self.gamma = focal_gamma def get_nll_loss(self, logits, target, mask): probs = logits.softmax(dim=-1) if probs.ndim == 3: probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size] if target.ndim == 2: target = target.flatten(0, 1) # [batch_size*seq_len] # clamp min value to 1e-7 to avoid log(0) pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len] loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len] loss = loss * mask.flatten(0, 1) # [batch_size*seq_len] loss = loss.sum() / mask.sum() # calculating mean loss considering mask return loss def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token): probs = logits.softmax(dim=-1) if ignore_token is not None and conti_token is not None: target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len] valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len] elif ignore_token is not None and conti_token is None: valid_mask = (target != ignore_token) elif ignore_token is None and conti_token is None: valid_mask = torch.ones_like(target).bool() valid_mask = valid_mask.flatten(0, 1) if probs.ndim == 3: probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size] if target.ndim == 2: target = target.flatten(0, 1) # [batch_size*seq_len] pt = probs[torch.arange(len(target)), target] # [batch_size*seq_len] total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len] loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len] loss = loss * total_mask # [batch_size*seq_len] loss = loss.sum() / total_mask.sum() # calculating mean loss considering mask return loss def __call__(self, logits_dict, shifted_tgt, mask, valid): train_loss_list = [] log_loss_dict_normal = {} for idx, key in enumerate(self.feature_list): training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask) train_loss_list.append(training_loss) if valid: if key == 'type': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None) elif key == 'beat': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999) elif key == 'chord' or key == 'tempo' or key == 'instrument': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999) else: log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None) k_normal = key + '_normal' log_loss_dict_normal[k_normal] = log_normal_loss total_loss = sum(train_loss_list) / len(train_loss_list) if valid: return total_loss, log_loss_dict_normal else: return total_loss, None def dispersive_loss(z, tau=0.5, eps=1e-8): """使用余弦距离的Dispersive Loss实现""" B = z.size(0) # 计算余弦相似度矩阵 [B, B] z_norm = torch.nn.functional.normalize(z, p=2, dim=1) # 向量归一化 sim_matrix = torch.matmul(z_norm, z_norm.transpose(0, 1)) # 余弦相似度 # 转换为余弦距离 (1 - 相似度),排除对角线 mask = 1 - torch.eye(B, device=z.device) cos_dist = (1 - sim_matrix) * mask # 计算分散性损失(与L2版本相同) exp_term = torch.exp(-cos_dist / tau) mean_exp = exp_term.sum() / (B * (B - 1) + eps) loss = -torch.log(mean_exp + eps) return loss class DiffusionLoss4CompoundToken(): def __init__(self, feature_list, focal_alpha:float, focal_gamma:float): self.feature_list = feature_list self.alpha = focal_alpha self.gamma = focal_gamma def get_nll_loss(self, logits, target, mask,mask_indices, p_mask): if logits.ndim == 3: logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size] if target.ndim == 2: target = target.flatten(0, 1) # [batch_size*seq_len] if mask_indices.ndim == 2: mask_indices = mask_indices.flatten(0, 1) if p_mask.ndim == 2: p_mask = p_mask.flatten(0, 1) if mask.ndim == 2: mask = mask.flatten(0, 1) # datatype of logits, target, mask_indices, p_mask should be the same token_loss = F.cross_entropy( logits[mask_indices], # 直接索引 logits target[mask_indices], reduction='none' ) / p_mask[mask_indices] loss = (token_loss * mask[mask_indices]).sum() / mask[mask_indices].sum() return loss def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token, mask_indices, p_mask): if ignore_token is not None and conti_token is not None: target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len] valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len] elif ignore_token is not None and conti_token is None: valid_mask = (target != ignore_token) elif ignore_token is None and conti_token is None: valid_mask = torch.ones_like(target).bool() valid_mask = valid_mask.flatten(0, 1) if logits.ndim == 3: logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size] if target.ndim == 2: target = target.flatten(0, 1) # [batch_size*seq_len] if mask_indices.ndim == 2: mask_indices = mask_indices.flatten(0, 1) if p_mask.ndim == 2: p_mask = p_mask.flatten(0, 1) token_loss = F.cross_entropy( logits[mask_indices], # 直接索引 logits target[mask_indices], reduction='none' ) / p_mask[mask_indices] total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len] loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum() return loss def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5): train_loss_list = [] log_loss_dict_normal = {} mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1) p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1) disp_loss = None if input_dict is not None: hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim feat = hidden_vec.mean(dim=1) #bs,dim disp_loss = dispersive_loss(feat, tau=tau) # scalar for idx, key in enumerate(self.feature_list): training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx]) train_loss_list.append(training_loss) if valid: if key == 'type': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) elif key == 'beat': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) elif key == 'chord' or key == 'tempo' or key == 'instrument': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) else: log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) k_normal = key + '_normal' log_loss_dict_normal[k_normal] = log_normal_loss total_loss = sum(train_loss_list) / len(train_loss_list) if disp_loss is not None: total_loss = total_loss + lambda_weight * disp_loss log_loss_dict_normal['dispersion'] = disp_loss.item() if valid: return total_loss, log_loss_dict_normal else: return total_loss, None class EncodecFlattenLoss(): def __init__(self, feature_list): self.feature_list = feature_list def get_nll_loss(self, logits, target, mask): probs = logits.softmax(dim=-1) if probs.ndim == 3: probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size] if target.ndim == 2: target = target.flatten(0, 1) # [batch_size*seq_len] pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len] loss_seq = -torch.log(pt) # [batch_size*seq_len] loss_seq = loss_seq * mask.flatten(0, 1) # [batch_size*seq_len] loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask return loss def __call__(self, logits, shifted_tgt, mask): loss = self.get_nll_loss(logits, shifted_tgt, mask) return loss class EncodecMultiClassLoss(EncodecFlattenLoss): def __init__(self, feature_list): super().__init__(feature_list) def __call__(self, logits_dict, shifted_tgt, mask): train_loss_list = [] for idx, key in enumerate(self.feature_list): training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask) train_loss_list.append(training_loss) total_loss = sum(train_loss_list) / len(train_loss_list) return total_loss ########################### Learning rate Scheduler ################################ ''' This scheduler is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/#custom-cosineannealingwarmrestarts-1 It's basically a cosine annealing scheduler with warm restarts including two methods, warm up start and reducing maximum lr. ''' class CosineAnnealingWarmUpRestarts(_LRScheduler): def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1, eta_min=0): if T_0 <= 0 or not isinstance(T_0, int): raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) if T_mult < 1 or not isinstance(T_mult, int): raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) if T_up < 0 or not isinstance(T_up, int): raise ValueError("Expected positive integer T_up, but got {}".format(T_up)) self.T_0 = T_0 self.T_mult = T_mult self.base_eta_max = eta_max self.eta_max = eta_max self.T_up = T_up self.T_i = T_0 self.gamma = gamma self.cycle = 0 self.T_cur = last_epoch super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch) def get_lr(self): if self.T_cur == -1: return self.base_lrs elif self.T_cur < self.T_up: return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs] else: return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2 for base_lr in self.base_lrs] def step(self, epoch=None): if epoch is None: epoch = self.last_epoch + 1 self.T_cur = self.T_cur + 1 if self.T_cur >= self.T_i: self.cycle += 1 self.T_cur = self.T_cur - self.T_i self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up else: if epoch >= self.T_0: if self.T_mult == 1: self.T_cur = epoch % self.T_0 self.cycle = epoch // self.T_0 else: n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) self.cycle = n self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) self.T_i = self.T_0 * self.T_mult ** (n) else: self.T_i = self.T_0 self.T_cur = epoch self.eta_max = self.base_eta_max * (self.gamma**self.cycle) self.last_epoch = math.floor(epoch) for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr class CosineLRScheduler(_LRScheduler): """Cosine LR scheduler. Args: optimizer (Optimizer): Torch optimizer. warmup_steps (int): Number of warmup steps. total_steps (int): Total number of steps. lr_min_ratio (float): Minimum learning rate. cycle_length (float): Cycle length. """ def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, lr_min_ratio: float = 0.0, cycle_length: float = 1.0): self.warmup_steps = warmup_steps assert self.warmup_steps >= 0 self.total_steps = total_steps assert self.total_steps >= 0 self.lr_min_ratio = lr_min_ratio self.cycle_length = cycle_length super().__init__(optimizer) def _get_sched_lr(self, lr: float, step: int): if step < self.warmup_steps: lr_ratio = step / self.warmup_steps lr = lr_ratio * lr elif step <= self.total_steps: s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ (1. + math.cos(math.pi * s / self.cycle_length)) lr = lr_ratio * lr else: lr_ratio = self.lr_min_ratio lr = lr_ratio * lr return lr def get_lr(self): return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] class DispersiveLoss(nn.Module): def __init__(self, loss_type='infonce_l2', tau=0.5, lambda_weight=0.5): super().__init__() self.loss_type = loss_type self.tau = tau self.lambda_weight = lambda_weight def forward(self, features, diffusion_loss): """ features: 批次特征矩阵,形状为 [batch_size, feature_dim] diffusion_loss: 原扩散损失 """ batch_size = features.size(0) # 计算距离矩阵 if self.loss_type == 'infonce_l2': # 计算平方L2距离 dist_matrix = torch.cdist(features, features, p=2) ** 2 # 计算分散损失 exp_dist = torch.exp(-dist_matrix / self.tau) disp_loss = torch.log(exp_dist.mean()) elif self.loss_type == 'hinge': # Hinge损失,假设阈值epsilon=1.0 dist_matrix = torch.cdist(features, features, p=2) disp_loss = torch.max(torch.zeros_like(dist_matrix), 1.0 - dist_matrix).mean() elif self.loss_type == 'covariance': # 协方差损失 normalized_features = (features - features.mean(dim=0)) / features.std(dim=0) cov_matrix = torch.matmul(normalized_features.T, normalized_features) / batch_size # 非对角线元素平方和 mask = ~torch.eye(cov_matrix.size(0), dtype=torch.bool) disp_loss = (cov_matrix[mask] ** 2).mean() else: raise ValueError("Unsupported loss type") # 总损失 = 扩散损失 + lambda * 分散损失 total_loss = diffusion_loss + self.lambda_weight * disp_loss return total_loss, disp_loss