Files
MIDIFoundationModel/Amadeus/train_utils.py
2025-11-27 15:44:17 +08:00

452 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
from numpy import mask_indices
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
from collections import defaultdict
import torch.nn.functional as F
def add_conti_for_single_feature(tensor):
new_target = tensor.clone()
# Assuming tensor shape is [batch, sequence, features]
# Create a shifted version of the tensor
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
# The first element of each sequence cannot be a duplicate by definition
shifted_tensor[:, 0] = new_target[:, 0] + 1
# Identify where the original and shifted tensors are the same (duplicates)
duplicates = new_target == shifted_tensor
# Replace duplicates with 9999
new_target[duplicates] = 9999
return new_target
def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_params):
feature_prediction_order_dict = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
}
oct_prediction_order = {
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
if encoding_scheme == 'oct':
prediction_order = oct_prediction_order[num_features]
elif encoding_scheme == 'remi':
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'cp':
if nn_params.get("partial_sequential_prediction", False):
default_prediction_order = feature_prediction_order_dict[num_features]
prediction_order = [default_prediction_order[0], default_prediction_order[1:]]
else:
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'nb':
assert target_feature in feature_prediction_order_dict[num_features], f"Target feature {target_feature} not in the selected sub-token set. Please check target feature in the config and num_features in nn_params."
default_prediction_order = feature_prediction_order_dict[num_features]
# Reorganize the prediction order based on the target_feature
target_index = default_prediction_order.index(target_feature)
prediction_order = default_prediction_order[target_index:] + default_prediction_order[:target_index]
return prediction_order
########################### Loss function ################################
class NLLLoss4REMI():
def __init__(
self,
focal_alpha:float,
focal_gamma:float,
):
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss_seq = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss, loss_seq
def __call__(self, logits, shifted_tgt, mask, vocab):
if vocab is not None:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
loss_by_class_normal = defaultdict(float)
shifted_tgt_with_mask = shifted_tgt * mask # [b, t]
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
for feature in vocab.feature_list:
feature_mask = vocab.total_mask[feature].to(answers_idx.device) # [327,]
mask_for_target = feature_mask[answers_idx] # [b*t]
normal_loss_seq_by_class = loss_seq * mask_for_target
if mask_for_target.sum().item() != 0:
loss_by_class_normal[feature+'_normal'] += (normal_loss_seq_by_class.sum().item() / mask_for_target.sum().item())
return loss, loss_by_class_normal
else:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
return loss, None
class NLLLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token):
probs = logits.softmax(dim=-1)
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target] # [batch_size*seq_len]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * total_mask # [batch_size*seq_len]
loss = loss.sum() / total_mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits_dict, shifted_tgt, mask, valid):
train_loss_list = []
log_loss_dict_normal = {}
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
if valid:
if key == 'type':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'beat':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'chord' or key == 'tempo' or key == 'instrument':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None)
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
def dispersive_loss(z, tau=0.5, eps=1e-8):
"""使用余弦距离的Dispersive Loss实现"""
B = z.size(0)
# 计算余弦相似度矩阵 [B, B]
z_norm = torch.nn.functional.normalize(z, p=2, dim=1) # 向量归一化
sim_matrix = torch.matmul(z_norm, z_norm.transpose(0, 1)) # 余弦相似度
# 转换为余弦距离 (1 - 相似度),排除对角线
mask = 1 - torch.eye(B, device=z.device)
cos_dist = (1 - sim_matrix) * mask
# 计算分散性损失与L2版本相同
exp_term = torch.exp(-cos_dist / tau)
mean_exp = exp_term.sum() / (B * (B - 1) + eps)
loss = -torch.log(mean_exp + eps)
return loss
class DiffusionLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask,mask_indices, p_mask):
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
if mask.ndim == 2:
mask = mask.flatten(0, 1)
# datatype of logits, target, mask_indices, p_mask should be the same
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
loss = (token_loss * mask[mask_indices]).sum() / mask[mask_indices].sum()
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token, mask_indices, p_mask):
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
return loss
def get_aux_ar_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
train_loss_list = []
log_loss_dict_normal = {}
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
disp_loss = None
aux_ar_logits = None
# print(len(logits_dict))
if len(logits_dict) == 2: # has aux ar loss
logits_dict, aux_ar_logits = logits_dict
if input_dict is not None:
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
feat = hidden_vec.mean(dim=1) #bs,dim
disp_loss = dispersive_loss(feat, tau=tau) # scalar
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
if aux_ar_logits is not None:
aux_ar_loss = self.get_aux_ar_nll_loss(aux_ar_logits[key], shifted_tgt[..., idx], mask)
training_loss = 0.5 * training_loss + 0.5 * aux_ar_loss
train_loss_list.append(training_loss)
if valid:
if key == 'type' or key == 'timesig':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'beat' or key == 'position' or key == 'bar':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if disp_loss is not None:
total_loss = total_loss + lambda_weight * disp_loss
log_loss_dict_normal['dispersion'] = disp_loss.item()
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
class EncodecFlattenLoss():
def __init__(self, feature_list):
self.feature_list = feature_list
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss_seq = -torch.log(pt) # [batch_size*seq_len]
loss_seq = loss_seq * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits, shifted_tgt, mask):
loss = self.get_nll_loss(logits, shifted_tgt, mask)
return loss
class EncodecMultiClassLoss(EncodecFlattenLoss):
def __init__(self, feature_list):
super().__init__(feature_list)
def __call__(self, logits_dict, shifted_tgt, mask):
train_loss_list = []
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
total_loss = sum(train_loss_list) / len(train_loss_list)
return total_loss
########################### Learning rate Scheduler ################################
'''
This scheduler is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/#custom-cosineannealingwarmrestarts-1
It's basically a cosine annealing scheduler with warm restarts including two methods, warm up start and reducing maximum lr.
'''
class CosineAnnealingWarmUpRestarts(_LRScheduler):
def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1, eta_min=0):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
if T_up < 0 or not isinstance(T_up, int):
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
self.T_0 = T_0
self.T_mult = T_mult
self.base_eta_max = eta_max
self.eta_max = eta_max
self.T_up = T_up
self.T_i = T_0
self.gamma = gamma
self.cycle = 0
self.T_cur = last_epoch
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.T_cur == -1:
return self.base_lrs
elif self.T_cur < self.T_up:
return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
else:
return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.cycle += 1
self.T_cur = self.T_cur - self.T_i
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
else:
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
self.cycle = epoch // self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.cycle = n
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineLRScheduler(_LRScheduler):
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
class DispersiveLoss(nn.Module):
def __init__(self, loss_type='infonce_l2', tau=0.5, lambda_weight=0.5):
super().__init__()
self.loss_type = loss_type
self.tau = tau
self.lambda_weight = lambda_weight
def forward(self, features, diffusion_loss):
"""
features: 批次特征矩阵,形状为 [batch_size, feature_dim]
diffusion_loss: 原扩散损失
"""
batch_size = features.size(0)
# 计算距离矩阵
if self.loss_type == 'infonce_l2':
# 计算平方L2距离
dist_matrix = torch.cdist(features, features, p=2) ** 2
# 计算分散损失
exp_dist = torch.exp(-dist_matrix / self.tau)
disp_loss = torch.log(exp_dist.mean())
elif self.loss_type == 'hinge':
# Hinge损失假设阈值epsilon=1.0
dist_matrix = torch.cdist(features, features, p=2)
disp_loss = torch.max(torch.zeros_like(dist_matrix), 1.0 - dist_matrix).mean()
elif self.loss_type == 'covariance':
# 协方差损失
normalized_features = (features - features.mean(dim=0)) / features.std(dim=0)
cov_matrix = torch.matmul(normalized_features.T, normalized_features) / batch_size
# 非对角线元素平方和
mask = ~torch.eye(cov_matrix.size(0), dtype=torch.bool)
disp_loss = (cov_matrix[mask] ** 2).mean()
else:
raise ValueError("Unsupported loss type")
# 总损失 = 扩散损失 + lambda * 分散损失
total_loss = diffusion_loss + self.lambda_weight * disp_loss
return total_loss, disp_loss