first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

428
Amadeus/train_utils.py Normal file
View File

@ -0,0 +1,428 @@
import math
from numpy import mask_indices
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
from collections import defaultdict
import torch.nn.functional as F
def add_conti_for_single_feature(tensor):
new_target = tensor.clone()
# Assuming tensor shape is [batch, sequence, features]
# Create a shifted version of the tensor
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
# The first element of each sequence cannot be a duplicate by definition
shifted_tensor[:, 0] = new_target[:, 0] + 1
# Identify where the original and shifted tensors are the same (duplicates)
duplicates = new_target == shifted_tensor
# Replace duplicates with 9999
new_target[duplicates] = 9999
return new_target
def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_params):
feature_prediction_order_dict = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
}
if encoding_scheme == 'remi':
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'cp':
if nn_params.get("partial_sequential_prediction", False):
default_prediction_order = feature_prediction_order_dict[num_features]
prediction_order = [default_prediction_order[0], default_prediction_order[1:]]
else:
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'nb':
assert target_feature in feature_prediction_order_dict[num_features], f"Target feature {target_feature} not in the selected sub-token set. Please check target feature in the config and num_features in nn_params."
default_prediction_order = feature_prediction_order_dict[num_features]
# Reorganize the prediction order based on the target_feature
target_index = default_prediction_order.index(target_feature)
prediction_order = default_prediction_order[target_index:] + default_prediction_order[:target_index]
return prediction_order
########################### Loss function ################################
class NLLLoss4REMI():
def __init__(
self,
focal_alpha:float,
focal_gamma:float,
):
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss_seq = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss, loss_seq
def __call__(self, logits, shifted_tgt, mask, vocab):
if vocab is not None:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
loss_by_class_normal = defaultdict(float)
shifted_tgt_with_mask = shifted_tgt * mask # [b, t]
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
for feature in vocab.feature_list:
feature_mask = vocab.total_mask[feature].to(answers_idx.device) # [327,]
mask_for_target = feature_mask[answers_idx] # [b*t]
normal_loss_seq_by_class = loss_seq * mask_for_target
if mask_for_target.sum().item() != 0:
loss_by_class_normal[feature+'_normal'] += (normal_loss_seq_by_class.sum().item() / mask_for_target.sum().item())
return loss, loss_by_class_normal
else:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
return loss, None
class NLLLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token):
probs = logits.softmax(dim=-1)
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target] # [batch_size*seq_len]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * total_mask # [batch_size*seq_len]
loss = loss.sum() / total_mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits_dict, shifted_tgt, mask, valid):
train_loss_list = []
log_loss_dict_normal = {}
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
if valid:
if key == 'type':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'beat':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'chord' or key == 'tempo' or key == 'instrument':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None)
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
def dispersive_loss(z, tau=0.5, eps=1e-8):
"""使用余弦距离的Dispersive Loss实现"""
B = z.size(0)
# 计算余弦相似度矩阵 [B, B]
z_norm = torch.nn.functional.normalize(z, p=2, dim=1) # 向量归一化
sim_matrix = torch.matmul(z_norm, z_norm.transpose(0, 1)) # 余弦相似度
# 转换为余弦距离 (1 - 相似度),排除对角线
mask = 1 - torch.eye(B, device=z.device)
cos_dist = (1 - sim_matrix) * mask
# 计算分散性损失与L2版本相同
exp_term = torch.exp(-cos_dist / tau)
mean_exp = exp_term.sum() / (B * (B - 1) + eps)
loss = -torch.log(mean_exp + eps)
return loss
class DiffusionLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask,mask_indices, p_mask):
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
if mask.ndim == 2:
mask = mask.flatten(0, 1)
# datatype of logits, target, mask_indices, p_mask should be the same
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
loss = (token_loss * mask[mask_indices]).sum() / mask[mask_indices].sum()
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token, mask_indices, p_mask):
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
return loss
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
train_loss_list = []
log_loss_dict_normal = {}
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
disp_loss = None
if input_dict is not None:
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
feat = hidden_vec.mean(dim=1) #bs,dim
disp_loss = dispersive_loss(feat, tau=tau) # scalar
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
train_loss_list.append(training_loss)
if valid:
if key == 'type':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'beat':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'chord' or key == 'tempo' or key == 'instrument':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if disp_loss is not None:
total_loss = total_loss + lambda_weight * disp_loss
log_loss_dict_normal['dispersion'] = disp_loss.item()
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
class EncodecFlattenLoss():
def __init__(self, feature_list):
self.feature_list = feature_list
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss_seq = -torch.log(pt) # [batch_size*seq_len]
loss_seq = loss_seq * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits, shifted_tgt, mask):
loss = self.get_nll_loss(logits, shifted_tgt, mask)
return loss
class EncodecMultiClassLoss(EncodecFlattenLoss):
def __init__(self, feature_list):
super().__init__(feature_list)
def __call__(self, logits_dict, shifted_tgt, mask):
train_loss_list = []
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
total_loss = sum(train_loss_list) / len(train_loss_list)
return total_loss
########################### Learning rate Scheduler ################################
'''
This scheduler is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/#custom-cosineannealingwarmrestarts-1
It's basically a cosine annealing scheduler with warm restarts including two methods, warm up start and reducing maximum lr.
'''
class CosineAnnealingWarmUpRestarts(_LRScheduler):
def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1, eta_min=0):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
if T_up < 0 or not isinstance(T_up, int):
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
self.T_0 = T_0
self.T_mult = T_mult
self.base_eta_max = eta_max
self.eta_max = eta_max
self.T_up = T_up
self.T_i = T_0
self.gamma = gamma
self.cycle = 0
self.T_cur = last_epoch
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.T_cur == -1:
return self.base_lrs
elif self.T_cur < self.T_up:
return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
else:
return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.cycle += 1
self.T_cur = self.T_cur - self.T_i
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
else:
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
self.cycle = epoch // self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.cycle = n
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineLRScheduler(_LRScheduler):
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
class DispersiveLoss(nn.Module):
def __init__(self, loss_type='infonce_l2', tau=0.5, lambda_weight=0.5):
super().__init__()
self.loss_type = loss_type
self.tau = tau
self.lambda_weight = lambda_weight
def forward(self, features, diffusion_loss):
"""
features: 批次特征矩阵,形状为 [batch_size, feature_dim]
diffusion_loss: 原扩散损失
"""
batch_size = features.size(0)
# 计算距离矩阵
if self.loss_type == 'infonce_l2':
# 计算平方L2距离
dist_matrix = torch.cdist(features, features, p=2) ** 2
# 计算分散损失
exp_dist = torch.exp(-dist_matrix / self.tau)
disp_loss = torch.log(exp_dist.mean())
elif self.loss_type == 'hinge':
# Hinge损失假设阈值epsilon=1.0
dist_matrix = torch.cdist(features, features, p=2)
disp_loss = torch.max(torch.zeros_like(dist_matrix), 1.0 - dist_matrix).mean()
elif self.loss_type == 'covariance':
# 协方差损失
normalized_features = (features - features.mean(dim=0)) / features.std(dim=0)
cov_matrix = torch.matmul(normalized_features.T, normalized_features) / batch_size
# 非对角线元素平方和
mask = ~torch.eye(cov_matrix.size(0), dtype=torch.bool)
disp_loss = (cov_matrix[mask] ** 2).mean()
else:
raise ValueError("Unsupported loss type")
# 总损失 = 扩散损失 + lambda * 分散损失
total_loss = diffusion_loss + self.lambda_weight * disp_loss
return total_loss, disp_loss