1029 add octuple
This commit is contained in:
@ -29,8 +29,12 @@ def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_pa
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
|
||||
}
|
||||
|
||||
if encoding_scheme == 'remi':
|
||||
oct_prediction_order = {
|
||||
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
|
||||
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
|
||||
if encoding_scheme == 'oct':
|
||||
prediction_order = oct_prediction_order[num_features]
|
||||
elif encoding_scheme == 'remi':
|
||||
prediction_order = feature_prediction_order_dict[num_features]
|
||||
elif encoding_scheme == 'cp':
|
||||
if nn_params.get("partial_sequential_prediction", False):
|
||||
@ -239,11 +243,11 @@ class DiffusionLoss4CompoundToken():
|
||||
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
|
||||
train_loss_list.append(training_loss)
|
||||
if valid:
|
||||
if key == 'type':
|
||||
if key == 'type' or key == 'timesig':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
elif key == 'beat':
|
||||
elif key == 'beat' or key == 'position' or key == 'bar':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
else:
|
||||
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||
|
||||
Reference in New Issue
Block a user