first commit
This commit is contained in:
428
Amadeus/train_utils.py
Normal file
428
Amadeus/train_utils.py
Normal file
@ -0,0 +1,428 @@
|
||||
import math
|
||||
|
||||
from numpy import mask_indices
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim import Optimizer
|
||||
from collections import defaultdict
|
||||
import torch.nn.functional as F
|
||||
|
||||
def add_conti_for_single_feature(tensor):
|
||||
new_target = tensor.clone()
|
||||
# Assuming tensor shape is [batch, sequence, features]
|
||||
# Create a shifted version of the tensor
|
||||
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
|
||||
# The first element of each sequence cannot be a duplicate by definition
|
||||
shifted_tensor[:, 0] = new_target[:, 0] + 1
|
||||
|
||||
# Identify where the original and shifted tensors are the same (duplicates)
|
||||
duplicates = new_target == shifted_tensor
|
||||
# Replace duplicates with 9999
|
||||
new_target[duplicates] = 9999
|
||||
return new_target
|
||||
|
||||
def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_params):
|
||||
feature_prediction_order_dict = {
|
||||
4: ["type", "beat", "pitch", "duration"],
|
||||
5: ["type", "beat", "instrument", "pitch", "duration"],
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
|
||||
}
|
||||
|
||||
if encoding_scheme == 'remi':
|
||||
prediction_order = feature_prediction_order_dict[num_features]
|
||||
elif encoding_scheme == 'cp':
|
||||
if nn_params.get("partial_sequential_prediction", False):
|
||||
default_prediction_order = feature_prediction_order_dict[num_features]
|
||||
prediction_order = [default_prediction_order[0], default_prediction_order[1:]]
|
||||
else:
|
||||
prediction_order = feature_prediction_order_dict[num_features]
|
||||
elif encoding_scheme == 'nb':
|
||||
assert target_feature in feature_prediction_order_dict[num_features], f"Target feature {target_feature} not in the selected sub-token set. Please check target feature in the config and num_features in nn_params."
|
||||
default_prediction_order = feature_prediction_order_dict[num_features]
|
||||
|
||||
# Reorganize the prediction order based on the target_feature
|
||||
target_index = default_prediction_order.index(target_feature)
|
||||
prediction_order = default_prediction_order[target_index:] + default_prediction_order[:target_index]
|
||||
|
||||
return prediction_order
|
||||
|
||||
########################### Loss function ################################
|
||||
|
||||
class NLLLoss4REMI():
|
||||
def __init__(
|
||||
self,
|
||||
focal_alpha:float,
|
||||
focal_gamma:float,
|
||||
):
|
||||
self.alpha = focal_alpha
|
||||
self.gamma = focal_gamma
|
||||
|
||||
def get_nll_loss(self, logits, target, mask):
|
||||
probs = logits.softmax(dim=-1)
|
||||
if probs.ndim == 3:
|
||||
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
# clamp min value to 1e-7 to avoid log(0)
|
||||
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||
loss_seq = loss * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
|
||||
return loss, loss_seq
|
||||
|
||||
def __call__(self, logits, shifted_tgt, mask, vocab):
|
||||
if vocab is not None:
|
||||
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
|
||||
loss_by_class_normal = defaultdict(float)
|
||||
shifted_tgt_with_mask = shifted_tgt * mask # [b, t]
|
||||
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
|
||||
for feature in vocab.feature_list:
|
||||
feature_mask = vocab.total_mask[feature].to(answers_idx.device) # [327,]
|
||||
mask_for_target = feature_mask[answers_idx] # [b*t]
|
||||
normal_loss_seq_by_class = loss_seq * mask_for_target
|
||||
if mask_for_target.sum().item() != 0:
|
||||
loss_by_class_normal[feature+'_normal'] += (normal_loss_seq_by_class.sum().item() / mask_for_target.sum().item())
|
||||
return loss, loss_by_class_normal
|
||||
else:
|
||||
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
|
||||
return loss, None
|
||||
|
||||
class NLLLoss4CompoundToken():
|
||||
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
|
||||
self.feature_list = feature_list
|
||||
self.alpha = focal_alpha
|
||||
self.gamma = focal_gamma
|
||||
|
||||
def get_nll_loss(self, logits, target, mask):
|
||||
probs = logits.softmax(dim=-1)
|
||||
if probs.ndim == 3:
|
||||
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
# clamp min value to 1e-7 to avoid log(0)
|
||||
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
|
||||
return loss
|
||||
|
||||
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token):
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
if ignore_token is not None and conti_token is not None:
|
||||
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
|
||||
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
|
||||
elif ignore_token is not None and conti_token is None:
|
||||
valid_mask = (target != ignore_token)
|
||||
elif ignore_token is None and conti_token is None:
|
||||
valid_mask = torch.ones_like(target).bool()
|
||||
valid_mask = valid_mask.flatten(0, 1)
|
||||
|
||||
if probs.ndim == 3:
|
||||
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
pt = probs[torch.arange(len(target)), target] # [batch_size*seq_len]
|
||||
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
|
||||
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||
loss = loss * total_mask # [batch_size*seq_len]
|
||||
loss = loss.sum() / total_mask.sum() # calculating mean loss considering mask
|
||||
return loss
|
||||
|
||||
def __call__(self, logits_dict, shifted_tgt, mask, valid):
|
||||
train_loss_list = []
|
||||
log_loss_dict_normal = {}
|
||||
for idx, key in enumerate(self.feature_list):
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
|
||||
train_loss_list.append(training_loss)
|
||||
if valid:
|
||||
if key == 'type':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None)
|
||||
elif key == 'beat':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
|
||||
else:
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None)
|
||||
k_normal = key + '_normal'
|
||||
log_loss_dict_normal[k_normal] = log_normal_loss
|
||||
total_loss = sum(train_loss_list) / len(train_loss_list)
|
||||
if valid:
|
||||
return total_loss, log_loss_dict_normal
|
||||
else:
|
||||
return total_loss, None
|
||||
|
||||
def dispersive_loss(z, tau=0.5, eps=1e-8):
|
||||
"""使用余弦距离的Dispersive Loss实现"""
|
||||
B = z.size(0)
|
||||
|
||||
# 计算余弦相似度矩阵 [B, B]
|
||||
z_norm = torch.nn.functional.normalize(z, p=2, dim=1) # 向量归一化
|
||||
sim_matrix = torch.matmul(z_norm, z_norm.transpose(0, 1)) # 余弦相似度
|
||||
|
||||
# 转换为余弦距离 (1 - 相似度),排除对角线
|
||||
mask = 1 - torch.eye(B, device=z.device)
|
||||
cos_dist = (1 - sim_matrix) * mask
|
||||
|
||||
# 计算分散性损失(与L2版本相同)
|
||||
exp_term = torch.exp(-cos_dist / tau)
|
||||
mean_exp = exp_term.sum() / (B * (B - 1) + eps)
|
||||
loss = -torch.log(mean_exp + eps)
|
||||
return loss
|
||||
class DiffusionLoss4CompoundToken():
|
||||
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
|
||||
self.feature_list = feature_list
|
||||
self.alpha = focal_alpha
|
||||
self.gamma = focal_gamma
|
||||
|
||||
def get_nll_loss(self, logits, target, mask,mask_indices, p_mask):
|
||||
if logits.ndim == 3:
|
||||
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
if mask_indices.ndim == 2:
|
||||
mask_indices = mask_indices.flatten(0, 1)
|
||||
if p_mask.ndim == 2:
|
||||
p_mask = p_mask.flatten(0, 1)
|
||||
if mask.ndim == 2:
|
||||
mask = mask.flatten(0, 1)
|
||||
# datatype of logits, target, mask_indices, p_mask should be the same
|
||||
token_loss = F.cross_entropy(
|
||||
logits[mask_indices], # 直接索引 logits
|
||||
target[mask_indices],
|
||||
reduction='none'
|
||||
) / p_mask[mask_indices]
|
||||
loss = (token_loss * mask[mask_indices]).sum() / mask[mask_indices].sum()
|
||||
return loss
|
||||
|
||||
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token, mask_indices, p_mask):
|
||||
if ignore_token is not None and conti_token is not None:
|
||||
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
|
||||
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
|
||||
elif ignore_token is not None and conti_token is None:
|
||||
valid_mask = (target != ignore_token)
|
||||
elif ignore_token is None and conti_token is None:
|
||||
valid_mask = torch.ones_like(target).bool()
|
||||
valid_mask = valid_mask.flatten(0, 1)
|
||||
|
||||
if logits.ndim == 3:
|
||||
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
if mask_indices.ndim == 2:
|
||||
mask_indices = mask_indices.flatten(0, 1)
|
||||
if p_mask.ndim == 2:
|
||||
p_mask = p_mask.flatten(0, 1)
|
||||
token_loss = F.cross_entropy(
|
||||
logits[mask_indices], # 直接索引 logits
|
||||
target[mask_indices],
|
||||
reduction='none'
|
||||
) / p_mask[mask_indices]
|
||||
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
|
||||
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
|
||||
|
||||
return loss
|
||||
|
||||
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
|
||||
train_loss_list = []
|
||||
log_loss_dict_normal = {}
|
||||
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||
disp_loss = None
|
||||
if input_dict is not None:
|
||||
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
|
||||
feat = hidden_vec.mean(dim=1) #bs,dim
|
||||
disp_loss = dispersive_loss(feat, tau=tau) # scalar
|
||||
for idx, key in enumerate(self.feature_list):
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
|
||||
train_loss_list.append(training_loss)
|
||||
if valid:
|
||||
if key == 'type':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
elif key == 'beat':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
else:
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
k_normal = key + '_normal'
|
||||
log_loss_dict_normal[k_normal] = log_normal_loss
|
||||
total_loss = sum(train_loss_list) / len(train_loss_list)
|
||||
if disp_loss is not None:
|
||||
total_loss = total_loss + lambda_weight * disp_loss
|
||||
log_loss_dict_normal['dispersion'] = disp_loss.item()
|
||||
if valid:
|
||||
return total_loss, log_loss_dict_normal
|
||||
else:
|
||||
return total_loss, None
|
||||
|
||||
class EncodecFlattenLoss():
|
||||
def __init__(self, feature_list):
|
||||
self.feature_list = feature_list
|
||||
|
||||
def get_nll_loss(self, logits, target, mask):
|
||||
probs = logits.softmax(dim=-1)
|
||||
if probs.ndim == 3:
|
||||
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||
if target.ndim == 2:
|
||||
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||
loss_seq = -torch.log(pt) # [batch_size*seq_len]
|
||||
loss_seq = loss_seq * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
|
||||
return loss
|
||||
|
||||
def __call__(self, logits, shifted_tgt, mask):
|
||||
loss = self.get_nll_loss(logits, shifted_tgt, mask)
|
||||
return loss
|
||||
|
||||
class EncodecMultiClassLoss(EncodecFlattenLoss):
|
||||
def __init__(self, feature_list):
|
||||
super().__init__(feature_list)
|
||||
|
||||
def __call__(self, logits_dict, shifted_tgt, mask):
|
||||
train_loss_list = []
|
||||
for idx, key in enumerate(self.feature_list):
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
|
||||
train_loss_list.append(training_loss)
|
||||
total_loss = sum(train_loss_list) / len(train_loss_list)
|
||||
return total_loss
|
||||
|
||||
########################### Learning rate Scheduler ################################
|
||||
'''
|
||||
This scheduler is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/#custom-cosineannealingwarmrestarts-1
|
||||
It's basically a cosine annealing scheduler with warm restarts including two methods, warm up start and reducing maximum lr.
|
||||
'''
|
||||
|
||||
class CosineAnnealingWarmUpRestarts(_LRScheduler):
|
||||
def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1, eta_min=0):
|
||||
if T_0 <= 0 or not isinstance(T_0, int):
|
||||
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
||||
if T_mult < 1 or not isinstance(T_mult, int):
|
||||
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
||||
if T_up < 0 or not isinstance(T_up, int):
|
||||
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
|
||||
self.T_0 = T_0
|
||||
self.T_mult = T_mult
|
||||
self.base_eta_max = eta_max
|
||||
self.eta_max = eta_max
|
||||
self.T_up = T_up
|
||||
self.T_i = T_0
|
||||
self.gamma = gamma
|
||||
self.cycle = 0
|
||||
self.T_cur = last_epoch
|
||||
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.T_cur == -1:
|
||||
return self.base_lrs
|
||||
elif self.T_cur < self.T_up:
|
||||
return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
|
||||
else:
|
||||
return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
self.T_cur = self.T_cur + 1
|
||||
if self.T_cur >= self.T_i:
|
||||
self.cycle += 1
|
||||
self.T_cur = self.T_cur - self.T_i
|
||||
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
|
||||
else:
|
||||
if epoch >= self.T_0:
|
||||
if self.T_mult == 1:
|
||||
self.T_cur = epoch % self.T_0
|
||||
self.cycle = epoch // self.T_0
|
||||
else:
|
||||
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
|
||||
self.cycle = n
|
||||
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
||||
self.T_i = self.T_0 * self.T_mult ** (n)
|
||||
else:
|
||||
self.T_i = self.T_0
|
||||
self.T_cur = epoch
|
||||
|
||||
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
|
||||
self.last_epoch = math.floor(epoch)
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group['lr'] = lr
|
||||
|
||||
class CosineLRScheduler(_LRScheduler):
|
||||
"""Cosine LR scheduler.
|
||||
Args:
|
||||
optimizer (Optimizer): Torch optimizer.
|
||||
warmup_steps (int): Number of warmup steps.
|
||||
total_steps (int): Total number of steps.
|
||||
lr_min_ratio (float): Minimum learning rate.
|
||||
cycle_length (float): Cycle length.
|
||||
"""
|
||||
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
|
||||
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
|
||||
self.warmup_steps = warmup_steps
|
||||
assert self.warmup_steps >= 0
|
||||
self.total_steps = total_steps
|
||||
assert self.total_steps >= 0
|
||||
self.lr_min_ratio = lr_min_ratio
|
||||
self.cycle_length = cycle_length
|
||||
super().__init__(optimizer)
|
||||
|
||||
def _get_sched_lr(self, lr: float, step: int):
|
||||
if step < self.warmup_steps:
|
||||
lr_ratio = step / self.warmup_steps
|
||||
lr = lr_ratio * lr
|
||||
elif step <= self.total_steps:
|
||||
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
||||
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
|
||||
(1. + math.cos(math.pi * s / self.cycle_length))
|
||||
lr = lr_ratio * lr
|
||||
else:
|
||||
lr_ratio = self.lr_min_ratio
|
||||
lr = lr_ratio * lr
|
||||
return lr
|
||||
|
||||
def get_lr(self):
|
||||
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
|
||||
|
||||
|
||||
class DispersiveLoss(nn.Module):
|
||||
def __init__(self, loss_type='infonce_l2', tau=0.5, lambda_weight=0.5):
|
||||
super().__init__()
|
||||
self.loss_type = loss_type
|
||||
self.tau = tau
|
||||
self.lambda_weight = lambda_weight
|
||||
|
||||
def forward(self, features, diffusion_loss):
|
||||
"""
|
||||
features: 批次特征矩阵,形状为 [batch_size, feature_dim]
|
||||
diffusion_loss: 原扩散损失
|
||||
"""
|
||||
batch_size = features.size(0)
|
||||
|
||||
# 计算距离矩阵
|
||||
if self.loss_type == 'infonce_l2':
|
||||
# 计算平方L2距离
|
||||
dist_matrix = torch.cdist(features, features, p=2) ** 2
|
||||
# 计算分散损失
|
||||
exp_dist = torch.exp(-dist_matrix / self.tau)
|
||||
disp_loss = torch.log(exp_dist.mean())
|
||||
elif self.loss_type == 'hinge':
|
||||
# Hinge损失,假设阈值epsilon=1.0
|
||||
dist_matrix = torch.cdist(features, features, p=2)
|
||||
disp_loss = torch.max(torch.zeros_like(dist_matrix), 1.0 - dist_matrix).mean()
|
||||
elif self.loss_type == 'covariance':
|
||||
# 协方差损失
|
||||
normalized_features = (features - features.mean(dim=0)) / features.std(dim=0)
|
||||
cov_matrix = torch.matmul(normalized_features.T, normalized_features) / batch_size
|
||||
# 非对角线元素平方和
|
||||
mask = ~torch.eye(cov_matrix.size(0), dtype=torch.bool)
|
||||
disp_loss = (cov_matrix[mask] ** 2).mean()
|
||||
else:
|
||||
raise ValueError("Unsupported loss type")
|
||||
|
||||
# 总损失 = 扩散损失 + lambda * 分散损失
|
||||
total_loss = diffusion_loss + self.lambda_weight * disp_loss
|
||||
return total_loss, disp_loss
|
||||
Reference in New Issue
Block a user