first commit
This commit is contained in:
BIN
Amadeus/.DS_Store
vendored
Normal file
BIN
Amadeus/.DS_Store
vendored
Normal file
Binary file not shown.
0
Amadeus/__init__.py
Normal file
0
Amadeus/__init__.py
Normal file
BIN
Amadeus/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/__pycache__/evaluation_utils.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/evaluation_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/__pycache__/model_zoo.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/model_zoo.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/__pycache__/sampling_utils.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/sampling_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/__pycache__/sub_decoder_utils.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/sub_decoder_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/__pycache__/sub_decoder_zoo.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/sub_decoder_zoo.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/__pycache__/train_utils.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/train_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/__pycache__/transformer_utils.cpython-310.pyc
Normal file
BIN
Amadeus/__pycache__/transformer_utils.cpython-310.pyc
Normal file
Binary file not shown.
56
Amadeus/catsample.py
Normal file
56
Amadeus/catsample.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
|
||||||
|
logits = categorical_probs.clamp(min=1e-9).log()
|
||||||
|
return F.gumbel_softmax(logits, hard=hard)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_categorical(categorical_probs, method="hard"):
|
||||||
|
if method == "hard":
|
||||||
|
gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
|
||||||
|
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Method {method} for sampling categorical variables is not valid.")
|
||||||
|
|
||||||
|
def direct_sampling(logits):
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
index = sample_categorical(probs.to(torch.float32))
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def top_p_sampling(logits, p=0.9):
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
|
||||||
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
probs.masked_fill_(indices_to_remove, 0)
|
||||||
|
probs /= probs.sum(dim=-1).unsqueeze(-1)
|
||||||
|
index = sample_categorical(probs.to(torch.float32))
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_sampling(logits, k=400):
|
||||||
|
top_k_values, top_k_indices = torch.topk(logits, int(k))
|
||||||
|
top_k_probs = top_k_values.softmax(dim=-1)
|
||||||
|
index = sample_categorical(top_k_probs.to(torch.float32))
|
||||||
|
index = top_k_indices[torch.arange(index.size(0)), index]
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
def sample_with_strategy(update_logits, strategy, para = None):
|
||||||
|
if strategy == "direct":
|
||||||
|
return direct_sampling(update_logits)
|
||||||
|
elif strategy == "top_p":
|
||||||
|
return top_p_sampling(update_logits, para)
|
||||||
|
elif strategy == "top_k":
|
||||||
|
return top_k_sampling(update_logits, para)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Strategy {strategy} is not valid.")
|
||||||
533
Amadeus/evaluation_utils.py
Normal file
533
Amadeus/evaluation_utils.py
Normal file
@ -0,0 +1,533 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import Union
|
||||||
|
from math import log
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from pathlib import Path
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import T5Tokenizer, T5EncoderModel
|
||||||
|
|
||||||
|
from . import model_zoo
|
||||||
|
from .symbolic_encoding import data_utils
|
||||||
|
from .model_zoo import AmadeusModel
|
||||||
|
from .symbolic_encoding.data_utils import TuneCompiler
|
||||||
|
from .symbolic_encoding.compile_utils import shift_and_pad
|
||||||
|
from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
|
||||||
|
from .symbolic_encoding import decoding_utils
|
||||||
|
from .train_utils import adjust_prediction_order
|
||||||
|
from data_representation import vocab_utils
|
||||||
|
from data_representation.vocab_utils import LangTokenVocab
|
||||||
|
|
||||||
|
def wandb_style_config_to_omega_config(wandb_conf):
|
||||||
|
# remove wandb related config
|
||||||
|
for wandb_key in ["wandb_version", "_wandb"]:
|
||||||
|
if wandb_key in wandb_conf:
|
||||||
|
del wandb_conf[wandb_key] # wandb-related config should not be overrided!
|
||||||
|
# print(wandb_conf)
|
||||||
|
# remove nonnecessary fields such as desc and value
|
||||||
|
for key in wandb_conf:
|
||||||
|
# if 'desc' in wandb_conf[key]:
|
||||||
|
# del wandb_conf[key]['desc']
|
||||||
|
if isinstance(wandb_conf[key], dict) and 'value' in wandb_conf[key]:
|
||||||
|
wandb_conf[key] = wandb_conf[key]['value']
|
||||||
|
# 处理存在'value'的情况
|
||||||
|
try:
|
||||||
|
if 'value' in wandb_conf[key]:
|
||||||
|
wandb_conf[key] = wandb_conf[key]['value']
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return wandb_conf
|
||||||
|
|
||||||
|
def get_dir_from_wandb_by_code(wandb_dir: Path, code:str) -> Path:
|
||||||
|
for dir in wandb_dir.iterdir():
|
||||||
|
if dir.name.endswith(code):
|
||||||
|
return dir
|
||||||
|
print(f'No such code in wandb_dir: {code}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_best_ckpt_path_and_config(wandb_dir, code):
|
||||||
|
dir = get_dir_from_wandb_by_code(wandb_dir, code)
|
||||||
|
if dir is None:
|
||||||
|
raise ValueError('No such code in wandb_dir')
|
||||||
|
ckpt_dir = dir / 'files' / 'checkpoints'
|
||||||
|
|
||||||
|
config_path = dir / 'files' / 'config.yaml'
|
||||||
|
# print all files in ckpt_dir
|
||||||
|
vocab_path = next(ckpt_dir.glob('vocab*'))
|
||||||
|
metadata_path = next(ckpt_dir.glob('*metadata.json'))
|
||||||
|
|
||||||
|
# if there is pt file ending with 'last', return it
|
||||||
|
if len(list(ckpt_dir.glob('*last.pt'))) > 0:
|
||||||
|
last_ckpt_fn = next(ckpt_dir.glob('*last.pt'))
|
||||||
|
else:
|
||||||
|
pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', '')))
|
||||||
|
last_ckpt_fn = pt_fns[-1]
|
||||||
|
|
||||||
|
return last_ckpt_fn, config_path, metadata_path, vocab_path
|
||||||
|
|
||||||
|
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
|
||||||
|
nn_params = config.nn_params
|
||||||
|
dataset_name = config.dataset
|
||||||
|
vocab_path = Path(vocab_path)
|
||||||
|
|
||||||
|
if 'Encodec' in dataset_name:
|
||||||
|
encodec_tokens_path = Path(f"dataset/maestro-v3.0.0-encodec_tokens")
|
||||||
|
encodec_dataset = EncodecDataset(config, encodec_tokens_path, None, None)
|
||||||
|
vocab_sizes = encodec_dataset.vocab.get_vocab_size()
|
||||||
|
train_set, valid_set, test_set = encodec_dataset.split_train_valid_test_set()
|
||||||
|
|
||||||
|
lm_model:model_zoo.LanguageModelTransformer= getattr(model_zoo, nn_params.model_name)(config, vocab_sizes)
|
||||||
|
else:
|
||||||
|
# print(config)
|
||||||
|
encoding_scheme = config.nn_params.encoding_scheme
|
||||||
|
num_features = config.nn_params.num_features
|
||||||
|
|
||||||
|
# get vocab
|
||||||
|
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||||
|
selected_vocab_name = vocab_name[encoding_scheme]
|
||||||
|
|
||||||
|
vocab = getattr(vocab_utils, selected_vocab_name)(
|
||||||
|
in_vocab_file_path=vocab_path,
|
||||||
|
event_data=None,
|
||||||
|
encoding_scheme=encoding_scheme,
|
||||||
|
num_features=num_features)
|
||||||
|
|
||||||
|
# Initialize symbolic dataset based on dataset name and configuration parameters
|
||||||
|
symbolic_dataset = getattr(data_utils, dataset_name)(
|
||||||
|
vocab=vocab,
|
||||||
|
encoding_scheme=encoding_scheme,
|
||||||
|
num_features=num_features,
|
||||||
|
debug=config.general.debug,
|
||||||
|
aug_type=config.data_params.aug_type,
|
||||||
|
input_length=config.train_params.input_length,
|
||||||
|
first_pred_feature=config.data_params.first_pred_feature,
|
||||||
|
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
|
||||||
|
for_evaluation=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
|
||||||
|
print(f"---{nn_params.main_decoder}--- is used")
|
||||||
|
print(f"---{dataset_name}--- is used")
|
||||||
|
print(f"---{encoding_scheme}--- is used")
|
||||||
|
split_ratio = config.data_params.split_ratio
|
||||||
|
# test_set = []
|
||||||
|
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
|
||||||
|
|
||||||
|
# get proper prediction order according to the encoding scheme and target feature in the config
|
||||||
|
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
|
||||||
|
|
||||||
|
# Create the Transformer model based on configuration parameters
|
||||||
|
AmadeusModel = getattr(model_zoo, nn_params.model_name)(
|
||||||
|
vocab=symbolic_dataset.vocab,
|
||||||
|
input_length=config.train_params.input_length,
|
||||||
|
prediction_order=prediction_order,
|
||||||
|
input_embedder_name=nn_params.input_embedder_name,
|
||||||
|
main_decoder_name=nn_params.main_decoder_name,
|
||||||
|
sub_decoder_name=nn_params.sub_decoder_name,
|
||||||
|
sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0,
|
||||||
|
sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \
|
||||||
|
if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False,
|
||||||
|
dim=nn_params.main_decoder.dim_model,
|
||||||
|
heads=nn_params.main_decoder.num_head,
|
||||||
|
depth=nn_params.main_decoder.num_layer,
|
||||||
|
dropout=nn_params.model_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AmadeusModel, test_set, symbolic_dataset.vocab
|
||||||
|
|
||||||
|
def add_conti_in_valid(tensor, encoding_scheme):
|
||||||
|
new_target = tensor.clone()
|
||||||
|
# Assuming tensor shape is [batch, sequence, features]
|
||||||
|
# Create a shifted version of the tensor
|
||||||
|
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
|
||||||
|
# The first element of each sequence cannot be a duplicate by definition
|
||||||
|
shifted_tensor[:, 0, :] = new_target[:, 0, :] + 1
|
||||||
|
|
||||||
|
# Identify where the original and shifted tensors are the same (duplicates)
|
||||||
|
duplicates = new_target == shifted_tensor
|
||||||
|
# TODO: convert hard-coded part
|
||||||
|
# convert values into False except the 1st and 2nd features
|
||||||
|
if encoding_scheme == 'nb':
|
||||||
|
if tensor.shape[2] == 5:
|
||||||
|
# change beat, instrument
|
||||||
|
duplicates[:, :, 0] = False
|
||||||
|
duplicates[:, :, 3] = False
|
||||||
|
duplicates[:, :, 4] = False
|
||||||
|
elif tensor.shape[2] == 4:
|
||||||
|
# change beat
|
||||||
|
duplicates[:, :, 0] = False
|
||||||
|
duplicates[:, :, 2] = False
|
||||||
|
duplicates[:, :, 3] = False
|
||||||
|
elif tensor.shape[2] == 7:
|
||||||
|
# change beat, chord, tempo
|
||||||
|
duplicates[:, :, 0] = False
|
||||||
|
duplicates[:, :, 4] = False
|
||||||
|
duplicates[:, :, 5] = False
|
||||||
|
duplicates[:, :, 6] = False
|
||||||
|
elif encoding_scheme == 'cp':
|
||||||
|
if tensor.shape[2] == 5:
|
||||||
|
# change instrument
|
||||||
|
duplicates[:, :, 0] = False
|
||||||
|
duplicates[:, :, 1] = False
|
||||||
|
duplicates[:, :, 3] = False
|
||||||
|
duplicates[:, :, 4] = False
|
||||||
|
elif tensor.shape[2] == 7:
|
||||||
|
# change chord, tempo
|
||||||
|
duplicates[:, :, 0] = False
|
||||||
|
duplicates[:, :, 1] = False
|
||||||
|
duplicates[:, :, 4] = False
|
||||||
|
duplicates[:, :, 5] = False
|
||||||
|
duplicates[:, :, 6] = False
|
||||||
|
|
||||||
|
# Replace duplicates with 9999
|
||||||
|
new_target[duplicates] = 9999
|
||||||
|
return new_target
|
||||||
|
|
||||||
|
# TODO: hard coded
|
||||||
|
def add_conti(list_of_lists, encoding_scheme):
|
||||||
|
if encoding_scheme == 'nb':
|
||||||
|
if len(list_of_lists[0]) == 4:
|
||||||
|
# type, beat, pitch, duration
|
||||||
|
for i in range(0, len(list_of_lists)):
|
||||||
|
if list_of_lists[i][0] == 'SSS':
|
||||||
|
list_of_lists[i][1] = 'Conti'
|
||||||
|
elif len(list_of_lists[0]) == 5:
|
||||||
|
# type, beat, instrument, pitch, duration
|
||||||
|
previous_instrument = None
|
||||||
|
for i in range(0, len(list_of_lists)):
|
||||||
|
if list_of_lists[i][0] == 'SSS':
|
||||||
|
list_of_lists[i][1] = 'Conti'
|
||||||
|
if list_of_lists[i][2] == previous_instrument and previous_instrument != 0:
|
||||||
|
list_of_lists[i][2] = 'Conti'
|
||||||
|
else:
|
||||||
|
previous_instrument = list_of_lists[i][2]
|
||||||
|
elif len(list_of_lists[0]) == 7:
|
||||||
|
# type, beat, chord, tempo, pitch, duration, velocity
|
||||||
|
previous_chord = None
|
||||||
|
previous_tempo = None
|
||||||
|
for i in range(0, len(list_of_lists)):
|
||||||
|
if list_of_lists[i][0] == 'SSS':
|
||||||
|
list_of_lists[i][1] = 'Conti'
|
||||||
|
if list_of_lists[i][2] == previous_chord and previous_chord != 0:
|
||||||
|
list_of_lists[i][2] = 'Conti'
|
||||||
|
elif list_of_lists[i][2] != previous_chord and list_of_lists[i][2] != 0:
|
||||||
|
previous_chord = list_of_lists[i][2]
|
||||||
|
if list_of_lists[i][3] == previous_tempo and previous_tempo != 0:
|
||||||
|
list_of_lists[i][3] = 'Conti'
|
||||||
|
elif list_of_lists[i][3] != previous_tempo and list_of_lists[i][3] != 0:
|
||||||
|
previous_tempo = list_of_lists[i][3]
|
||||||
|
elif encoding_scheme == 'cp':
|
||||||
|
if len(list_of_lists[0]) == 7:
|
||||||
|
# type, beat, chord, tempo, pitch, duration, velocity
|
||||||
|
previous_chord = None
|
||||||
|
previous_tempo = None
|
||||||
|
for i in range(0, len(list_of_lists)):
|
||||||
|
current_chord = list_of_lists[i][2]
|
||||||
|
current_tempo = list_of_lists[i][3]
|
||||||
|
if current_chord == previous_chord and current_chord != 0:
|
||||||
|
list_of_lists[i][2] = 'Conti'
|
||||||
|
elif current_chord != previous_chord and current_chord != 0:
|
||||||
|
previous_chord = current_chord
|
||||||
|
if current_tempo == previous_tempo and current_tempo != 0:
|
||||||
|
list_of_lists[i][3] = 'Conti'
|
||||||
|
elif current_tempo != previous_tempo and current_tempo != 0:
|
||||||
|
previous_tempo = current_tempo
|
||||||
|
if len(list_of_lists[0]) == 5:
|
||||||
|
# type, beat, instrument, pitch, duration
|
||||||
|
previous_instrument = None
|
||||||
|
for i in range(0, len(list_of_lists)):
|
||||||
|
current_instrument = list_of_lists[i][2]
|
||||||
|
if current_instrument == previous_instrument and current_instrument != 0:
|
||||||
|
list_of_lists[i][2] = 'Conti'
|
||||||
|
elif current_instrument != previous_instrument and current_instrument != 0:
|
||||||
|
previous_instrument = current_instrument
|
||||||
|
return list_of_lists
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
def __init__(self,
|
||||||
|
config: DictConfig,
|
||||||
|
model:AmadeusModel,
|
||||||
|
test_set:TuneCompiler,
|
||||||
|
vocab: Union[LangTokenVocab, LangTokenVocab],
|
||||||
|
device:str='cuda',
|
||||||
|
batch_size:int=16):
|
||||||
|
self.config = config
|
||||||
|
self.device = device
|
||||||
|
self.vocab = vocab
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(device)
|
||||||
|
self.test_set = test_set
|
||||||
|
|
||||||
|
self.input_len = config.train_params.input_length
|
||||||
|
self.loss_by_class = {key:[] for key in self.vocab.feature_list}
|
||||||
|
self.count_by_class = {key:0 for key in self.vocab.feature_list}
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
self.is_multiclass = True if config.nn_params.encoding_scheme == 'nb' or config.nn_params.encoding_scheme == 'cp' else False
|
||||||
|
self.first_pred_feature = self.config.data_params.first_pred_feature
|
||||||
|
|
||||||
|
self.neglect_keywords = ['SSS', 'SSN', 'Conti', 'Metrical', 'Note']
|
||||||
|
self.valid_item_prob = []
|
||||||
|
|
||||||
|
# we don't use focal loss on evaluation
|
||||||
|
self.focal_alpha = 1
|
||||||
|
self.focal_gamma = 0
|
||||||
|
|
||||||
|
def save_results(self, save_fn):
|
||||||
|
# convert loss_by_clas tensor to cpu
|
||||||
|
for key in self.loss_by_class.keys():
|
||||||
|
self.loss_by_class[key] = torch.tensor(self.loss_by_class[key]).cpu()
|
||||||
|
self.count_by_class[key] = torch.tensor(self.count_by_class[key]).cpu()
|
||||||
|
torch.save({'loss_by_class':self.loss_by_class, 'count_by_class':self.count_by_class}, save_fn)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_perplexity(self,less_than=256):
|
||||||
|
for data in tqdm(self.test_set.data_list, desc='Cal over dataset', position=0):
|
||||||
|
data_tensor = torch.LongTensor(data[0])
|
||||||
|
if self.config.nn_params.encoding_scheme == 'nb':
|
||||||
|
data_tensor = shift_and_pad(data_tensor, self.first_pred_feature)
|
||||||
|
data_tensor = data_tensor[:-1]
|
||||||
|
|
||||||
|
x_seg = data_tensor[:-1].unsqueeze(0)
|
||||||
|
y_seg = data_tensor[1:].unsqueeze(0)
|
||||||
|
self._cal_initial_seg(x_seg, y_seg)
|
||||||
|
|
||||||
|
if x_seg.shape[1] > self.input_len:
|
||||||
|
cat_logits = []
|
||||||
|
cat_y = []
|
||||||
|
cat_mask_indices = []
|
||||||
|
batch_x = x_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
|
||||||
|
batch_y = y_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
|
||||||
|
if self.is_multiclass:
|
||||||
|
batch_x = batch_x.transpose(1,2)
|
||||||
|
batch_y = batch_y.transpose(1,2)
|
||||||
|
for batch_start_idx in tqdm(range(0, min(batch_x.shape[0], less_than), self.batch_size), desc='In piece iter', position=1, leave=False):
|
||||||
|
x = batch_x[batch_start_idx:batch_start_idx+self.batch_size]
|
||||||
|
y = batch_y[batch_start_idx:batch_start_idx+self.batch_size]
|
||||||
|
logits, y,mask_indices = self._cal_following_seg(x, y)
|
||||||
|
cat_logits.append(logits)
|
||||||
|
cat_y.append(y)
|
||||||
|
cat_mask_indices.append(mask_indices)
|
||||||
|
if self.is_multiclass:
|
||||||
|
cat_dict = {}
|
||||||
|
for key in self.vocab.feature_list:
|
||||||
|
cat_dict[key] = torch.cat([logits_dict[key] for logits_dict in cat_logits], dim=0)
|
||||||
|
cat_logits = cat_dict
|
||||||
|
else:
|
||||||
|
cat_logits = torch.cat(cat_logits, dim=0)
|
||||||
|
cat_y = torch.cat(cat_y, dim=0)
|
||||||
|
mask_indices = torch.cat(cat_mask_indices, dim=0)
|
||||||
|
if self.is_multiclass:
|
||||||
|
self._update_loss_for_multi_class(cat_logits, cat_y,mask_indices)
|
||||||
|
else:
|
||||||
|
cat_prob = torch.nn.functional.softmax(cat_logits, dim=-1)
|
||||||
|
pt = cat_prob[torch.arange(cat_prob.shape[0]), cat_y]
|
||||||
|
# focal_loss = -self.focal_alpha * (1-pt)**self.focal_gamma * torch.log(pt) # [batch_size*seq_len]
|
||||||
|
loss = -torch.log(pt)
|
||||||
|
self._update_loss_for_single_class(loss, cat_y)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _update_loss_for_single_class(self, neg_log_prob:torch.Tensor, y:torch.Tensor):
|
||||||
|
for key in self.vocab.feature_list:
|
||||||
|
feature_mask = self.vocab.total_mask[key].to(y.device) # [vocab_size,]
|
||||||
|
mask_for_target = feature_mask[y] # [b*t]
|
||||||
|
normal_loss_seq_by_class = neg_log_prob[mask_for_target==1]
|
||||||
|
if mask_for_target.sum().item() != 0:
|
||||||
|
self.loss_by_class[key] += normal_loss_seq_by_class.tolist()
|
||||||
|
self.count_by_class[key] += mask_for_target.sum().item()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _update_loss_for_multi_class(self, logits_dict:dict, tgt:torch.Tensor, mask_indices:torch.Tensor=None):
|
||||||
|
correct_token_prob = []
|
||||||
|
for index, key in enumerate(self.vocab.feature_list):
|
||||||
|
feat_tgt = tgt[:,index]
|
||||||
|
logit_values = logits_dict[key]
|
||||||
|
logit_values = logit_values
|
||||||
|
prob_values = torch.nn.functional.softmax(logit_values, dim=-1)
|
||||||
|
# replce the false
|
||||||
|
correct_token_prob.append(prob_values[torch.arange(prob_values.shape[0]), feat_tgt])
|
||||||
|
correct_token_prob = torch.stack(correct_token_prob, dim=1)
|
||||||
|
# tgt = reverse_shift_and_pad_for_tensor(tgt, self.first_pred_feature)
|
||||||
|
y_decoded = self.vocab.decode(tgt)
|
||||||
|
y_decoded = add_conti(y_decoded, self.config.nn_params.encoding_scheme)
|
||||||
|
# correct_token_prob = reverse_shift_and_pad_for_tensor(correct_token_prob, self.first_pred_feature)
|
||||||
|
num_notes = logits_dict['pitch'].shape[0]
|
||||||
|
cum_prob = 1
|
||||||
|
max_num = mask_indices.size(0)
|
||||||
|
for idx in range(max_num):
|
||||||
|
if max_num != num_notes:
|
||||||
|
print("not equal",max_num,num_notes)
|
||||||
|
token = y_decoded[idx]
|
||||||
|
vaild_mask = mask_indices[idx,:]
|
||||||
|
token_prob = correct_token_prob[idx].tolist()
|
||||||
|
for j, key in enumerate(self.vocab.feature_list):
|
||||||
|
cur_feature = token[j]
|
||||||
|
whether_predicted = vaild_mask[j]
|
||||||
|
# clamp cur_prob to avoid when cur_prob is 0
|
||||||
|
cur_prob = max(token_prob[j], 1e-10)
|
||||||
|
if cur_feature == 0: # ignore token
|
||||||
|
continue
|
||||||
|
if whether_predicted is False: # skip provided token
|
||||||
|
continue
|
||||||
|
if cur_feature in self.neglect_keywords:
|
||||||
|
cum_prob *= cur_prob
|
||||||
|
continue
|
||||||
|
if self.config.nn_params.encoding_scheme == 'cp' and 'time_signature' in cur_feature:
|
||||||
|
cum_prob *= cur_prob
|
||||||
|
continue
|
||||||
|
if self.config.nn_params.encoding_scheme == 'cp' and 'Bar' in cur_feature:
|
||||||
|
cum_prob = 1
|
||||||
|
continue
|
||||||
|
self.valid_item_prob.append([cur_feature, cur_prob, cur_prob*cum_prob])
|
||||||
|
pt = cur_prob*cum_prob
|
||||||
|
loss = -log(pt)
|
||||||
|
self.loss_by_class[key].append(loss)
|
||||||
|
self.count_by_class[key] += 1
|
||||||
|
cum_prob = 1
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _cal_initial_seg(self, x_seg, y_seg):
|
||||||
|
x, y = x_seg[:, :self.input_len].to(self.device), y_seg[:, :self.input_len].to(self.device)
|
||||||
|
mask_indices = torch.ones_like(y).bool().to(self.device).flatten(0,1)
|
||||||
|
if self.config.use_diff is True:
|
||||||
|
logits,(mask_indices,_) = self.model(x, y)
|
||||||
|
else:
|
||||||
|
logits = self.model(x, y)
|
||||||
|
y = y.flatten(0,1)
|
||||||
|
if self.is_multiclass:
|
||||||
|
for key in logits.keys():
|
||||||
|
feat_tensor = logits[key].flatten(0,1)
|
||||||
|
logits[key] = feat_tensor
|
||||||
|
self._update_loss_for_multi_class(logits, y, mask_indices)
|
||||||
|
else:
|
||||||
|
prob = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
prob = prob.flatten(0,1)
|
||||||
|
pt = prob[torch.arange(len(y)), y]
|
||||||
|
loss = -torch.log(pt)
|
||||||
|
self._update_loss_for_single_class(loss, y)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _cal_following_seg(self, x:torch.Tensor, y:torch.Tensor):
|
||||||
|
x, y = x.to(self.device), y.to(self.device)
|
||||||
|
mask_indices = torch.ones_like(y).bool().to(self.device)
|
||||||
|
if self.config.use_diff is True:
|
||||||
|
logits,(mask_indices,_) = self.model(x, y)
|
||||||
|
else:
|
||||||
|
logits = self.model(x, y)
|
||||||
|
y = y[:, -1:].flatten(0,1).cpu()
|
||||||
|
mask_indices = mask_indices.reshape(x.shape)[:,-1:].flatten(0,1).cpu()
|
||||||
|
if self.is_multiclass:
|
||||||
|
logits_dict = {}
|
||||||
|
for key in self.vocab.feature_list:
|
||||||
|
logits_dict[key] = logits[key][:, -1:].flatten(0,1).cpu()
|
||||||
|
return logits_dict, y,mask_indices
|
||||||
|
else:
|
||||||
|
logits = logits[:, -1:].flatten(0,1).cpu()
|
||||||
|
return logits, y,mask_indices
|
||||||
|
|
||||||
|
def prepare_prompt_and_ground_truth(self, save_dir, num_target_samples, num_target_measures):
|
||||||
|
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||||
|
|
||||||
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||||
|
try:
|
||||||
|
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||||
|
except KeyError:
|
||||||
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||||
|
|
||||||
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||||
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||||
|
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||||
|
|
||||||
|
for i, (tuneidx, tune_name) in enumerate(self.test_set):
|
||||||
|
ground_truth_sample = tuneidx
|
||||||
|
try:
|
||||||
|
decoder(ground_truth_sample, output_path=str(save_dir / f"{i}_{tune_name}_gt.mid"))
|
||||||
|
except:
|
||||||
|
print(f"Error in generating {i}_{tune_name}.mid")
|
||||||
|
|
||||||
|
prompt = self.model.decoder._prepare_inference(start_token=self.model.decoder.net.start_token, manual_seed=0, condition=tuneidx, num_target_measures=num_target_measures)
|
||||||
|
try:
|
||||||
|
decoder(prompt, output_path=str(save_dir / f"{i}_{tune_name}_prompt.mid"))
|
||||||
|
except:
|
||||||
|
print(f"Error in generating {i}_{tune_name}_prompt.mid")
|
||||||
|
|
||||||
|
if i == num_target_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
def generate_samples_with_prompt(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072):
|
||||||
|
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||||
|
|
||||||
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||||
|
try:
|
||||||
|
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||||
|
except KeyError:
|
||||||
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||||
|
|
||||||
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||||
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||||
|
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||||
|
|
||||||
|
tuneidx = tuneidx.cuda()
|
||||||
|
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||||
|
if encoding_scheme == 'nb':
|
||||||
|
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||||
|
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
|
||||||
|
|
||||||
|
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||||
|
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||||
|
|
||||||
|
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||||
|
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||||
|
|
||||||
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||||
|
try:
|
||||||
|
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||||
|
except KeyError:
|
||||||
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||||
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||||
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||||
|
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||||
|
|
||||||
|
for i in range(num_samples):
|
||||||
|
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||||
|
if encoding_scheme == 'nb':
|
||||||
|
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||||
|
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
|
||||||
|
|
||||||
|
def generate_samples_with_text_prompt(self, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||||
|
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
|
||||||
|
encoder = T5EncoderModel.from_pretrained('google/flan-t5-base').to(self.device)
|
||||||
|
print(f"Using T5EncoderModel for text prompt: {prompt}")
|
||||||
|
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(self.device)
|
||||||
|
context = encoder(**context).last_hidden_state
|
||||||
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||||
|
try:
|
||||||
|
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||||
|
except KeyError:
|
||||||
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||||
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||||
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||||
|
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||||
|
|
||||||
|
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||||
|
if encoding_scheme == 'nb':
|
||||||
|
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||||
|
# Open the jsonl file and count the number of lines to determine the current index
|
||||||
|
jsonl_path = save_dir / "name2prompt.jsonl"
|
||||||
|
if jsonl_path.exists():
|
||||||
|
with open(jsonl_path, 'r') as f:
|
||||||
|
current_idx = sum(1 for _ in f)
|
||||||
|
else:
|
||||||
|
current_idx = 0
|
||||||
|
|
||||||
|
name = f"prompt_{current_idx}"
|
||||||
|
name2prompt_dict = defaultdict(list)
|
||||||
|
name2prompt_dict[name].append(prompt)
|
||||||
|
with open(jsonl_path, 'a') as f:
|
||||||
|
f.write(json.dumps(name2prompt_dict) + '\n')
|
||||||
|
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))
|
||||||
512
Amadeus/model_zoo.py
Normal file
512
Amadeus/model_zoo.py
Normal file
@ -0,0 +1,512 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
from . import transformer_utils
|
||||||
|
from . import sub_decoder_zoo
|
||||||
|
from x_transformers.x_transformers import LayerIntermediates, AbsolutePositionalEmbedding
|
||||||
|
from data_representation.vocab_utils import LangTokenVocab
|
||||||
|
import os
|
||||||
|
|
||||||
|
class AmadeusModelWrapper(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
input_length:int,
|
||||||
|
prediction_order:list,
|
||||||
|
input_embedder_name:str,
|
||||||
|
main_decoder_name:str,
|
||||||
|
sub_decoder_name:str,
|
||||||
|
sub_decoder_depth:int,
|
||||||
|
sub_decoder_enricher_use:bool,
|
||||||
|
dim:int,
|
||||||
|
heads:int,
|
||||||
|
depth:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
This class wraps the three main components of the AmadeusModel model,
|
||||||
|
which are the input embedding layer, the main transformer decoder, and the sub-decoder.
|
||||||
|
'''
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.vocab = vocab
|
||||||
|
self.vocab_size = vocab.get_vocab_size()
|
||||||
|
self.start_token = vocab.sos_token if hasattr(vocab, 'sos_token') else None
|
||||||
|
self.end_token = vocab.eos_token if hasattr(vocab, 'eos_token') else None
|
||||||
|
self.input_length = input_length
|
||||||
|
self.prediction_order = prediction_order
|
||||||
|
self._get_input_embedder(input_embedder_name, vocab, dropout, dim)
|
||||||
|
self._get_main_decoder(main_decoder_name, input_length, dim, heads, depth, dropout)
|
||||||
|
self._get_sub_decoder(sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout)
|
||||||
|
self.bos_token_hidden = None
|
||||||
|
|
||||||
|
def _get_input_embedder(self, input_embedder_name, vocab, dropout, dim):
|
||||||
|
self.emb_dropout = nn.Dropout(dropout)
|
||||||
|
self.input_embedder = getattr(transformer_utils, input_embedder_name)(
|
||||||
|
vocab=vocab,
|
||||||
|
dim_model=dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_main_decoder(self, main_decoder_name, input_length, dim, heads, depth, dropout):
|
||||||
|
self.pos_enc = AbsolutePositionalEmbedding(dim, input_length)
|
||||||
|
self.main_norm = nn.LayerNorm(dim)
|
||||||
|
self.main_decoder = getattr(transformer_utils, main_decoder_name)(
|
||||||
|
dim=dim,
|
||||||
|
depth=depth,
|
||||||
|
heads=heads,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_sub_decoder(self, sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout):
|
||||||
|
self.sub_decoder = getattr(sub_decoder_zoo, sub_decoder_name)(
|
||||||
|
prediction_order=prediction_order,
|
||||||
|
vocab=vocab,
|
||||||
|
dim=dim,
|
||||||
|
sub_decoder_depth=sub_decoder_depth,
|
||||||
|
heads=heads,
|
||||||
|
dropout=dropout,
|
||||||
|
sub_decoder_enricher_use=sub_decoder_enricher_use
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
|
||||||
|
embedding = self.input_embedder(input_seq) + self.pos_enc(input_seq)
|
||||||
|
embedding = self.emb_dropout(embedding)
|
||||||
|
hidden_vec,layer_inter = self.main_decoder(embedding,train=True, context=context) # B x T x d_model
|
||||||
|
hidden_vec = self.main_norm(hidden_vec)
|
||||||
|
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': target, 'bos_token_hidden': self.bos_token_hidden}
|
||||||
|
logits = self.sub_decoder(input_dict)
|
||||||
|
# 选择总数中离三分之一最近的层
|
||||||
|
num_layers = len(layer_inter.layer_hiddens)
|
||||||
|
idx = round(num_layers / 3)
|
||||||
|
idx = min(max(idx, 0), num_layers - 1)
|
||||||
|
input_dict['hidden_vec'] = layer_inter.layer_hiddens[idx]
|
||||||
|
return logits, input_dict
|
||||||
|
|
||||||
|
class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||||
|
def __init__(self, net:AmadeusModelWrapper):
|
||||||
|
'''
|
||||||
|
Initializes an autoregressive wrapper around the AmadeusModelWrapper,
|
||||||
|
which allows sequential token generation.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
- net: The nested music transformer model that performs the token generation.
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
self.net = net
|
||||||
|
|
||||||
|
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
|
||||||
|
return self.net(input_seq, target, context=context)
|
||||||
|
|
||||||
|
def _prepare_inference(self, start_token, manual_seed, condition=None, num_target_measures=4):
|
||||||
|
'''
|
||||||
|
Prepares the initial tokens for autoregressive inference. If a manual seed is provided,
|
||||||
|
it sets the seed for reproducibility. If a condition is given, it selects a subset of
|
||||||
|
the tokens based on certain criteria related to the encoding scheme.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
- start_token: The token that represents the start of a sequence.
|
||||||
|
- manual_seed: A seed value for reproducibility in inference (if greater than 0).
|
||||||
|
- condition: An optional tensor used for conditional generation, which helps select a
|
||||||
|
portion of the input tokens based on the encoding scheme.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- total_out: A tensor containing the initial tokens for inference, padded to ensure compatibility
|
||||||
|
with the model.
|
||||||
|
'''
|
||||||
|
if manual_seed > 0:
|
||||||
|
torch.manual_seed(manual_seed)
|
||||||
|
|
||||||
|
total_out = []
|
||||||
|
if condition is None:
|
||||||
|
# Use the start token if no condition is given
|
||||||
|
total_out.extend(start_token)
|
||||||
|
else:
|
||||||
|
# Extract the portion of the sequence depending on encoding scheme (remi, cp, or nb)
|
||||||
|
if self.net.vocab.encoding_scheme == 'remi':
|
||||||
|
type_boundaries = self.net.vocab.remi_vocab_boundaries_by_key['type']
|
||||||
|
# vocab idx -> 0:SOS, 1:EOS, 2:Bar_without_time_signature, ... where_type_ends:Bar_time_signature_end, ...
|
||||||
|
measure_bool = (2 <= condition) & (condition < type_boundaries[1]) # between Bar_ts_start and Bar_ts_end
|
||||||
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||||
|
elif self.net.vocab.encoding_scheme == 'cp':
|
||||||
|
# find the start and end of the measure
|
||||||
|
beat_event2idx = self.net.vocab.event2idx['beat']
|
||||||
|
for event, idx in beat_event2idx.items():
|
||||||
|
if event == 0:
|
||||||
|
continue
|
||||||
|
if event == 'Bar':
|
||||||
|
start_idx = idx
|
||||||
|
elif event.startswith('Beat'):
|
||||||
|
end_idx = idx
|
||||||
|
break
|
||||||
|
measure_bool = (condition[:,1] >= start_idx) & (condition[:,1] < end_idx) # measure tokens
|
||||||
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||||
|
# measure_bool = (condition[:,1] == 1) # measure tokens
|
||||||
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||||
|
elif self.net.vocab.encoding_scheme == 'nb':
|
||||||
|
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
|
||||||
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
||||||
|
|
||||||
|
if conditional_input_len == 0:
|
||||||
|
conditional_input_len = 50
|
||||||
|
|
||||||
|
selected_tokens = condition[:conditional_input_len].tolist()
|
||||||
|
total_out.extend(selected_tokens)
|
||||||
|
|
||||||
|
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
|
||||||
|
return total_out
|
||||||
|
|
||||||
|
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
|
||||||
|
'''
|
||||||
|
Runs one step of autoregressive generation by taking the input sequence, embedding it,
|
||||||
|
passing it through the main decoder, and generating logits and a sampled token.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
- input_seq: The input sequence tensor to be embedded and processed.
|
||||||
|
- cache: Optional cache for attention mechanisms to avoid recomputation.
|
||||||
|
- sampling_method: Sampling strategy used to select the next token.
|
||||||
|
- threshold: Optional threshold value for sampling methods that require it.
|
||||||
|
- temperature: Controls the randomness of predictions (higher temperature increases randomness).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- logits: The predicted logits for the next token.
|
||||||
|
- sampled_token: The token sampled from the logits.
|
||||||
|
- intermidiates: Intermediate states from the main decoder, useful for caching.
|
||||||
|
'''
|
||||||
|
embedding = self.net.input_embedder(input_seq) + self.net.pos_enc(input_seq)
|
||||||
|
embedding = self.net.emb_dropout(embedding)
|
||||||
|
|
||||||
|
# Run through the main decoder and normalize
|
||||||
|
hidden_vec, intermidiates = self.net.main_decoder(embedding, cache,context_embedding=context) # B x T x d_model
|
||||||
|
hidden_vec = self.net.main_norm(hidden_vec)
|
||||||
|
hidden_vec = hidden_vec[:, -1:] # Keep only the last time step
|
||||||
|
|
||||||
|
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
|
||||||
|
|
||||||
|
# Generate the next token
|
||||||
|
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
|
||||||
|
return logits, sampled_token, intermidiates, hidden_vec
|
||||||
|
|
||||||
|
def _update_total_out(self, total_out, sampled_token):
|
||||||
|
'''
|
||||||
|
Updates the output sequence with the newly sampled token. Depending on the encoding scheme,
|
||||||
|
it either appends the token directly or processes feature-based sampling.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
- total_out: The tensor containing the previously generated tokens.
|
||||||
|
- sampled_token: The newly generated token to be appended.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- total_out: Updated output tensor with the newly generated token.
|
||||||
|
- sampled_token: The processed sampled token.
|
||||||
|
'''
|
||||||
|
if self.net.vocab.encoding_scheme == 'remi':
|
||||||
|
# For remi encoding, directly append the sampled token
|
||||||
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
|
||||||
|
else:
|
||||||
|
# Handle other encoding schemes by concatenating features
|
||||||
|
sampled_token_list = []
|
||||||
|
for key in self.net.vocab.feature_list:
|
||||||
|
sampled_token_list.append(sampled_token[key])
|
||||||
|
sampled_token = torch.cat(sampled_token_list, dim=-1)
|
||||||
|
# print(total_out.shape)
|
||||||
|
if len(sampled_token.shape) == 2:
|
||||||
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=1)
|
||||||
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
|
||||||
|
|
||||||
|
return total_out, sampled_token
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
|
||||||
|
'''
|
||||||
|
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
|
||||||
|
until the desired maximum sequence length is reached or the end token is encountered.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
- manual_seed: A seed value for reproducibility in inference.
|
||||||
|
- max_seq_len: The maximum length of the generated sequence.
|
||||||
|
- condition: An optional conditioning sequence to start generation from.
|
||||||
|
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
|
||||||
|
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
|
||||||
|
- temperature: Controls the randomness of the token sampling process.
|
||||||
|
- batch_size: The number of sequences to generate in parallel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- total_out: The generated sequence of tokens as a tensor.
|
||||||
|
'''
|
||||||
|
# Prepare the starting sequence for inference
|
||||||
|
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
||||||
|
|
||||||
|
# If a condition is provided, run one initial step
|
||||||
|
if condition is not None:
|
||||||
|
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||||
|
else:
|
||||||
|
cache = LayerIntermediates()
|
||||||
|
|
||||||
|
# Continue generating tokens until the maximum sequence length is reached
|
||||||
|
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
|
||||||
|
bos_hidden_vec = None
|
||||||
|
hidden_vec_list = []
|
||||||
|
token_time_list = []
|
||||||
|
while total_out.shape[1] < max_seq_len:
|
||||||
|
pbar.update(1)
|
||||||
|
input_tensor = total_out[:, -self.net.input_length:]
|
||||||
|
# Generate the next token and update the cache
|
||||||
|
time_start = time.time()
|
||||||
|
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||||
|
time_end = time.time()
|
||||||
|
token_time_list.append(time_end - time_start)
|
||||||
|
if bos_hidden_vec is None:
|
||||||
|
bos_hidden_vec = hidden_vec
|
||||||
|
hidden_vec_list.append(hidden_vec)
|
||||||
|
# Update attention cache to handle autoregressive generation
|
||||||
|
for inter in cache.attn_intermediates:
|
||||||
|
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
|
||||||
|
|
||||||
|
# Update the generated output with the new token
|
||||||
|
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
|
||||||
|
|
||||||
|
# Stop if the end token is reached
|
||||||
|
if sampled_token.tolist() == self.net.end_token[0]:
|
||||||
|
break
|
||||||
|
# append hidden_vec to pkl
|
||||||
|
|
||||||
|
# save_path = 'hidden/diffnoaug_hidden_vec.pt'
|
||||||
|
# save_time_path = 'hidden/diff_noaug_token_time.json'
|
||||||
|
# if os.path.exists(save_path):
|
||||||
|
# # Load existing list and append
|
||||||
|
# hidden_vec_all = torch.load(save_path, map_location="cpu")
|
||||||
|
# hidden_vec_all.extend(hidden_vec_list)
|
||||||
|
# torch.save(hidden_vec_all, save_path)
|
||||||
|
# else:
|
||||||
|
# torch.save(hidden_vec_list, save_path)
|
||||||
|
|
||||||
|
# if os.path.exists(save_time_path):
|
||||||
|
# # Load existing list and append
|
||||||
|
# token_time_all = json.load(open(save_time_path, 'r'))
|
||||||
|
# token_time_all = token_time_all['token_time_list']
|
||||||
|
# token_time_all.extend(token_time_list)
|
||||||
|
# average_time = sum(token_time_all) / len(token_time_all)
|
||||||
|
# data = {
|
||||||
|
# 'average_time': average_time,
|
||||||
|
# 'token_time_list': token_time_all
|
||||||
|
# }
|
||||||
|
# json.dump(data, open(save_time_path, 'w'), indent=4)
|
||||||
|
# else:
|
||||||
|
# average_time = sum(token_time_list) / len(token_time_list)
|
||||||
|
# data = {
|
||||||
|
# 'average_time': average_time,
|
||||||
|
# 'token_time_list': token_time_list
|
||||||
|
# }
|
||||||
|
# json.dump(data, open(save_time_path, 'w'), indent=4)
|
||||||
|
|
||||||
|
return total_out
|
||||||
|
|
||||||
|
def generate_batch(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1):
|
||||||
|
'''
|
||||||
|
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
|
||||||
|
until the desired maximum sequence length is reached or the end token is encountered.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
- manual_seed: A seed value for reproducibility in inference.
|
||||||
|
- max_seq_len: The maximum length of the generated sequence.
|
||||||
|
- condition: An optional conditioning sequence to start generation from.
|
||||||
|
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
|
||||||
|
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
|
||||||
|
- temperature: Controls the randomness of the token sampling process.
|
||||||
|
- batch_size: The number of sequences to generate in parallel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- total_out: The generated sequence of tokens as a tensor.
|
||||||
|
'''
|
||||||
|
# Prepare the starting sequence for inference
|
||||||
|
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
||||||
|
# total_out (1,1,num) -> (bs,1,num)
|
||||||
|
total_out = total_out.repeat(batch_size, 1, 1)
|
||||||
|
# If a condition is provided, run one initial step
|
||||||
|
if condition is not None:
|
||||||
|
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||||
|
else:
|
||||||
|
cache = LayerIntermediates()
|
||||||
|
|
||||||
|
# Continue generating tokens until the maximum sequence length is reached
|
||||||
|
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
|
||||||
|
while total_out.shape[1] < max_seq_len:
|
||||||
|
pbar.update(1)
|
||||||
|
input_tensor = total_out[:, -self.net.input_length:]
|
||||||
|
|
||||||
|
# Generate the next token and update the cache
|
||||||
|
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||||
|
|
||||||
|
# Update attention cache to handle autoregressive generation
|
||||||
|
for inter in cache.attn_intermediates:
|
||||||
|
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
|
||||||
|
|
||||||
|
# Update the generated output with the new token
|
||||||
|
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
|
||||||
|
|
||||||
|
# Stop if the end token is reached
|
||||||
|
if sampled_token.tolist() == self.net.end_token[0]:
|
||||||
|
break
|
||||||
|
|
||||||
|
return total_out
|
||||||
|
|
||||||
|
class AmadeusModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
input_length:int,
|
||||||
|
prediction_order:list,
|
||||||
|
input_embedder_name:str,
|
||||||
|
main_decoder_name:str,
|
||||||
|
sub_decoder_name:str,
|
||||||
|
sub_decoder_depth:int,
|
||||||
|
sub_decoder_enricher_use:bool,
|
||||||
|
dim:int,
|
||||||
|
heads:int,
|
||||||
|
depth:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
This class combines the wrapper classes and initializes the full AmadeusModel model,
|
||||||
|
which can perform autoregressive sequence generation for symbolic music.
|
||||||
|
|
||||||
|
Vocabulary used for tokenization of the symbolic music data.
|
||||||
|
Length of the input seqkeuence in tokens.
|
||||||
|
Defines the order in which features are predicted in a sequence used for compound shift
|
||||||
|
Name of the input embedding model to be used (e.g., one-hot embedding or learned embeddings).
|
||||||
|
Name of the main transformer decoder model used for generating the hidden representations for compound tokens.
|
||||||
|
Name of the sub-decoder, which processes the hidden states and decodes the sub-tokens inside the compound tokens.
|
||||||
|
Depth (number of layers) of the sub-decoder.
|
||||||
|
Whether to use an additional enricher module in the sub-decoder to refine representations.
|
||||||
|
Dimensionality of the model (hidden size of the transformer layers).
|
||||||
|
Number of attention heads in the transformer layers.
|
||||||
|
Number of layers in the main decoder.
|
||||||
|
Dropout rate for all layers in the model.
|
||||||
|
'''
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
decoder = AmadeusModelWrapper(
|
||||||
|
vocab=vocab,
|
||||||
|
input_length=input_length,
|
||||||
|
prediction_order=prediction_order,
|
||||||
|
input_embedder_name=input_embedder_name,
|
||||||
|
main_decoder_name=main_decoder_name,
|
||||||
|
sub_decoder_name=sub_decoder_name,
|
||||||
|
sub_decoder_depth=sub_decoder_depth,
|
||||||
|
sub_decoder_enricher_use=sub_decoder_enricher_use,
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
depth=depth,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self.decoder = AmadeusModelAutoregressiveWrapper(
|
||||||
|
net=decoder
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
|
||||||
|
return self.decoder(input_seq, target, context=context)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
|
||||||
|
if batch_size == 1:
|
||||||
|
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
|
||||||
|
else:
|
||||||
|
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
|
||||||
|
|
||||||
|
class AmadeusModel4Encodec(AmadeusModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
input_length:int,
|
||||||
|
prediction_order:list,
|
||||||
|
input_embedder_name:str,
|
||||||
|
main_decoder_name:str,
|
||||||
|
sub_decoder_name:str,
|
||||||
|
sub_decoder_depth:int,
|
||||||
|
sub_decoder_enricher_use:bool,
|
||||||
|
dim:int,
|
||||||
|
heads:int,
|
||||||
|
depth:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vocab=vocab,
|
||||||
|
input_length=input_length,
|
||||||
|
prediction_order=prediction_order,
|
||||||
|
input_embedder_name=input_embedder_name,
|
||||||
|
main_decoder_name=main_decoder_name,
|
||||||
|
sub_decoder_name=sub_decoder_name,
|
||||||
|
sub_decoder_depth=sub_decoder_depth,
|
||||||
|
sub_decoder_enricher_use=sub_decoder_enricher_use,
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
depth=depth,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_inference(self, start_token, manual_seed, condition=None):
|
||||||
|
if manual_seed > 0:
|
||||||
|
torch.manual_seed(manual_seed)
|
||||||
|
total_out = []
|
||||||
|
if condition is None:
|
||||||
|
total_out.extend(start_token)
|
||||||
|
else:
|
||||||
|
if self.decoder.net.vocab.encoding_scheme == 'remi':
|
||||||
|
selected_tokens = condition[:1500].tolist()
|
||||||
|
else:
|
||||||
|
selected_tokens = condition[:500].tolist()
|
||||||
|
total_out.extend(selected_tokens)
|
||||||
|
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.decoder.net.device)
|
||||||
|
return total_out
|
||||||
|
|
||||||
|
def _update_total_out(self, total_out, sampled_token):
|
||||||
|
if self.decoder.net.vocab.encoding_scheme == 'remi':
|
||||||
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
|
||||||
|
else:
|
||||||
|
sampled_token_list = []
|
||||||
|
for key in self.decoder.net.vocab.feature_list:
|
||||||
|
sampled_token_list.append(sampled_token[key])
|
||||||
|
sampled_token = torch.cat(sampled_token_list, dim=-1) # B(1) x num_features
|
||||||
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
|
||||||
|
return total_out, sampled_token
|
||||||
|
|
||||||
|
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1):
|
||||||
|
embedding = self.decoder.net.input_embedder(input_seq) + self.decoder.net.pos_enc(input_seq)
|
||||||
|
embedding = self.decoder.net.emb_dropout(embedding)
|
||||||
|
hidden_vec, intermidiates = self.decoder.net.main_decoder(embedding, cache) # B x T x d_model
|
||||||
|
hidden_vec = self.decoder.net.main_norm(hidden_vec)
|
||||||
|
hidden_vec = hidden_vec[:, -1:] # B x 1 x d_model
|
||||||
|
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': None}
|
||||||
|
if self.decoder.net.vocab.encoding_scheme == 'remi':
|
||||||
|
feature_class_idx = (input_seq.shape[1] - 1) % 4
|
||||||
|
feature_type = self.decoder.net.vocab.feature_list[feature_class_idx]
|
||||||
|
logits, sampled_token = self.decoder.net.sub_decoder.run_one_step(input_dict, sampling_method, threshold, temperature, feature_type)
|
||||||
|
else:
|
||||||
|
logits, sampled_token = self.decoder.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
|
||||||
|
return logits, sampled_token, intermidiates
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(self, manual_seed, max_seq_len, condition=None, sampling_method=None, threshold=None, temperature=1):
|
||||||
|
total_out = self._prepare_inference(self.decoder.net.start_token, manual_seed, condition)
|
||||||
|
if condition is not None:
|
||||||
|
_, _, cache = self._run_one_step(total_out[:, -self.decoder.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||||
|
else:
|
||||||
|
cache = LayerIntermediates()
|
||||||
|
while total_out.shape[1] < max_seq_len:
|
||||||
|
input_tensor = total_out[:, -self.decoder.net.input_length:]
|
||||||
|
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||||
|
for inter in cache.attn_intermediates:
|
||||||
|
inter.cached_kv = [t[..., -(self.decoder.net.input_length - 1):, :] for t in inter.cached_kv] # B x num_heads x T x d_head
|
||||||
|
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
|
||||||
|
if sampled_token.tolist() == self.decoder.net.end_token[0]:
|
||||||
|
break
|
||||||
|
return total_out
|
||||||
168
Amadeus/sampling_utils.py
Normal file
168
Amadeus/sampling_utils.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def top_p_sampling(logits, thres=0.9):
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = cum_probs > thres
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
# Create an empty tensor to hold the new logits
|
||||||
|
new_logits = logits.clone()
|
||||||
|
|
||||||
|
# Use the sorted indices to place the '-inf' in the original places
|
||||||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||||
|
new_logits[..., indices_to_remove] = float('-inf')
|
||||||
|
return new_logits
|
||||||
|
|
||||||
|
|
||||||
|
# refered: https://github.com/cimeister/typical-sampling
|
||||||
|
def typical_sampling(logits, thres=0.99):
|
||||||
|
# calculate entropy
|
||||||
|
normalized = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
p = torch.exp(normalized)
|
||||||
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||||||
|
|
||||||
|
# shift and sort
|
||||||
|
shifted_scores = torch.abs((-normalized) - ent)
|
||||||
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||||||
|
sorted_logits = logits.gather(-1, sorted_indices)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative mass above the threshold
|
||||||
|
last_ind = (cumulative_probs < thres).sum(dim=-1)
|
||||||
|
last_ind[last_ind < 0] = 0
|
||||||
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(-1, last_ind.view(-1, 1, 1))
|
||||||
|
# if self.min_tokens_to_keep > 1:
|
||||||
|
# # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||||
|
# sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
|
||||||
|
|
||||||
|
scores = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def add_gumbel_noise(logits, temperature):
|
||||||
|
'''
|
||||||
|
The Gumbel max is a method for sampling categorical distributions.
|
||||||
|
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
|
||||||
|
Thus, we use float64.
|
||||||
|
'''
|
||||||
|
if temperature == 0:
|
||||||
|
return logits
|
||||||
|
logits = logits.to(torch.float64)
|
||||||
|
noise = torch.rand_like(logits, dtype=torch.float64)
|
||||||
|
gumbel_noise = (- torch.log(noise)) ** temperature
|
||||||
|
return logits.exp() / gumbel_noise
|
||||||
|
#
|
||||||
|
# refered: https://github.com/john-hewitt/truncation-sampling
|
||||||
|
def eta_sampling(logits, epsilon) -> torch.FloatTensor:
|
||||||
|
probabilities = logits.softmax(dim=-1)
|
||||||
|
entropy = torch.distributions.Categorical(probs=probabilities).entropy()
|
||||||
|
new_epsilon = min(epsilon, torch.sqrt(torch.tensor(epsilon))*torch.exp(-entropy))
|
||||||
|
indices_to_remove = probabilities < new_epsilon
|
||||||
|
max_word = torch.argmax(logits, dim=-1)
|
||||||
|
indices_to_remove[..., max_word.squeeze()] = 0
|
||||||
|
new_scores = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||||
|
return new_scores
|
||||||
|
|
||||||
|
def sample(logits, sampling_method, threshold, temperature):
|
||||||
|
"""Sample from the logits with a specific sampling strategy."""
|
||||||
|
if sampling_method == "top_p":
|
||||||
|
probs = F.softmax(top_p_sampling(logits, thres=threshold) / temperature, dim=-1)
|
||||||
|
elif sampling_method == "typical":
|
||||||
|
probs = F.softmax(typical_sampling(logits, thres=threshold) / temperature, dim=-1)
|
||||||
|
elif sampling_method == "eta":
|
||||||
|
probs = F.softmax(eta_sampling(logits, epsilon=threshold) / temperature, dim=-1)
|
||||||
|
else:
|
||||||
|
probs = F.softmax(logits / temperature, dim=-1)
|
||||||
|
return torch.multinomial(probs[-1,-1,:], 1)
|
||||||
|
|
||||||
|
def sample_with_prob(logits, sampling_method, threshold, temperature):
|
||||||
|
"""Sample from the logits with a specific sampling strategy and return the token and its probability."""
|
||||||
|
# temporarily apply the sampling method to logits
|
||||||
|
logits = logits / temperature
|
||||||
|
# logits = add_gumbel_noise(logits, temperature)
|
||||||
|
|
||||||
|
if sampling_method == "top_p":
|
||||||
|
modified_logits = top_p_sampling(logits, thres=threshold)
|
||||||
|
elif sampling_method == "typical":
|
||||||
|
modified_logits = typical_sampling(logits, thres=threshold)
|
||||||
|
elif sampling_method == "eta":
|
||||||
|
modified_logits = eta_sampling(logits, epsilon=threshold)
|
||||||
|
else:
|
||||||
|
modified_logits = logits # 其他情况直接使用原始logits
|
||||||
|
|
||||||
|
# print(modified_logits.shape)
|
||||||
|
# 应用温度调整并计算概率
|
||||||
|
# probs = F.softmax(modified_logits / temperature, dim=-1)
|
||||||
|
probs = F.softmax(modified_logits, dim=-1)
|
||||||
|
|
||||||
|
# 获取最后一个位置的概率分布
|
||||||
|
# probs_last = probs[-1, -1, :]
|
||||||
|
# print(probs.shape)
|
||||||
|
probs_last = probs[-1, -1, :]
|
||||||
|
|
||||||
|
# 采样
|
||||||
|
sampled_token = torch.multinomial(probs_last, num_samples=1)
|
||||||
|
# 获取对应的概率值
|
||||||
|
prob_value = probs_last[sampled_token]
|
||||||
|
|
||||||
|
return sampled_token, prob_value.squeeze()
|
||||||
|
|
||||||
|
def top_p_sampling_fast(logits, thres=0.9):
|
||||||
|
"""
|
||||||
|
logits: Tensor of shape [B, L, V]
|
||||||
|
Returns: logits with low-prob tokens masked as -inf, shape [B, L, V]
|
||||||
|
"""
|
||||||
|
# Step 1: sort logits and get indices
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) # [B, L, V]
|
||||||
|
|
||||||
|
# Step 2: compute cumulative probs
|
||||||
|
probs = F.softmax(sorted_logits, dim=-1) # [B, L, V]
|
||||||
|
cum_probs = torch.cumsum(probs, dim=-1) # [B, L, V]
|
||||||
|
|
||||||
|
# Step 3: mask tokens beyond cumulative threshold
|
||||||
|
sorted_mask = cum_probs > thres
|
||||||
|
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
||||||
|
sorted_mask[..., 0] = False # always keep at least one token
|
||||||
|
|
||||||
|
# Step 4: scatter back to original order
|
||||||
|
# Create mask of same shape as logits, default False
|
||||||
|
mask = torch.zeros_like(logits, dtype=torch.bool) # [B, L, V]
|
||||||
|
mask = mask.scatter(-1, sorted_indices, sorted_mask)
|
||||||
|
|
||||||
|
# Step 5: mask logits
|
||||||
|
logits = logits.masked_fill(mask, float('-inf')) # final masked logits
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample_with_prob_fast(logits, sampling_method="top_p", threshold=0.9, temperature=1.0, mask_indices=None):
|
||||||
|
"""
|
||||||
|
logits: [B*T, num_sub_tokens, vocab_size]
|
||||||
|
mask_indices: mask indicating which tokens to sample, shape = [B*T, num_sub_tokens]
|
||||||
|
"""
|
||||||
|
if temperature != 1.0:
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
|
if sampling_method == "top_p":
|
||||||
|
logits = top_p_sampling_fast(logits, thres=threshold) # should support batch
|
||||||
|
elif sampling_method == "typical":
|
||||||
|
logits = typical_sampling(logits, thres=threshold)
|
||||||
|
elif sampling_method == "eta":
|
||||||
|
logits = eta_sampling(logits, epsilon=threshold)
|
||||||
|
# else: keep logits as-is
|
||||||
|
|
||||||
|
probs = torch.softmax(logits, dim=-1) # [B*T, num_sub_tokens, vocab_size]
|
||||||
|
|
||||||
|
B, L, V = probs.shape
|
||||||
|
probs_flat = probs.view(-1, V) # [(B*T * num_sub_tokens), V]
|
||||||
|
|
||||||
|
# 采样:multinomial 不能一次性处理 3D,展平后采样
|
||||||
|
sampled = torch.multinomial(probs_flat, num_samples=1) # [(B*T * num_sub_tokens), 1]
|
||||||
|
sampled = sampled.view(B, L) # [B*T, num_sub_tokens]
|
||||||
|
|
||||||
|
sampled_probs = torch.gather(probs, 2, sampled.unsqueeze(-1)).squeeze(-1) # [B*T, num_sub_tokens]
|
||||||
|
|
||||||
|
return sampled, sampled_probs
|
||||||
228
Amadeus/sub_decoder_utils.py
Normal file
228
Amadeus/sub_decoder_utils.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
from math import ceil
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, in_size, out_size, hidden_size, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.out_size = out_size
|
||||||
|
self.layer = nn.Sequential(
|
||||||
|
nn.Linear(in_size, hidden_size),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size, out_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layer(x)
|
||||||
|
|
||||||
|
class extendedMLP(nn.Module):
|
||||||
|
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = in_size
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
if num_layers == 1:
|
||||||
|
# Only one layer
|
||||||
|
self.layers.append(nn.Linear(in_size, out_size))
|
||||||
|
return
|
||||||
|
elif num_layers > 1:
|
||||||
|
# First layer
|
||||||
|
self.layers.append(nn.Linear(in_size, hidden_size))
|
||||||
|
self.layers.append(nn.Dropout(dropout))
|
||||||
|
self.layers.append(nn.ReLU())
|
||||||
|
# Intermediate layers
|
||||||
|
if num_layers > 2:
|
||||||
|
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
|
||||||
|
self.layers.append(nn.Linear(hidden_size, hidden_size))
|
||||||
|
self.layers.append(nn.Dropout(dropout))
|
||||||
|
self.layers.append(nn.ReLU())
|
||||||
|
# Last layer
|
||||||
|
self.layers.append(nn.Linear(hidden_size, out_size))
|
||||||
|
else:
|
||||||
|
raise ValueError("num_layers should be a positive integer")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class multiMLP(nn.Module):
|
||||||
|
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
|
||||||
|
super().__init__()
|
||||||
|
self.out_size = out_size
|
||||||
|
self.layer = nn.ModuleList([MLP(in_size, out_size, hidden_size, dropout) for _ in pred_order])
|
||||||
|
|
||||||
|
def forward(self, x, choice):
|
||||||
|
'''
|
||||||
|
x: B x T x d_model
|
||||||
|
choice: token type from self.pred_order (str or list of str)
|
||||||
|
'''
|
||||||
|
if isinstance(choice, str):
|
||||||
|
idx = self.pred_order.index(choice)
|
||||||
|
return self.layer[idx](x)
|
||||||
|
elif len(choice) > 1 and not isinstance(choice, str):
|
||||||
|
raise ValueError("multiMLP doesn't support parallel prediction")
|
||||||
|
|
||||||
|
class ResidualLayerNormModule(nn.Module):
|
||||||
|
def __init__(self, submodule: nn.Module):
|
||||||
|
super().__init__()
|
||||||
|
self.submodule = submodule
|
||||||
|
if submodule.__class__.__name__ == 'MultiheadAttention':
|
||||||
|
self.layer_norm = nn.LayerNorm(self.submodule.embed_dim)
|
||||||
|
else:
|
||||||
|
self.layer_norm = nn.LayerNorm(self.submodule.input_size)
|
||||||
|
|
||||||
|
def forward_attention(self, q, k, v, attn_mask, type):
|
||||||
|
attn_output, _ = self.submodule(q, k, v, attn_mask=attn_mask, need_weights=False, average_attn_weights=False)
|
||||||
|
return self.layer_norm(attn_output + q)
|
||||||
|
|
||||||
|
def forward_mlp(self, x):
|
||||||
|
return self.layer_norm(self.submodule(x) + x)
|
||||||
|
|
||||||
|
class MultiProj_hidden2logit(nn.Module):
|
||||||
|
def __init__(self, dim, vocab_sizes):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleDict({
|
||||||
|
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
def forward(self, hidden_vec, feature):
|
||||||
|
logit = self.layers[f"layer_{feature}"](hidden_vec)
|
||||||
|
return logit
|
||||||
|
|
||||||
|
class MultiProj_catvec2hidden(nn.Module):
|
||||||
|
def __init__(self, config, par_pred_keys, seq_pred_keys):
|
||||||
|
super().__init__()
|
||||||
|
'''
|
||||||
|
This class is used in SQstyleEachEmbStrategy
|
||||||
|
par_pred_keys: list of independent features(These tokens are predicted in parallel)
|
||||||
|
seq_pred_keys: list of sequential features(These tokens are predicted sequentially)
|
||||||
|
'''
|
||||||
|
net_param = config.nn_params
|
||||||
|
self.d_model = net_param.model.d_model
|
||||||
|
independent_emb_size = 0
|
||||||
|
for key in par_pred_keys:
|
||||||
|
independent_emb_size += net_param.emb[key]
|
||||||
|
self.layers = nn.ModuleDict({
|
||||||
|
'layer_independent': nn.Linear(self.d_model + independent_emb_size, self.d_model),
|
||||||
|
**{f"layer_{key}": nn.Linear(self.d_model + net_param.emb[key], self.d_model) for key in seq_pred_keys}
|
||||||
|
})
|
||||||
|
self.par_pred_keys = par_pred_keys
|
||||||
|
self.seq_pred_keys = seq_pred_keys
|
||||||
|
self.dropout = nn.Dropout(0.1)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x, choice):
|
||||||
|
'''
|
||||||
|
x: B x T x (d_model + emb_size)
|
||||||
|
choice: key type (str or list of str)
|
||||||
|
'''
|
||||||
|
if isinstance(choice, str): # single key
|
||||||
|
assert choice in self.seq_pred_keys
|
||||||
|
output = self.layers[f"layer_{choice}"](x)
|
||||||
|
return self.relu(self.dropout(output))
|
||||||
|
elif len(choice) > 1 and not isinstance(choice, str): # multiple keys, parallel
|
||||||
|
assert choice == self.par_pred_keys # the order of choice should be the same as the order of self.par_pred_keys
|
||||||
|
output = self.layers['layer_independent'](x)
|
||||||
|
return self.relu(self.dropout(output))
|
||||||
|
|
||||||
|
def mask_tensor(tensor, mask_rate=0.15):
|
||||||
|
# Get the size of the tensor
|
||||||
|
batch_size, seq_len, dim = tensor.size()
|
||||||
|
# Calculate the total number of elements and the number to mask
|
||||||
|
total_elements = batch_size * seq_len
|
||||||
|
num_to_mask = int(total_elements * mask_rate)
|
||||||
|
# Create a 1D binary mask where 1 indicates that element will be masked.
|
||||||
|
# Start by creating a tensor of zeros with length equal to the total number of elements.
|
||||||
|
mask = torch.zeros(total_elements).to(tensor.device)
|
||||||
|
# Set `num_to_mask` random indices to 1 (masking)
|
||||||
|
indices_to_mask = torch.randperm(total_elements)[:num_to_mask]
|
||||||
|
mask[indices_to_mask] = 1
|
||||||
|
# Reshape the mask to match the original tensor's shape
|
||||||
|
mask = mask.reshape(batch_size, seq_len)
|
||||||
|
mask = mask.unsqueeze(2) # B x T x 1
|
||||||
|
masked_tensor = tensor * (mask == 0).float() # B x T x d_model
|
||||||
|
return masked_tensor
|
||||||
|
|
||||||
|
def generate_causality_mask_on_window(size, window_size):
|
||||||
|
mask = torch.zeros((size, size))
|
||||||
|
for i in range(size):
|
||||||
|
mask[i, i+window_size:] = 1
|
||||||
|
return mask.bool()
|
||||||
|
|
||||||
|
# generate boolean mask, if the value is 1 or true, it means the value is masked
|
||||||
|
# considers BOS token and mask margin
|
||||||
|
def generate_CA_mask(tgt_len, memory_len, mask_margin=0):
|
||||||
|
mask = torch.triu(torch.ones((tgt_len, memory_len)), diagonal=mask_margin+1)
|
||||||
|
return mask.bool()
|
||||||
|
|
||||||
|
# generate boolean mask, if the value is 1 or true, it means the value is masked
|
||||||
|
def generate_SA_mask(tgt_len):
|
||||||
|
mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1)
|
||||||
|
return mask.bool()
|
||||||
|
|
||||||
|
def generate_none_causality_mask(tgt_len, memory_len):
|
||||||
|
mask = torch.zeros((tgt_len, memory_len))
|
||||||
|
return mask.bool()
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||||
|
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, input_dict):
|
||||||
|
'''
|
||||||
|
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||||||
|
'''
|
||||||
|
# cross attention
|
||||||
|
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
|
||||||
|
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||||
|
attn_output = self.dropout(attn_output)
|
||||||
|
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
class TransformerLayer(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||||
|
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||||
|
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, input_dict):
|
||||||
|
'''
|
||||||
|
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||||||
|
'''
|
||||||
|
# self attention
|
||||||
|
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
|
||||||
|
|
||||||
|
input_dict['input_seq'] = attn_output
|
||||||
|
# cross attention
|
||||||
|
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
|
||||||
|
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||||
|
attn_output = self.dropout(attn_output)
|
||||||
|
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
class FeatureEnricher(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
||||||
|
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, input_dict):
|
||||||
|
'''
|
||||||
|
input_dict = {'input_seq': input_seq, 'memory': memory}
|
||||||
|
'''
|
||||||
|
# cross attention
|
||||||
|
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], None, type='feature_enrichment')
|
||||||
|
attn_output = self.residual_FF.forward_mlp(attn_output)
|
||||||
|
attn_output = self.dropout(attn_output)
|
||||||
|
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
|
||||||
|
return output_dict
|
||||||
1280
Amadeus/sub_decoder_zoo.py
Normal file
1280
Amadeus/sub_decoder_zoo.py
Normal file
File diff suppressed because it is too large
Load Diff
0
Amadeus/symbolic_encoding/__init__.py
Normal file
0
Amadeus/symbolic_encoding/__init__.py
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-311.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-312.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-312.pyc
Normal file
Binary file not shown.
46
Amadeus/symbolic_encoding/anylazesf.py
Normal file
46
Amadeus/symbolic_encoding/anylazesf.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from sf2utils.sf2parse import Sf2File
|
||||||
|
|
||||||
|
def print_sorted_presets(sf2_path):
|
||||||
|
presets_info = []
|
||||||
|
|
||||||
|
with open(sf2_path, 'rb') as f:
|
||||||
|
sf2 = Sf2File(f)
|
||||||
|
|
||||||
|
for preset in sf2.presets:
|
||||||
|
try:
|
||||||
|
# 尝试直接读取
|
||||||
|
name = getattr(preset, 'name', '???').strip('\x00')
|
||||||
|
bank = getattr(preset, 'bank', None)
|
||||||
|
program = getattr(preset, 'preset', None)
|
||||||
|
|
||||||
|
# 如果获取不到,再尝试从子属性中取
|
||||||
|
if bank is None or program is None:
|
||||||
|
for attr in dir(preset):
|
||||||
|
attr_value = getattr(preset, attr)
|
||||||
|
if hasattr(attr_value, 'bank') and hasattr(attr_value, 'preset'):
|
||||||
|
bank = attr_value.bank
|
||||||
|
program = attr_value.preset
|
||||||
|
name = getattr(attr_value, 'name', name).strip('\x00')
|
||||||
|
break
|
||||||
|
|
||||||
|
# 收集有效结果
|
||||||
|
if bank is not None and program is not None:
|
||||||
|
presets_info.append((program, bank, name))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading preset: {e}")
|
||||||
|
|
||||||
|
# 按 program 升序排序(若需要按 bank 再 program,改为 sorted(..., key=lambda x: (x[1], x[0])))
|
||||||
|
presets_info.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print(f"{'Program':<8} {'Bank':<6} {'Preset Name'}")
|
||||||
|
print("-" * 40)
|
||||||
|
for program, bank, name in presets_info:
|
||||||
|
print(f"{program:<8} {bank:<6} {name}")
|
||||||
|
|
||||||
|
# DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
|
||||||
|
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
|
||||||
|
|
||||||
|
# 替换为你的 sf2 文件路径
|
||||||
|
sf2_path = "/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2"
|
||||||
|
print_sorted_presets(sf2_path)
|
||||||
94
Amadeus/symbolic_encoding/augmentor.py
Normal file
94
Amadeus/symbolic_encoding/augmentor.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import random
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Augmentor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab,
|
||||||
|
aug_type:Union[str, None],
|
||||||
|
input_length:int
|
||||||
|
):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.aug_type = aug_type
|
||||||
|
self.input_length = input_length
|
||||||
|
self.feature_list = vocab.feature_list
|
||||||
|
self.num_features = len(self.feature_list)
|
||||||
|
self.encoding_scheme = vocab.encoding_scheme
|
||||||
|
|
||||||
|
self.pitch_idx = self.feature_list.index('pitch')
|
||||||
|
if 'chord' in self.feature_list:
|
||||||
|
self.chord_idx = self.feature_list.index('chord')
|
||||||
|
|
||||||
|
def _get_shift(self, segment):
|
||||||
|
# the pitch vocab has ignore token in 0 index
|
||||||
|
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||||
|
pitch_mask = segment != 0
|
||||||
|
pitch_segment = segment[pitch_mask[:,self.pitch_idx], self.pitch_idx]
|
||||||
|
# check if tensor is empty
|
||||||
|
if pitch_segment.numel() == 0:
|
||||||
|
shift = 0
|
||||||
|
else:
|
||||||
|
lowest_pitch = max(12, torch.min(pitch_segment))
|
||||||
|
highest_pitch = min(119, torch.max(pitch_segment))
|
||||||
|
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) > 11)[0][-1].item()
|
||||||
|
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) < 120)[0][-1].item()
|
||||||
|
shift = random.randint(-lower_shift_bound, upper_shift_bound)
|
||||||
|
else: # remi
|
||||||
|
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
|
||||||
|
segemnt_pitch_mask = mask_for_pitch[segment]
|
||||||
|
segment_pitch = segment * segemnt_pitch_mask
|
||||||
|
segment_pitch = segment_pitch[segment_pitch != 0]
|
||||||
|
# check if tensor is empty
|
||||||
|
if segment_pitch.numel() == 0:
|
||||||
|
shift = 0
|
||||||
|
else:
|
||||||
|
lower_bound = torch.argwhere(mask_for_pitch == 1)[0].item()
|
||||||
|
upper_bound = torch.argwhere(mask_for_pitch == 1)[-1].item()
|
||||||
|
lowest_pitch = max(lower_bound, torch.min(segment_pitch))
|
||||||
|
highest_pitch = min(upper_bound, torch.max(segment_pitch))
|
||||||
|
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) >= lower_bound)[0][-1].item()
|
||||||
|
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) <= upper_bound)[0][-1].item()
|
||||||
|
shift = random.randint(-lower_shift_bound, upper_shift_bound)
|
||||||
|
return shift
|
||||||
|
|
||||||
|
# TODO: arrange hard coded part
|
||||||
|
def __call__(self, segment):
|
||||||
|
'''
|
||||||
|
input_tensor is segments of x, y
|
||||||
|
for transformer_xl, the shape of x, y is [max_num_segments, input_length, num_features]
|
||||||
|
so we need to change the shape of x, y to [max_num_segments*input_length, num_features]
|
||||||
|
'''
|
||||||
|
if self.aug_type == 'random':
|
||||||
|
shift = self._get_shift(segment)
|
||||||
|
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||||
|
# pitch augmentation
|
||||||
|
segment_pitch_mask = segment != 0
|
||||||
|
new_segment = segment.clone()
|
||||||
|
new_segment[segment_pitch_mask[:,self.pitch_idx], self.pitch_idx] += shift
|
||||||
|
if 'chord' in self.feature_list:
|
||||||
|
# chord augmentation
|
||||||
|
segment_chord_mask = (segment[:,self.chord_idx] != 0) & (segment[:,self.chord_idx] != 1)
|
||||||
|
new_segment[segment_chord_mask, self.chord_idx] = (((new_segment[segment_chord_mask, self.chord_idx]-2) % 12) + shift ) % 12 + ((new_segment[segment_chord_mask, self.chord_idx]-2) // 12) * 12 + 2
|
||||||
|
segment = new_segment
|
||||||
|
else: # remi
|
||||||
|
# choose random interger between -5 and 6
|
||||||
|
# the augmented results from shift -6 and 6 are same, so we choose -5 and 6
|
||||||
|
# pitch augmentation
|
||||||
|
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
|
||||||
|
segment_pitch_mask = mask_for_pitch[segment]
|
||||||
|
new_segment = segment.clone()
|
||||||
|
new_segment_valid = (new_segment + shift) * segment_pitch_mask
|
||||||
|
new_segment = new_segment * (1 - segment_pitch_mask) + new_segment_valid
|
||||||
|
if 'chord' in self.feature_list:
|
||||||
|
# chord augmentation
|
||||||
|
mask_for_chord = self.vocab.total_mask['chord'].clone().to(segment.device)
|
||||||
|
chord_n_n_idx = torch.argwhere(mask_for_chord == 1)[-1].item()
|
||||||
|
mask_for_chord[chord_n_n_idx] = 0
|
||||||
|
start_idx_chord = self.vocab.remi_vocab_boundaries_by_key['chord'][0]
|
||||||
|
segment_chord_mask = mask_for_chord[segment]
|
||||||
|
new_segment_valid = ((((new_segment - start_idx_chord) % 12 + shift) % 12) + ((new_segment - start_idx_chord) // 12) * 12 + start_idx_chord) * segment_chord_mask
|
||||||
|
new_segment = new_segment * (1 - segment_chord_mask) + new_segment_valid
|
||||||
|
segment = new_segment
|
||||||
|
return segment
|
||||||
207
Amadeus/symbolic_encoding/compile_utils.py
Normal file
207
Amadeus/symbolic_encoding/compile_utils.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
def reverse_shift_and_pad(tune_in_idx, slice_boundary=4):
|
||||||
|
new_lst = [curr_elems[:slice_boundary] + next_elems[slice_boundary:] for curr_elems, next_elems in zip(tune_in_idx, tune_in_idx[1:])]
|
||||||
|
return new_lst
|
||||||
|
|
||||||
|
def reverse_shift_and_pad_for_tensor(tensor, first_pred_feature):
|
||||||
|
'''
|
||||||
|
tensor: [batch_size x seq_len x feature_size]
|
||||||
|
'''
|
||||||
|
if first_pred_feature == 'type':
|
||||||
|
return tensor
|
||||||
|
if tensor.shape[-1] == 8:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
||||||
|
elif tensor.shape[-1] == 7:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'pitch':4, 'duration':5, 'velocity':6}
|
||||||
|
elif tensor.shape[-1] == 5:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':1, 'instrument':2, 'pitch':3, 'duration':4}
|
||||||
|
elif tensor.shape[-1] == 4:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':1, 'pitch':2, 'duration':3}
|
||||||
|
slice_boundary = slice_boundary_dict[first_pred_feature]
|
||||||
|
new_tensor = torch.zeros_like(tensor)
|
||||||
|
new_tensor[..., :, :slice_boundary] = tensor[..., :, :slice_boundary]
|
||||||
|
new_tensor[..., :-1, slice_boundary:] = tensor[..., 1:, slice_boundary:]
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
|
def shift_and_pad(tune_in_idx, first_pred_feature):
|
||||||
|
if first_pred_feature == 'type':
|
||||||
|
return tune_in_idx
|
||||||
|
if len(tune_in_idx[0]) == 8:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':-7, 'chord':-6, 'tempo':-5, 'instrument':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
|
||||||
|
elif len(tune_in_idx[0]) == 7:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':-6, 'chord':-5, 'tempo':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
|
||||||
|
elif len(tune_in_idx[0]) == 5:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':-4, 'instrument':-3, 'pitch':-2, 'duration':-1}
|
||||||
|
elif len(tune_in_idx[0]) == 4:
|
||||||
|
slice_boundary_dict = {'type':0, 'beat':-3, 'pitch':-2, 'duration':-1}
|
||||||
|
slice_boundary = slice_boundary_dict[first_pred_feature]
|
||||||
|
# Add an empty list padded with zeros at the beginning, and sos and eos tokens are not shifted
|
||||||
|
padded_tune_in_idx = torch.cat([torch.zeros(1, len(tune_in_idx[0]), dtype=torch.long), tune_in_idx], dim=0)
|
||||||
|
new_tensor = torch.zeros_like(padded_tune_in_idx)
|
||||||
|
new_tensor[:, slice_boundary:] = padded_tune_in_idx[:, slice_boundary:]
|
||||||
|
new_tensor[:-1, :slice_boundary] = padded_tune_in_idx[1:, :slice_boundary]
|
||||||
|
return new_tensor
|
||||||
|
|
||||||
|
class VanillaTransformer_compiler():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_list,
|
||||||
|
augmentor,
|
||||||
|
eos_token,
|
||||||
|
input_length,
|
||||||
|
first_pred_feature,
|
||||||
|
encoding_scheme
|
||||||
|
):
|
||||||
|
self.data_list = data_list
|
||||||
|
self.augmentor = augmentor
|
||||||
|
self.eos_token = eos_token
|
||||||
|
self.input_length = input_length
|
||||||
|
self.first_pred_feature = first_pred_feature
|
||||||
|
self.encoding_scheme = encoding_scheme
|
||||||
|
|
||||||
|
def make_segments(self, data_type):
|
||||||
|
segments = []
|
||||||
|
tune_name2segment = defaultdict(list)
|
||||||
|
segment2tune_name = []
|
||||||
|
num_segments = 0
|
||||||
|
for i in range(len(self.data_list)):
|
||||||
|
tune_in_idx, tune_name = self.data_list[i]
|
||||||
|
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||||
|
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||||
|
eos_token = torch.LongTensor(self.eos_token)
|
||||||
|
else:
|
||||||
|
eos_token = torch.LongTensor(self.eos_token)
|
||||||
|
# shift and pad
|
||||||
|
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
|
||||||
|
if data_type == 'train':
|
||||||
|
if len(tune_in_idx) <= self.input_length+1:
|
||||||
|
if 'remi' in self.encoding_scheme:
|
||||||
|
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
|
||||||
|
else:
|
||||||
|
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
|
||||||
|
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||||
|
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
|
||||||
|
segments.append([segment, mask])
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
else:
|
||||||
|
start_point = 0
|
||||||
|
while start_point + self.input_length+1 < len(tune_in_idx):
|
||||||
|
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||||
|
segment = tune_in_idx[start_point:start_point + self.input_length+1]
|
||||||
|
segments.append([segment, mask])
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
assert len(segment) == self.input_length+1
|
||||||
|
# Randomly choose the start point for the next segment, which is in the range of half of the current segment to the end of the current segment
|
||||||
|
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
|
||||||
|
# if text controled,we only use the first segment
|
||||||
|
# add the last segment
|
||||||
|
if len(tune_in_idx[start_point:]) < self.input_length+1:
|
||||||
|
if 'remi' in self.encoding_scheme:
|
||||||
|
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
|
||||||
|
else:
|
||||||
|
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
|
||||||
|
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||||
|
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
|
||||||
|
segments.append([segment, mask])
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
|
||||||
|
|
||||||
|
else: # for validset
|
||||||
|
for i in range(0, len(tune_in_idx), self.input_length+1):
|
||||||
|
segment = tune_in_idx[i:i+self.input_length+1]
|
||||||
|
if len(segment) <= self.input_length+1:
|
||||||
|
if 'remi' in self.encoding_scheme:
|
||||||
|
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
|
||||||
|
else:
|
||||||
|
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
|
||||||
|
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||||
|
segment = torch.cat([segment, padding_seq], dim=0)
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
segments.append([segment, mask])
|
||||||
|
num_segments += 1
|
||||||
|
tune_name2segment[tune_name].append(num_segments-1)
|
||||||
|
else:
|
||||||
|
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||||
|
segments.append([segment, mask])
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
segments.append([segment, mask])
|
||||||
|
num_segments += 1
|
||||||
|
tune_name2segment[tune_name].append(num_segments-1)
|
||||||
|
assert len(segment) == self.input_length+1
|
||||||
|
|
||||||
|
return segments, tune_name2segment, segment2tune_name
|
||||||
|
|
||||||
|
def make_segments_iters(self, data_type):
|
||||||
|
tune_name2segment = defaultdict(list)
|
||||||
|
segment2tune_name = []
|
||||||
|
num_segments = 0
|
||||||
|
# shuffle the data_list
|
||||||
|
if data_type == 'train':
|
||||||
|
random.shuffle(self.data_list)
|
||||||
|
print("length of data_list:", len(self.data_list))
|
||||||
|
for i in range(len(self.data_list)):
|
||||||
|
tune_in_idx, tune_name = self.data_list[i]
|
||||||
|
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||||
|
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||||
|
eos_token = torch.LongTensor(self.eos_token)
|
||||||
|
else:
|
||||||
|
eos_token = torch.LongTensor(self.eos_token)
|
||||||
|
# shift and pad
|
||||||
|
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
|
||||||
|
if data_type == 'train':
|
||||||
|
if len(tune_in_idx) <= self.input_length+1:
|
||||||
|
if 'remi' in self.encoding_scheme:
|
||||||
|
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
|
||||||
|
else:
|
||||||
|
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
|
||||||
|
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||||
|
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||||
|
else:
|
||||||
|
start_point = 0
|
||||||
|
while start_point + self.input_length+1 < len(tune_in_idx):
|
||||||
|
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||||
|
segment = tune_in_idx[start_point:start_point + self.input_length+1]
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||||
|
assert len(segment) == self.input_length+1
|
||||||
|
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
|
||||||
|
# break
|
||||||
|
if len(tune_in_idx[start_point:]) < self.input_length+1:
|
||||||
|
if 'remi' in self.encoding_scheme:
|
||||||
|
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
|
||||||
|
else:
|
||||||
|
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
|
||||||
|
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||||
|
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||||
|
else: # for validset
|
||||||
|
for i in range(0, len(tune_in_idx), self.input_length+1):
|
||||||
|
segment = tune_in_idx[i:i+self.input_length+1]
|
||||||
|
if len(segment) <= self.input_length+1:
|
||||||
|
if 'remi' in self.encoding_scheme:
|
||||||
|
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
|
||||||
|
else:
|
||||||
|
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
|
||||||
|
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||||
|
segment = torch.cat([segment, padding_seq], dim=0)
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
num_segments += 1
|
||||||
|
tune_name2segment[tune_name].append(num_segments-1)
|
||||||
|
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||||
|
else:
|
||||||
|
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||||
|
segment2tune_name.append(tune_name)
|
||||||
|
num_segments += 1
|
||||||
|
tune_name2segment[tune_name].append(num_segments-1)
|
||||||
|
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||||
|
assert len(segment) == self.input_length+1
|
||||||
|
|
||||||
1610
Amadeus/symbolic_encoding/data_utils.py
Normal file
1610
Amadeus/symbolic_encoding/data_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
404
Amadeus/symbolic_encoding/decoding_utils.py
Normal file
404
Amadeus/symbolic_encoding/decoding_utils.py
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
import os, sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from music21 import converter
|
||||||
|
import muspy
|
||||||
|
import miditoolkit
|
||||||
|
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
|
||||||
|
|
||||||
|
from .midi2audio import FluidSynth
|
||||||
|
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
|
||||||
|
|
||||||
|
class MuteWarn:
|
||||||
|
def __enter__(self):
|
||||||
|
self._init_stdout = sys.stdout
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
sys.stdout.close()
|
||||||
|
sys.stdout = self._init_stdout
|
||||||
|
|
||||||
|
def save_score_image_from_midi(midi_fn, file_name):
|
||||||
|
assert isinstance(midi_fn, str)
|
||||||
|
with MuteWarn():
|
||||||
|
convert = converter.parse(midi_fn)
|
||||||
|
convert.write('musicxml.png', fp=file_name)
|
||||||
|
|
||||||
|
def save_pianoroll_image_from_midi(midi_fn, file_name):
|
||||||
|
assert isinstance(midi_fn, str)
|
||||||
|
midi_obj_muspy = muspy.read_midi(midi_fn)
|
||||||
|
midi_obj_muspy.show_pianoroll(track_label='program', preset='frame')
|
||||||
|
plt.gcf().set_size_inches(20, 10)
|
||||||
|
plt.savefig(file_name)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def save_wav_from_midi(midi_fn, file_name, qpm=80):
|
||||||
|
assert isinstance(midi_fn, str)
|
||||||
|
with MuteWarn():
|
||||||
|
music = muspy.read_midi(midi_fn)
|
||||||
|
music.tempos[0].qpm = qpm
|
||||||
|
music.write_audio(file_name, rate=44100, gain=3)
|
||||||
|
|
||||||
|
def save_wav_from_midi_fluidsynth(midi_fn, file_name, gain=3):
|
||||||
|
assert isinstance(midi_fn, str)
|
||||||
|
fs = FluidSynth(gain=gain)
|
||||||
|
fs.midi_to_audio(midi_fn, file_name)
|
||||||
|
|
||||||
|
class MidiDecoder4REMI:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab,
|
||||||
|
in_beat_resolution,
|
||||||
|
dataset_name
|
||||||
|
):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.in_beat_resolution = in_beat_resolution
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
if dataset_name == 'SymphonyMIDI':
|
||||||
|
self.gain = 0.7
|
||||||
|
elif dataset_name == 'SOD' or dataset_name == 'LakhClean':
|
||||||
|
self.gain = 1.1
|
||||||
|
elif dataset_name == 'Pop1k7' or dataset_name == 'Pop909':
|
||||||
|
self.gain = 2.5
|
||||||
|
else:
|
||||||
|
self.gain = 1.5
|
||||||
|
|
||||||
|
def __call__(self, generated_output, output_path=None):
|
||||||
|
'''
|
||||||
|
generated_output: list of tensor, the tensor
|
||||||
|
'''
|
||||||
|
idx2event = self.vocab.idx2event
|
||||||
|
if generated_output.dim() == 2:
|
||||||
|
generated_output = generated_output.squeeze(0)
|
||||||
|
events = [idx2event[token.item()] for token in generated_output]
|
||||||
|
|
||||||
|
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||||
|
if 'tempo' not in idx2event.keys():
|
||||||
|
default_tempo = 95
|
||||||
|
midi_obj.tempo_changes.append(
|
||||||
|
TempoChange(tempo=default_tempo, time=0))
|
||||||
|
default_ticks_per_beat = 480
|
||||||
|
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||||||
|
cur_pos = 0
|
||||||
|
bar_pos = 0
|
||||||
|
cur_bar_resol = 0
|
||||||
|
beat_pos = 0
|
||||||
|
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||||||
|
instr_notes_dict = defaultdict(list)
|
||||||
|
for i in range(len(events)-2):
|
||||||
|
cur_event = events[i]
|
||||||
|
# print(cur_event)
|
||||||
|
name = cur_event.split('_')[0]
|
||||||
|
attr = cur_event.split('_')
|
||||||
|
if name == 'Bar':
|
||||||
|
bar_pos += cur_bar_resol
|
||||||
|
if 'time' in cur_event:
|
||||||
|
cur_num, cur_denom = attr[-1].split('/')
|
||||||
|
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||||
|
cur_bar_resol = new_bar_resol
|
||||||
|
midi_obj.time_signature_changes.append(
|
||||||
|
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||||
|
elif name == 'Beat':
|
||||||
|
beat_pos = int(attr[1])
|
||||||
|
cur_pos = bar_pos + beat_pos * default_in_beat_ticks
|
||||||
|
elif name == 'Chord':
|
||||||
|
chord_text = attr[1] + '_' + attr[2]
|
||||||
|
midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
|
||||||
|
elif name == 'Tempo':
|
||||||
|
midi_obj.tempo_changes.append(
|
||||||
|
TempoChange(tempo=int(attr[1]), time=cur_pos))
|
||||||
|
elif name == 'Instrument':
|
||||||
|
cur_instr = int(attr[1])
|
||||||
|
else:
|
||||||
|
if len(self.vocab.feature_list) == 7 or len(self.vocab.feature_list) == 8:
|
||||||
|
if 'Note_Pitch' in events[i] and \
|
||||||
|
'Note_Duration' in events[i+1] and \
|
||||||
|
'Note_Velocity' in events[i+2]:
|
||||||
|
pitch = int(events[i].split('_')[-1])
|
||||||
|
duration = int(events[i+1].split('_')[-1])
|
||||||
|
duration = duration * default_in_beat_ticks
|
||||||
|
end = cur_pos + duration
|
||||||
|
velocity = int(events[i+2].split('_')[-1])
|
||||||
|
instr_notes_dict[cur_instr].append(
|
||||||
|
Note(
|
||||||
|
pitch=pitch,
|
||||||
|
start=cur_pos,
|
||||||
|
end=end,
|
||||||
|
velocity=velocity))
|
||||||
|
elif len(self.vocab.feature_list) == 4 or len(self.vocab.feature_list) == 5:
|
||||||
|
if 'Note_Pitch' in events[i] and \
|
||||||
|
'Note_Duration' in events[i+1]:
|
||||||
|
pitch = int(events[i].split('_')[-1])
|
||||||
|
duration = int(events[i+1].split('_')[-1])
|
||||||
|
duration = duration * default_in_beat_ticks
|
||||||
|
end = cur_pos + duration
|
||||||
|
velocity = 90
|
||||||
|
instr_notes_dict[cur_instr].append(
|
||||||
|
Note(
|
||||||
|
pitch=pitch,
|
||||||
|
start=cur_pos,
|
||||||
|
end=end,
|
||||||
|
velocity=velocity))
|
||||||
|
|
||||||
|
# save midi
|
||||||
|
for instr, notes in instr_notes_dict.items():
|
||||||
|
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||||||
|
if instr == 114: is_drum = True
|
||||||
|
else: is_drum = False
|
||||||
|
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||||||
|
instr_track.notes = notes
|
||||||
|
midi_obj.instruments.append(instr_track)
|
||||||
|
|
||||||
|
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||||||
|
output_path = str(output_path)
|
||||||
|
# make subdir
|
||||||
|
music_path = os.path.join(os.path.dirname(output_path), 'music')
|
||||||
|
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
|
||||||
|
if not os.path.exists(music_path):
|
||||||
|
os.makedirs(music_path)
|
||||||
|
if not os.path.exists(prompt_music_path):
|
||||||
|
os.makedirs(prompt_music_path)
|
||||||
|
# if not contain 'prompt' in output_path, save prompt music
|
||||||
|
if 'prompt' in output_path:
|
||||||
|
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||||
|
else:
|
||||||
|
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||||
|
|
||||||
|
midi_obj.dump(output_path)
|
||||||
|
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||||
|
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||||
|
return midi_obj
|
||||||
|
|
||||||
|
class MidiDecoder4CP(MidiDecoder4REMI):
|
||||||
|
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||||||
|
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||||||
|
|
||||||
|
def _update_chord_tempo(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
|
||||||
|
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||||
|
# chord
|
||||||
|
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0:
|
||||||
|
midi_obj.markers.append(
|
||||||
|
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
|
||||||
|
# tempo
|
||||||
|
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
|
||||||
|
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
|
||||||
|
midi_obj.tempo_changes.append(
|
||||||
|
TempoChange(tempo=tempo, time=cur_pos))
|
||||||
|
return midi_obj
|
||||||
|
elif len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||||
|
return midi_obj
|
||||||
|
|
||||||
|
def __call__(self, generated_output, output_path=None):
|
||||||
|
'''
|
||||||
|
generated_output: tensor, batch x seq_len x num_types
|
||||||
|
num_types includes: type, tempo, chord,'beat, pitch, duration, velocity
|
||||||
|
'''
|
||||||
|
idx2event = self.vocab.idx2event
|
||||||
|
feature_keys = self.vocab.feature_list
|
||||||
|
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
|
||||||
|
|
||||||
|
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||||
|
if len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||||
|
default_tempo = 95
|
||||||
|
midi_obj.tempo_changes.append(
|
||||||
|
TempoChange(tempo=default_tempo, time=0))
|
||||||
|
default_ticks_per_beat = 480
|
||||||
|
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||||||
|
cur_pos = 0
|
||||||
|
bar_pos = 0
|
||||||
|
cur_bar_resol = 0
|
||||||
|
beat_pos = 0
|
||||||
|
instr_notes_dict = defaultdict(list)
|
||||||
|
generated_output = generated_output.squeeze(0)
|
||||||
|
for i in range(len(generated_output)):
|
||||||
|
token_with_7infos = []
|
||||||
|
for kidx, key in enumerate(feature_keys):
|
||||||
|
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
|
||||||
|
# type token
|
||||||
|
if 'time_signature' in token_with_7infos[feature2idx['type']]:
|
||||||
|
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
|
||||||
|
bar_pos += cur_bar_resol
|
||||||
|
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||||
|
cur_bar_resol = new_bar_resol
|
||||||
|
midi_obj.time_signature_changes.append(
|
||||||
|
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||||
|
elif token_with_7infos[feature2idx['type']] == 'Metrical':
|
||||||
|
if 'time_signature' in token_with_7infos[feature2idx['beat']]:
|
||||||
|
cur_num, cur_denom = token_with_7infos[feature2idx['beat']].split('_')[-1].split('/')
|
||||||
|
bar_pos += cur_bar_resol
|
||||||
|
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||||
|
cur_bar_resol = new_bar_resol
|
||||||
|
midi_obj.time_signature_changes.append(
|
||||||
|
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||||
|
elif token_with_7infos[feature2idx['beat']] == 'Bar':
|
||||||
|
bar_pos += cur_bar_resol
|
||||||
|
elif 'Beat' in str(token_with_7infos[feature2idx['beat']]):
|
||||||
|
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
|
||||||
|
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
|
||||||
|
# chord and tempo
|
||||||
|
midi_obj = self._update_chord_tempo(midi_obj, cur_pos, token_with_7infos, feature2idx)
|
||||||
|
elif token_with_7infos[feature2idx['type']] == 'Note':
|
||||||
|
# instrument token
|
||||||
|
if len(feature2idx) == 8 or len(feature2idx) == 5:
|
||||||
|
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
|
||||||
|
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
|
||||||
|
else:
|
||||||
|
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||||||
|
try:
|
||||||
|
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
|
||||||
|
duration = token_with_7infos[feature2idx['duration']].split('_')[-1]
|
||||||
|
duration = int(duration) * default_in_beat_ticks
|
||||||
|
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||||
|
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
|
||||||
|
else:
|
||||||
|
velocity = 80
|
||||||
|
end = cur_pos + duration
|
||||||
|
instr_notes_dict[cur_instr].append(
|
||||||
|
Note(
|
||||||
|
pitch=int(pitch),
|
||||||
|
start=cur_pos,
|
||||||
|
end=end,
|
||||||
|
velocity=int(velocity))
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
else: # when new bar started without beat
|
||||||
|
continue
|
||||||
|
|
||||||
|
# save midi
|
||||||
|
for instr, notes in instr_notes_dict.items():
|
||||||
|
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||||||
|
if instr == 114: is_drum = True
|
||||||
|
else: is_drum = False
|
||||||
|
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||||||
|
instr_track.notes = notes
|
||||||
|
midi_obj.instruments.append(instr_track)
|
||||||
|
|
||||||
|
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||||||
|
output_path = str(output_path)
|
||||||
|
output_music_dir = os.path.join(os.path.dirname(output_path), 'music')
|
||||||
|
if not os.path.exists(output_music_dir):
|
||||||
|
os.makedirs(output_music_dir)
|
||||||
|
midi_obj.dump(output_path)
|
||||||
|
save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||||
|
save_wav_from_midi_fluidsynth(output_path, output_music_dir.replace('.mid', '.wav'), gain=self.gain)
|
||||||
|
return midi_obj
|
||||||
|
|
||||||
|
class MidiDecoder4NB(MidiDecoder4REMI):
|
||||||
|
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||||||
|
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||||||
|
|
||||||
|
def _update_additional_info(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
|
||||||
|
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||||
|
# chord
|
||||||
|
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0 and token_with_7infos[feature2idx['chord']] != 'Chord_N_N':
|
||||||
|
midi_obj.markers.append(
|
||||||
|
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
|
||||||
|
# tempo
|
||||||
|
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
|
||||||
|
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
|
||||||
|
midi_obj.tempo_changes.append(
|
||||||
|
TempoChange(tempo=tempo, time=cur_pos))
|
||||||
|
return midi_obj
|
||||||
|
elif len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||||
|
return midi_obj
|
||||||
|
|
||||||
|
def __call__(self, generated_output, output_path=None):
|
||||||
|
'''
|
||||||
|
generated_output: tensor, batch x seq_len x num_types
|
||||||
|
num_types includes: type, beat, chord, tempo, intrument, pitch, duration, velocity
|
||||||
|
'''
|
||||||
|
idx2event = self.vocab.idx2event
|
||||||
|
feature_keys = self.vocab.feature_list
|
||||||
|
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
|
||||||
|
|
||||||
|
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||||
|
if len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||||
|
default_tempo = 95
|
||||||
|
midi_obj.tempo_changes.append(
|
||||||
|
TempoChange(tempo=default_tempo, time=0))
|
||||||
|
default_ticks_per_beat = 480
|
||||||
|
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||||||
|
cur_pos = 0
|
||||||
|
bar_pos = 0
|
||||||
|
cur_bar_resol = 0
|
||||||
|
beat_pos = 0
|
||||||
|
instr_notes_dict = defaultdict(list)
|
||||||
|
generated_output = generated_output.squeeze(0)
|
||||||
|
for i in range(len(generated_output)):
|
||||||
|
token_with_7infos = []
|
||||||
|
for kidx, key in enumerate(feature_keys):
|
||||||
|
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
|
||||||
|
# type token
|
||||||
|
if token_with_7infos[feature2idx['type']] == 'Empty_Bar' or token_with_7infos[feature2idx['type']] == 'SNN':
|
||||||
|
bar_pos += cur_bar_resol
|
||||||
|
elif 'NNN' in token_with_7infos[feature2idx['type']]:
|
||||||
|
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
|
||||||
|
bar_pos += cur_bar_resol
|
||||||
|
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||||
|
cur_bar_resol = new_bar_resol
|
||||||
|
midi_obj.time_signature_changes.append(
|
||||||
|
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||||
|
# instrument token
|
||||||
|
if len(feature2idx) == 8 or len(feature2idx) == 5:
|
||||||
|
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
|
||||||
|
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
|
||||||
|
else:
|
||||||
|
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||||||
|
if 'Beat' in str(token_with_7infos[feature2idx['beat']]) or 'CONTI' in str(token_with_7infos[feature2idx['beat']]):
|
||||||
|
if 'Beat' in str(token_with_7infos[feature2idx['beat']]): # when beat is not CONTI beat is updated
|
||||||
|
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
|
||||||
|
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
|
||||||
|
# update chord and tempo
|
||||||
|
midi_obj = self._update_additional_info(midi_obj, cur_pos, token_with_7infos, feature2idx)
|
||||||
|
# note
|
||||||
|
try:
|
||||||
|
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
|
||||||
|
duration = token_with_7infos[feature2idx['duration']].split('_')[-1] # duration between 1~192
|
||||||
|
duration = int(duration) * default_in_beat_ticks
|
||||||
|
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||||
|
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
|
||||||
|
else:
|
||||||
|
velocity = 90
|
||||||
|
end = cur_pos + duration
|
||||||
|
instr_notes_dict[cur_instr].append(
|
||||||
|
Note(
|
||||||
|
pitch=int(pitch),
|
||||||
|
start=cur_pos,
|
||||||
|
end=end,
|
||||||
|
velocity=int(velocity))
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
else: # when new bar started without beat
|
||||||
|
continue
|
||||||
|
|
||||||
|
# save midi
|
||||||
|
for instr, notes in instr_notes_dict.items():
|
||||||
|
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||||||
|
if instr == 114: is_drum = True
|
||||||
|
else: is_drum = False
|
||||||
|
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||||||
|
instr_track.notes = notes
|
||||||
|
midi_obj.instruments.append(instr_track)
|
||||||
|
|
||||||
|
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||||||
|
output_path = str(output_path)
|
||||||
|
music_path = os.path.join(os.path.dirname(output_path), 'music')
|
||||||
|
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
|
||||||
|
if not os.path.exists(music_path):
|
||||||
|
os.makedirs(music_path)
|
||||||
|
if not os.path.exists(prompt_music_path):
|
||||||
|
os.makedirs(prompt_music_path)
|
||||||
|
# if not contain 'prompt' in output_path, save prompt music
|
||||||
|
if 'prompt' in output_path:
|
||||||
|
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||||
|
else:
|
||||||
|
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||||
|
midi_obj.dump(output_path)
|
||||||
|
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||||
|
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||||
|
return midi_obj
|
||||||
208
Amadeus/symbolic_encoding/metric_utils.py
Normal file
208
Amadeus/symbolic_encoding/metric_utils.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
# TODO: refactor hard coded values
|
||||||
|
def check_syntax_errors_in_inference_for_nb(generated_output, feature_list):
|
||||||
|
generated_output = generated_output.squeeze(0)
|
||||||
|
type_idx = feature_list.index('type')
|
||||||
|
beat_idx = feature_list.index('beat')
|
||||||
|
type_beat_list = []
|
||||||
|
for token in generated_output:
|
||||||
|
type_beat_list.append((token[type_idx].item(), token[beat_idx].item())) # type, beat
|
||||||
|
|
||||||
|
last_note = 1
|
||||||
|
beat_type_unmatched_error_list = []
|
||||||
|
num_unmatched_errors = 0
|
||||||
|
beat_backwards_error_list = []
|
||||||
|
num_backwards_errors = 0
|
||||||
|
for type_beat in type_beat_list:
|
||||||
|
if type_beat[0] == 4: # same bar, new beat
|
||||||
|
if type_beat[1] == 0 or type_beat[1] == 1:
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
beat_type_unmatched_error_list.append(type_beat)
|
||||||
|
if type_beat[1] <= last_note:
|
||||||
|
num_backwards_errors += 1
|
||||||
|
beat_backwards_error_list.append([last_note, type_beat])
|
||||||
|
else:
|
||||||
|
last_note = type_beat[1] # update last note
|
||||||
|
elif type_beat[0] >= 5: # new bar, new beat
|
||||||
|
if type_beat[1] == 0:
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
beat_type_unmatched_error_list.append(type_beat)
|
||||||
|
last_note = 1
|
||||||
|
unmatched_error_rate = num_unmatched_errors / len(type_beat_list)
|
||||||
|
backwards_error_rate = num_backwards_errors / len(type_beat_list)
|
||||||
|
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
|
||||||
|
return type_beat_errors_dict
|
||||||
|
|
||||||
|
def check_syntax_errors_in_inference_for_cp(generated_output, feature_list):
|
||||||
|
generated_output = generated_output.squeeze(0)
|
||||||
|
type_idx = feature_list.index('type')
|
||||||
|
beat_idx = feature_list.index('beat')
|
||||||
|
pitch_idx = feature_list.index('pitch')
|
||||||
|
duration_idx = feature_list.index('duration')
|
||||||
|
last_note = 1
|
||||||
|
beat_type_unmatched_error_list = []
|
||||||
|
num_unmatched_errors = 0
|
||||||
|
beat_backwards_error_list = []
|
||||||
|
num_backwards_errors = 0
|
||||||
|
for token in generated_output:
|
||||||
|
if token[type_idx].item() == 2: # Metrical
|
||||||
|
if token[pitch_idx].item() != 0 or token[duration_idx].item() != 0:
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
beat_type_unmatched_error_list.append(token)
|
||||||
|
if token[beat_idx].item() == 1: # new bar
|
||||||
|
last_note = 1 # last note will be updated in the next token
|
||||||
|
elif token[beat_idx].item() != 0 and token[beat_idx].item() <= last_note:
|
||||||
|
num_backwards_errors += 1
|
||||||
|
last_note = token[beat_idx].item() # update last note
|
||||||
|
beat_backwards_error_list.append([last_note, token])
|
||||||
|
else:
|
||||||
|
last_note = token[beat_idx].item() # update last note
|
||||||
|
if token[type_idx].item() == 3: # Note
|
||||||
|
if token[beat_idx].item() != 0:
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
beat_type_unmatched_error_list.append(token)
|
||||||
|
unmatched_error_rate = num_unmatched_errors / len(generated_output)
|
||||||
|
backwards_error_rate = num_backwards_errors / len(generated_output)
|
||||||
|
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
|
||||||
|
return type_beat_errors_dict
|
||||||
|
|
||||||
|
def check_syntax_errors_in_inference_for_remi(generated_output, vocab):
|
||||||
|
generated_output = generated_output.squeeze(0)
|
||||||
|
# to check duration errors
|
||||||
|
beat_mask = vocab.total_mask['beat'].to(generated_output.device)
|
||||||
|
beat_mask_for_target = beat_mask[generated_output]
|
||||||
|
beat_target = generated_output * beat_mask_for_target
|
||||||
|
bar_mask = vocab.total_mask['type'].to(generated_output.device)
|
||||||
|
bar_mask_for_target = bar_mask[generated_output]
|
||||||
|
bar_target = (generated_output+1) * bar_mask_for_target # as bar token in 0 in remi vocab, we add 1 to bar token
|
||||||
|
target = beat_target + bar_target
|
||||||
|
target = target[target!=0]
|
||||||
|
# collect beats in between bars(idx=1)
|
||||||
|
num_backwards_errors = 0
|
||||||
|
collected_beats = []
|
||||||
|
total_beats = 0
|
||||||
|
for token in target:
|
||||||
|
if token == 1 or 3 <= token <= 26: # Bar_None, or Bar_time_signature
|
||||||
|
collected_beats_tensor = torch.tensor(collected_beats)
|
||||||
|
diff = torch.diff(collected_beats_tensor)
|
||||||
|
num_error_beats = torch.where(diff<=0)[0].shape[0]
|
||||||
|
num_backwards_errors += num_error_beats
|
||||||
|
collected_beats = []
|
||||||
|
else:
|
||||||
|
collected_beats.append(token.item())
|
||||||
|
total_beats += 1
|
||||||
|
if total_beats != 0:
|
||||||
|
backwards_error_rate = num_backwards_errors / total_beats
|
||||||
|
else:
|
||||||
|
backwards_error_rate = 0
|
||||||
|
# print(f"error rate in beat backwards: {backwards_error_rate}")
|
||||||
|
return {'beat_backwards_error': backwards_error_rate}
|
||||||
|
|
||||||
|
def type_beat_errors_in_validation_nb(beat_prob, answer_type, input_beat, mask):
|
||||||
|
bool_mask = mask.bool().flatten() # (b*t)
|
||||||
|
pred_beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
|
||||||
|
valid_pred_beat_idx = pred_beat_idx[bool_mask] # valid beat_idx
|
||||||
|
answer_type = answer_type.flatten() # (b*t)
|
||||||
|
valid_type_input = answer_type[bool_mask] # valid answer_type
|
||||||
|
type_beat_list = []
|
||||||
|
for i in range(len(valid_pred_beat_idx)):
|
||||||
|
type_beat_list.append((valid_type_input[i].item(), valid_pred_beat_idx[i].item())) # type, beat
|
||||||
|
input_beat = input_beat.flatten()
|
||||||
|
valid_input_beat = input_beat[bool_mask]
|
||||||
|
|
||||||
|
last_note = 1
|
||||||
|
num_unmatched_errors = 0
|
||||||
|
num_backwards_errors = 0
|
||||||
|
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
|
||||||
|
# update last note
|
||||||
|
if input_beat_idx.item() >= 1: # beat
|
||||||
|
last_note = input_beat_idx.item()
|
||||||
|
if type_beat[0] == 4: # same bar, new beat
|
||||||
|
if type_beat[1] == 0 or type_beat[1] == 1:
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
if type_beat[1] <= last_note:
|
||||||
|
num_backwards_errors += 1
|
||||||
|
elif type_beat[0] >= 5: # new bar, new beat
|
||||||
|
if type_beat[1] == 0:
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
|
||||||
|
|
||||||
|
def type_beat_errors_in_validation_cp(beat_prob, answer_type, input_beat, mask):
|
||||||
|
bool_mask = mask.bool().flatten() # (b*t)
|
||||||
|
beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
|
||||||
|
valid_beat_idx = beat_idx[bool_mask] # valid beat_idx
|
||||||
|
answer_type = answer_type.flatten() # (b*t)
|
||||||
|
valid_type_input = answer_type[bool_mask] # valid answer_type
|
||||||
|
type_beat_list = []
|
||||||
|
for i in range(len(valid_beat_idx)):
|
||||||
|
type_beat_list.append((valid_type_input[i].item(), valid_beat_idx[i].item())) # type, beat
|
||||||
|
input_beat = input_beat.flatten()
|
||||||
|
valid_input_beat = input_beat[bool_mask]
|
||||||
|
|
||||||
|
last_note = 1
|
||||||
|
num_unmatched_errors = 0
|
||||||
|
num_backwards_errors = 0
|
||||||
|
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
|
||||||
|
# update last note
|
||||||
|
if input_beat_idx.item() == 1: # bar
|
||||||
|
last_note = 1
|
||||||
|
elif input_beat_idx.item() >= 2: # new beat
|
||||||
|
last_note = input_beat_idx.item()
|
||||||
|
# check errors
|
||||||
|
if type_beat[0] == 2: # Metrical
|
||||||
|
if type_beat[1] == 0: # ignore
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
elif type_beat[1] >= 2: # new beat
|
||||||
|
if type_beat[1] <= last_note:
|
||||||
|
num_backwards_errors += 1
|
||||||
|
elif type_beat[0] == 3: # Note
|
||||||
|
if type_beat[1] != 0:
|
||||||
|
num_unmatched_errors += 1
|
||||||
|
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
|
||||||
|
|
||||||
|
def get_beat_difference_metric(prob_dict, arranged_prob_dict, mask):
|
||||||
|
orign_beat_prob = prob_dict['beat'] # b x t x vocab_size
|
||||||
|
arranged_beat_prob = arranged_prob_dict['beat'] # b x t x vocab_size
|
||||||
|
|
||||||
|
# calculate similarity between original beat prob and arranged beat prob
|
||||||
|
origin_beat_token = torch.argmax(orign_beat_prob, dim=-1) * mask # b x t
|
||||||
|
arranged_beat_token = torch.argmax(arranged_beat_prob, dim=-1) * mask # b x t
|
||||||
|
num_same_beat = torch.sum(origin_beat_token == arranged_beat_token) - torch.sum(mask==0)
|
||||||
|
num_beat = torch.sum(mask==1)
|
||||||
|
beat_sim = (num_same_beat / num_beat).item() # scalar
|
||||||
|
|
||||||
|
# apply mask, shape of mask: b x t
|
||||||
|
orign_beat_prob = orign_beat_prob * mask.unsqueeze(-1) # b x t x vocab_size
|
||||||
|
arranged_beat_prob = arranged_beat_prob * mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
# calculate cosine similarity between original beat prob and arranged beat prob
|
||||||
|
orign_beat_prob = orign_beat_prob.flatten(0,1) # (b*t) x vocab_size
|
||||||
|
arranged_beat_prob = arranged_beat_prob.flatten(0,1) # (b*t) x vocab_size
|
||||||
|
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
|
||||||
|
beat_cos_sim = cos(orign_beat_prob, arranged_beat_prob) # (b*t)
|
||||||
|
# exclude invalid tokens, zero padding tokens
|
||||||
|
beat_cos_sim = beat_cos_sim[mask.flatten().bool()] # num_valid_tokens
|
||||||
|
beat_cos_sim = torch.mean(beat_cos_sim).item() # scalar
|
||||||
|
return {'beat_cos_sim': beat_cos_sim, 'beat_sim': beat_sim}
|
||||||
|
|
||||||
|
def get_gini_coefficient(generated_output):
|
||||||
|
if len(generated_output.shape) == 3:
|
||||||
|
generated_output = generated_output.squeeze(0).tolist()
|
||||||
|
gen_list = [tuple(x) for x in generated_output]
|
||||||
|
else:
|
||||||
|
gen_list = generated_output.squeeze(0).tolist()
|
||||||
|
counts = Counter(gen_list).values()
|
||||||
|
sorted_counts = sorted(counts)
|
||||||
|
n = len(sorted_counts)
|
||||||
|
cumulative_counts = np.cumsum(sorted_counts)
|
||||||
|
cumulative_proportion = cumulative_counts / cumulative_counts[-1]
|
||||||
|
|
||||||
|
lorenz_area = sum(cumulative_proportion[:-1]) / n # Exclude the last element
|
||||||
|
equality_area = 0.5 # The area under line of perfect equality
|
||||||
|
|
||||||
|
gini = (equality_area - lorenz_area) / equality_area
|
||||||
|
return gini
|
||||||
78
Amadeus/symbolic_encoding/midi2audio.py
Normal file
78
Amadeus/symbolic_encoding/midi2audio.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
'''
|
||||||
|
This file is a modified version of midi2audio.py from https://github.com/bzamecnik/midi2audio
|
||||||
|
Author: Bohumír Zámečník (@bzamecnik)
|
||||||
|
License: MIT, see the LICENSE file
|
||||||
|
'''
|
||||||
|
|
||||||
|
__all__ = ['FluidSynth']
|
||||||
|
|
||||||
|
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
|
||||||
|
DEFAULT_SAMPLE_RATE = 48000
|
||||||
|
DEFAULT_GAIN = 0.05
|
||||||
|
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
|
||||||
|
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
|
||||||
|
# DEFAULT_SAMPLE_RATE = 16000
|
||||||
|
# DEFAULT_GAIN = 0.20
|
||||||
|
|
||||||
|
class FluidSynth():
|
||||||
|
def __init__(self, sound_font=DEFAULT_SOUND_FONT, sample_rate=DEFAULT_SAMPLE_RATE, gain=DEFAULT_GAIN):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.sound_font = os.path.expanduser(sound_font)
|
||||||
|
self.gain = gain
|
||||||
|
|
||||||
|
def midi_to_audio(self, midi_file: str, audio_file: str, verbose=True):
|
||||||
|
if verbose:
|
||||||
|
stdout = None
|
||||||
|
else:
|
||||||
|
stdout = subprocess.DEVNULL
|
||||||
|
|
||||||
|
# Convert MIDI to WAV
|
||||||
|
subprocess.call(
|
||||||
|
['fluidsynth', '-ni', '-g', str(self.gain), self.sound_font, midi_file, '-F', audio_file, '-r', str(self.sample_rate)],
|
||||||
|
stdout=stdout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert WAV to MP3
|
||||||
|
# mp3_path = audio_file.replace('.wav', '.mp3')
|
||||||
|
# AudioSegment.from_wav(audio_file).export(mp3_path, format="mp3")
|
||||||
|
|
||||||
|
# # Delete the temporary WAV file
|
||||||
|
# os.remove(audio_file)
|
||||||
|
|
||||||
|
def play_midi(self, midi_file):
|
||||||
|
subprocess.call(['fluidsynth', '-i', '-g', str(self.gain), self.sound_font, midi_file, '-r', str(self.sample_rate)])
|
||||||
|
|
||||||
|
def parse_args(allow_synth=True):
|
||||||
|
parser = argparse.ArgumentParser(description='Convert MIDI to audio via FluidSynth')
|
||||||
|
parser.add_argument('midi_file', metavar='MIDI', type=str)
|
||||||
|
if allow_synth:
|
||||||
|
parser.add_argument('audio_file', metavar='AUDIO', type=str, nargs='?')
|
||||||
|
parser.add_argument('-s', '--sound-font', type=str,
|
||||||
|
default=DEFAULT_SOUND_FONT,
|
||||||
|
help='path to a SF2 sound font (default: %s)' % DEFAULT_SOUND_FONT)
|
||||||
|
parser.add_argument('-r', '--sample-rate', type=int, nargs='?',
|
||||||
|
default=DEFAULT_SAMPLE_RATE,
|
||||||
|
help='sample rate in Hz (default: %s)' % DEFAULT_SAMPLE_RATE)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def main(allow_synth=True):
|
||||||
|
args = parse_args(allow_synth)
|
||||||
|
fs = FluidSynth(args.sound_font, args.sample_rate)
|
||||||
|
if allow_synth and args.audio_file:
|
||||||
|
fs.midi_to_audio(args.midi_file, args.audio_file)
|
||||||
|
else:
|
||||||
|
fs.play_midi(args.midi_file)
|
||||||
|
|
||||||
|
def main_play():
|
||||||
|
"""
|
||||||
|
A method for the `midiplay` entry point. It omits the audio file from args.
|
||||||
|
"""
|
||||||
|
main(allow_synth=False)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
65
Amadeus/symbolic_yamls/config-accelerate.yaml
Normal file
65
Amadeus/symbolic_yamls/config-accelerate.yaml
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
defaults:
|
||||||
|
# - nn_params: nb8_embSum_NMT
|
||||||
|
# - nn_params: remi8
|
||||||
|
- nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
||||||
|
# - nn_params: nb8_embSum_diff_t2m_150M_pretraining
|
||||||
|
# - nn_params: nb8_embSum_subPararell
|
||||||
|
# - nn_params: nb8_embSum_diff_t2m_150M
|
||||||
|
|
||||||
|
# - nn_params: nb8_embSum_subFeedForward
|
||||||
|
# - nn_params: nb8_embSum_diff
|
||||||
|
# nn_params: nb8_SA_diff
|
||||||
|
# - nn_params: nb8_embSum_diff_main12head16dim512_ave
|
||||||
|
# - nn_params: nb8_embSum_NMT_main12_head_16_dim512
|
||||||
|
# - nn_params: remi8_main12_head_16_dim512
|
||||||
|
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||||
|
|
||||||
|
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||||
|
captions_path: dataset/midicaps/train_set.json
|
||||||
|
|
||||||
|
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||||
|
# captions_path: dataset/symphonyNet/syd-caption.json
|
||||||
|
|
||||||
|
use_ddp: True # True, False | distributed data parallel
|
||||||
|
use_fp16: True # True, False | mixed precision training
|
||||||
|
use_diff: True # True,use diffusion in subdecoder
|
||||||
|
diff_steps: 8 # number of diffusion steps
|
||||||
|
use_dispLoss: True
|
||||||
|
lambda_weight: 0.5
|
||||||
|
tau: 0.5
|
||||||
|
|
||||||
|
train_params:
|
||||||
|
device: cuda
|
||||||
|
batch_size: 3
|
||||||
|
grad_clip: 1.0
|
||||||
|
num_iter: 300000 # total number of iterations
|
||||||
|
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
||||||
|
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
|
||||||
|
iterations_per_training_cycle: 10 # number of iterations for logging training loss
|
||||||
|
iterations_per_validation_cycle: 5000 # number of iterations for validation process
|
||||||
|
input_length: 3072 # input sequence length3072
|
||||||
|
# you can use focal loss, it it's not used, set focal_gamma to 0
|
||||||
|
focal_alpha: 1
|
||||||
|
focal_gamma: 0
|
||||||
|
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||||
|
scheduler : cosinelr
|
||||||
|
initial_lr: 0.00005
|
||||||
|
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||||
|
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||||
|
warmup_steps: 2000 #number of warmup steps
|
||||||
|
max_lr: 0.00015
|
||||||
|
gamma: 0.6 # the decay rate for 'cosineannealingwarmuprestarts'
|
||||||
|
# Distributed Data Parallel
|
||||||
|
world_size: 5 # 0 means no distributed training
|
||||||
|
gradient_accumulation_steps: 4 # 1 means no gradient accumulation
|
||||||
|
inference_params:
|
||||||
|
num_uncond_generation: 1 # number of unconditional generation
|
||||||
|
num_cond_generation: 3 # number of conditional generation
|
||||||
|
data_params:
|
||||||
|
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
|
||||||
|
split_ratio: 0.998 # train-validation-test split ratio
|
||||||
|
aug_type: pitch # random, null | pitch and chord augmentation type
|
||||||
|
general:
|
||||||
|
debug: False
|
||||||
|
make_log: True # True, False | update the log file in wandb online to your designated project and entity
|
||||||
|
infer_and_log: True # True, False | inference and log the results
|
||||||
54
Amadeus/symbolic_yamls/config.yaml
Normal file
54
Amadeus/symbolic_yamls/config.yaml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
defaults:
|
||||||
|
# - nn_params: nb8_embSum_NMT
|
||||||
|
# - nn_params: remi8
|
||||||
|
# - nn_params: nb8_embSum_diff
|
||||||
|
- nn_params: nb8_embSum_subFeedForward
|
||||||
|
# - nn_params: nb8_SA_diff
|
||||||
|
# - nn_params: nb8_embSum_diff_main12head16dim512_ave
|
||||||
|
# - nn_params: nb8_embSum_NMT_main12_head_16_dim512
|
||||||
|
# - nn_params: remi8_main12_head_16_dim512
|
||||||
|
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||||
|
|
||||||
|
dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean
|
||||||
|
use_ddp: True # True, False | distributed data parallel
|
||||||
|
use_fp16: True # True, False | mixed precision training
|
||||||
|
use_diff: True # True,use diffusion in subdecoder
|
||||||
|
use_dispLoss: True
|
||||||
|
lambda_weight: 0.5
|
||||||
|
tau: 0.5
|
||||||
|
diff_steps: 8 # number of diffusion steps
|
||||||
|
train_params:
|
||||||
|
device: cuda
|
||||||
|
batch_size: 8
|
||||||
|
grad_clip: 1.0
|
||||||
|
num_iter: 25000 # total number of iterations
|
||||||
|
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
||||||
|
num_cycles_for_model_checkpoint: 10 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
|
||||||
|
iterations_per_training_cycle: 10 # number of iterations for logging training loss
|
||||||
|
iterations_per_validation_cycle: 500 # number of iterations for validation process
|
||||||
|
input_length: 3072 # input sequence length3072
|
||||||
|
# you can use focal loss, it it's not used, set focal_gamma to 0
|
||||||
|
focal_alpha: 1
|
||||||
|
focal_gamma: 0
|
||||||
|
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||||
|
scheduler : cosinelr
|
||||||
|
initial_lr: 0.0001
|
||||||
|
decay_step_rate: 0.4 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||||
|
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||||
|
warmup_steps: 2000 # number of warmup steps
|
||||||
|
max_lr: 0.00015
|
||||||
|
gamma: 0.6 # the decay rate for 'cosineannealingwarmuprestarts'
|
||||||
|
# Distributed Data Parallel
|
||||||
|
world_size: 5 # 0 means no distributed training
|
||||||
|
gradient_accumulation_steps: 1 # 1 means no gradient accumulation
|
||||||
|
inference_params:
|
||||||
|
num_uncond_generation: 1 # number of unconditional generation
|
||||||
|
num_cond_generation: 3 # number of conditional generation
|
||||||
|
data_params:
|
||||||
|
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
|
||||||
|
split_ratio: 0.99 # train-validation-test split ratio
|
||||||
|
aug_type: null # random, null | pitch and chord augmentation type
|
||||||
|
general:
|
||||||
|
debug: False
|
||||||
|
make_log: True # True, False | update the log file in wandb online to your designated project and entity
|
||||||
|
infer_and_log: True # True, False | inference and log the results
|
||||||
20
Amadeus/symbolic_yamls/nn_params/cp5_embSum_NMT.yaml
Normal file
20
Amadeus/symbolic_yamls/nn_params/cp5_embSum_NMT.yaml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
input_length: 1024
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
input_length: 1024
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: FeedForward
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: FeedForward
|
||||||
|
model_dropout: 0.1
|
||||||
|
partial_sequential_prediction: True
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
19
Amadeus/symbolic_yamls/nn_params/cp7_embSum_NMT.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/cp7_embSum_NMT.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: FeedForward
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
encoding_scheme: cp
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabCP
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: FeedForward
|
||||||
|
model_dropout: 0.1
|
||||||
|
partial_sequential_prediction: True
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
input_length: 1024
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb5_embSum_NMT.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb5_embSum_NMT.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb5_embSum_diff.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb5_embSum_diff.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 3
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 768
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 3
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: FeedForward
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
18
Amadeus/symbolic_yamls/nn_params/nb5_embSum_subPararell.yaml
Normal file
18
Amadeus/symbolic_yamls/nn_params/nb5_embSum_subPararell.yaml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: Parallel
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
18
Amadeus/symbolic_yamls/nn_params/nb5_embSum_subRNN.yaml
Normal file
18
Amadeus/symbolic_yamls/nn_params/nb5_embSum_subRNN.yaml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: RNN
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: SelfAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb7_embSum_NMT.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb7_embSum_NMT.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: FeedForward
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
18
Amadeus/symbolic_yamls/nn_params/nb7_embSum_subPararell.yaml
Normal file
18
Amadeus/symbolic_yamls/nn_params/nb7_embSum_subPararell.yaml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: Parallel
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
18
Amadeus/symbolic_yamls/nn_params/nb7_embSum_subRNN.yaml
Normal file
18
Amadeus/symbolic_yamls/nn_params/nb7_embSum_subRNN.yaml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: RNN
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: SelfAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb8_SA_diff.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb8_SA_diff.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SelfAttentionEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_NMT.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_NMT.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 3
|
||||||
|
feature_enricher_use: True
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_NMTsub6.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_NMTsub6.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: CrossAttention
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 6
|
||||||
|
feature_enricher_use: True
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_150M.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_150M.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.2
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 768
|
||||||
|
num_layer: 16
|
||||||
|
num_head: 12
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
|
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: AverageEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: True
|
||||||
|
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 3
|
||||||
|
feature_enricher_use: True
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_sub3.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_sub3.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 2
|
||||||
|
feature_enricher_use: False
|
||||||
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_sub6.yaml
Normal file
19
Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_sub6.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 6
|
||||||
|
feature_enricher_use: True
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerCrossAttendDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.2
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 768
|
||||||
|
num_layer: 16
|
||||||
|
num_head: 12
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerFinetuningDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.2
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 768
|
||||||
|
num_layer: 20
|
||||||
|
num_head: 12
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerPrefixDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 768
|
||||||
|
num_layer: 16
|
||||||
|
num_head: 12
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerPretrainingDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 768
|
||||||
|
num_layer: 20
|
||||||
|
num_head: 12
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerCrossAttendDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: FeedForward
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
18
Amadeus/symbolic_yamls/nn_params/nb8_embSum_subPararell.yaml
Normal file
18
Amadeus/symbolic_yamls/nn_params/nb8_embSum_subPararell.yaml
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: Parallel
|
||||||
|
model_dropout: 0.1
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 6
|
||||||
|
num_head: 8
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
12
Amadeus/symbolic_yamls/nn_params/remi5.yaml
Normal file
12
Amadeus/symbolic_yamls/nn_params/remi5.yaml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
encoding_scheme: remi
|
||||||
|
num_features: 5
|
||||||
|
vocab_name: LangTokenVocab
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SingleEmbedding
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: SingleProjection
|
||||||
|
model_dropout: 0.1
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 8
|
||||||
|
num_head: 8
|
||||||
12
Amadeus/symbolic_yamls/nn_params/remi7.yaml
Normal file
12
Amadeus/symbolic_yamls/nn_params/remi7.yaml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
encoding_scheme: remi
|
||||||
|
num_features: 7
|
||||||
|
vocab_name: LangTokenVocab
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SingleEmbedding
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: SingleProjection
|
||||||
|
model_dropout: 0.1
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 8
|
||||||
|
num_head: 8
|
||||||
12
Amadeus/symbolic_yamls/nn_params/remi8.yaml
Normal file
12
Amadeus/symbolic_yamls/nn_params/remi8.yaml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
encoding_scheme: remi
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: LangTokenVocab
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SingleEmbedding
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: SingleProjection
|
||||||
|
model_dropout: 0.1
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 8
|
||||||
|
num_head: 8
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
encoding_scheme: remi
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: LangTokenVocab
|
||||||
|
model_name: NestedMusicTransformer
|
||||||
|
input_embedder_name: SingleEmbedding
|
||||||
|
main_decoder_name: XtransformerDecoder
|
||||||
|
sub_decoder_name: SingleProjection
|
||||||
|
model_dropout: 0.1
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 512
|
||||||
|
num_layer: 12
|
||||||
|
num_head: 16
|
||||||
17
Amadeus/symbolic_yamls/symbolic_sweep.yaml
Normal file
17
Amadeus/symbolic_yamls/symbolic_sweep.yaml
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
program: train.py
|
||||||
|
method: grid
|
||||||
|
metric:
|
||||||
|
name: valid.total
|
||||||
|
goal: minimize
|
||||||
|
parameters:
|
||||||
|
train_params.batch_size:
|
||||||
|
values: [8]
|
||||||
|
train_params.focal_gamma:
|
||||||
|
values: [0, 1]
|
||||||
|
nn_params.main_decoder.input_length:
|
||||||
|
values: [8192]
|
||||||
|
|
||||||
|
command:
|
||||||
|
- python3
|
||||||
|
- ${program}
|
||||||
|
- ${args_no_hyphens}
|
||||||
428
Amadeus/train_utils.py
Normal file
428
Amadeus/train_utils.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from numpy import mask_indices
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from collections import defaultdict
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def add_conti_for_single_feature(tensor):
|
||||||
|
new_target = tensor.clone()
|
||||||
|
# Assuming tensor shape is [batch, sequence, features]
|
||||||
|
# Create a shifted version of the tensor
|
||||||
|
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
|
||||||
|
# The first element of each sequence cannot be a duplicate by definition
|
||||||
|
shifted_tensor[:, 0] = new_target[:, 0] + 1
|
||||||
|
|
||||||
|
# Identify where the original and shifted tensors are the same (duplicates)
|
||||||
|
duplicates = new_target == shifted_tensor
|
||||||
|
# Replace duplicates with 9999
|
||||||
|
new_target[duplicates] = 9999
|
||||||
|
return new_target
|
||||||
|
|
||||||
|
def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_params):
|
||||||
|
feature_prediction_order_dict = {
|
||||||
|
4: ["type", "beat", "pitch", "duration"],
|
||||||
|
5: ["type", "beat", "instrument", "pitch", "duration"],
|
||||||
|
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||||
|
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding_scheme == 'remi':
|
||||||
|
prediction_order = feature_prediction_order_dict[num_features]
|
||||||
|
elif encoding_scheme == 'cp':
|
||||||
|
if nn_params.get("partial_sequential_prediction", False):
|
||||||
|
default_prediction_order = feature_prediction_order_dict[num_features]
|
||||||
|
prediction_order = [default_prediction_order[0], default_prediction_order[1:]]
|
||||||
|
else:
|
||||||
|
prediction_order = feature_prediction_order_dict[num_features]
|
||||||
|
elif encoding_scheme == 'nb':
|
||||||
|
assert target_feature in feature_prediction_order_dict[num_features], f"Target feature {target_feature} not in the selected sub-token set. Please check target feature in the config and num_features in nn_params."
|
||||||
|
default_prediction_order = feature_prediction_order_dict[num_features]
|
||||||
|
|
||||||
|
# Reorganize the prediction order based on the target_feature
|
||||||
|
target_index = default_prediction_order.index(target_feature)
|
||||||
|
prediction_order = default_prediction_order[target_index:] + default_prediction_order[:target_index]
|
||||||
|
|
||||||
|
return prediction_order
|
||||||
|
|
||||||
|
########################### Loss function ################################
|
||||||
|
|
||||||
|
class NLLLoss4REMI():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
focal_alpha:float,
|
||||||
|
focal_gamma:float,
|
||||||
|
):
|
||||||
|
self.alpha = focal_alpha
|
||||||
|
self.gamma = focal_gamma
|
||||||
|
|
||||||
|
def get_nll_loss(self, logits, target, mask):
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
if probs.ndim == 3:
|
||||||
|
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||||
|
if target.ndim == 2:
|
||||||
|
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
# clamp min value to 1e-7 to avoid log(0)
|
||||||
|
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||||
|
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||||
|
loss_seq = loss * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
|
||||||
|
return loss, loss_seq
|
||||||
|
|
||||||
|
def __call__(self, logits, shifted_tgt, mask, vocab):
|
||||||
|
if vocab is not None:
|
||||||
|
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
|
||||||
|
loss_by_class_normal = defaultdict(float)
|
||||||
|
shifted_tgt_with_mask = shifted_tgt * mask # [b, t]
|
||||||
|
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
|
||||||
|
for feature in vocab.feature_list:
|
||||||
|
feature_mask = vocab.total_mask[feature].to(answers_idx.device) # [327,]
|
||||||
|
mask_for_target = feature_mask[answers_idx] # [b*t]
|
||||||
|
normal_loss_seq_by_class = loss_seq * mask_for_target
|
||||||
|
if mask_for_target.sum().item() != 0:
|
||||||
|
loss_by_class_normal[feature+'_normal'] += (normal_loss_seq_by_class.sum().item() / mask_for_target.sum().item())
|
||||||
|
return loss, loss_by_class_normal
|
||||||
|
else:
|
||||||
|
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
|
||||||
|
return loss, None
|
||||||
|
|
||||||
|
class NLLLoss4CompoundToken():
|
||||||
|
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
|
||||||
|
self.feature_list = feature_list
|
||||||
|
self.alpha = focal_alpha
|
||||||
|
self.gamma = focal_gamma
|
||||||
|
|
||||||
|
def get_nll_loss(self, logits, target, mask):
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
if probs.ndim == 3:
|
||||||
|
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||||
|
if target.ndim == 2:
|
||||||
|
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
# clamp min value to 1e-7 to avoid log(0)
|
||||||
|
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||||
|
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||||
|
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token):
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
|
||||||
|
if ignore_token is not None and conti_token is not None:
|
||||||
|
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
|
||||||
|
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
|
||||||
|
elif ignore_token is not None and conti_token is None:
|
||||||
|
valid_mask = (target != ignore_token)
|
||||||
|
elif ignore_token is None and conti_token is None:
|
||||||
|
valid_mask = torch.ones_like(target).bool()
|
||||||
|
valid_mask = valid_mask.flatten(0, 1)
|
||||||
|
|
||||||
|
if probs.ndim == 3:
|
||||||
|
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||||
|
if target.ndim == 2:
|
||||||
|
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
pt = probs[torch.arange(len(target)), target] # [batch_size*seq_len]
|
||||||
|
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
|
||||||
|
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
|
||||||
|
loss = loss * total_mask # [batch_size*seq_len]
|
||||||
|
loss = loss.sum() / total_mask.sum() # calculating mean loss considering mask
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def __call__(self, logits_dict, shifted_tgt, mask, valid):
|
||||||
|
train_loss_list = []
|
||||||
|
log_loss_dict_normal = {}
|
||||||
|
for idx, key in enumerate(self.feature_list):
|
||||||
|
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
|
||||||
|
train_loss_list.append(training_loss)
|
||||||
|
if valid:
|
||||||
|
if key == 'type':
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None)
|
||||||
|
elif key == 'beat':
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
|
||||||
|
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
|
||||||
|
else:
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None)
|
||||||
|
k_normal = key + '_normal'
|
||||||
|
log_loss_dict_normal[k_normal] = log_normal_loss
|
||||||
|
total_loss = sum(train_loss_list) / len(train_loss_list)
|
||||||
|
if valid:
|
||||||
|
return total_loss, log_loss_dict_normal
|
||||||
|
else:
|
||||||
|
return total_loss, None
|
||||||
|
|
||||||
|
def dispersive_loss(z, tau=0.5, eps=1e-8):
|
||||||
|
"""使用余弦距离的Dispersive Loss实现"""
|
||||||
|
B = z.size(0)
|
||||||
|
|
||||||
|
# 计算余弦相似度矩阵 [B, B]
|
||||||
|
z_norm = torch.nn.functional.normalize(z, p=2, dim=1) # 向量归一化
|
||||||
|
sim_matrix = torch.matmul(z_norm, z_norm.transpose(0, 1)) # 余弦相似度
|
||||||
|
|
||||||
|
# 转换为余弦距离 (1 - 相似度),排除对角线
|
||||||
|
mask = 1 - torch.eye(B, device=z.device)
|
||||||
|
cos_dist = (1 - sim_matrix) * mask
|
||||||
|
|
||||||
|
# 计算分散性损失(与L2版本相同)
|
||||||
|
exp_term = torch.exp(-cos_dist / tau)
|
||||||
|
mean_exp = exp_term.sum() / (B * (B - 1) + eps)
|
||||||
|
loss = -torch.log(mean_exp + eps)
|
||||||
|
return loss
|
||||||
|
class DiffusionLoss4CompoundToken():
|
||||||
|
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
|
||||||
|
self.feature_list = feature_list
|
||||||
|
self.alpha = focal_alpha
|
||||||
|
self.gamma = focal_gamma
|
||||||
|
|
||||||
|
def get_nll_loss(self, logits, target, mask,mask_indices, p_mask):
|
||||||
|
if logits.ndim == 3:
|
||||||
|
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||||
|
if target.ndim == 2:
|
||||||
|
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
if mask_indices.ndim == 2:
|
||||||
|
mask_indices = mask_indices.flatten(0, 1)
|
||||||
|
if p_mask.ndim == 2:
|
||||||
|
p_mask = p_mask.flatten(0, 1)
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.flatten(0, 1)
|
||||||
|
# datatype of logits, target, mask_indices, p_mask should be the same
|
||||||
|
token_loss = F.cross_entropy(
|
||||||
|
logits[mask_indices], # 直接索引 logits
|
||||||
|
target[mask_indices],
|
||||||
|
reduction='none'
|
||||||
|
) / p_mask[mask_indices]
|
||||||
|
loss = (token_loss * mask[mask_indices]).sum() / mask[mask_indices].sum()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token, mask_indices, p_mask):
|
||||||
|
if ignore_token is not None and conti_token is not None:
|
||||||
|
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
|
||||||
|
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
|
||||||
|
elif ignore_token is not None and conti_token is None:
|
||||||
|
valid_mask = (target != ignore_token)
|
||||||
|
elif ignore_token is None and conti_token is None:
|
||||||
|
valid_mask = torch.ones_like(target).bool()
|
||||||
|
valid_mask = valid_mask.flatten(0, 1)
|
||||||
|
|
||||||
|
if logits.ndim == 3:
|
||||||
|
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||||
|
if target.ndim == 2:
|
||||||
|
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
if mask_indices.ndim == 2:
|
||||||
|
mask_indices = mask_indices.flatten(0, 1)
|
||||||
|
if p_mask.ndim == 2:
|
||||||
|
p_mask = p_mask.flatten(0, 1)
|
||||||
|
token_loss = F.cross_entropy(
|
||||||
|
logits[mask_indices], # 直接索引 logits
|
||||||
|
target[mask_indices],
|
||||||
|
reduction='none'
|
||||||
|
) / p_mask[mask_indices]
|
||||||
|
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
|
||||||
|
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
|
||||||
|
train_loss_list = []
|
||||||
|
log_loss_dict_normal = {}
|
||||||
|
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||||
|
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
|
||||||
|
disp_loss = None
|
||||||
|
if input_dict is not None:
|
||||||
|
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
|
||||||
|
feat = hidden_vec.mean(dim=1) #bs,dim
|
||||||
|
disp_loss = dispersive_loss(feat, tau=tau) # scalar
|
||||||
|
for idx, key in enumerate(self.feature_list):
|
||||||
|
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
|
||||||
|
train_loss_list.append(training_loss)
|
||||||
|
if valid:
|
||||||
|
if key == 'type':
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||||
|
elif key == 'beat':
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||||
|
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||||
|
else:
|
||||||
|
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
|
||||||
|
k_normal = key + '_normal'
|
||||||
|
log_loss_dict_normal[k_normal] = log_normal_loss
|
||||||
|
total_loss = sum(train_loss_list) / len(train_loss_list)
|
||||||
|
if disp_loss is not None:
|
||||||
|
total_loss = total_loss + lambda_weight * disp_loss
|
||||||
|
log_loss_dict_normal['dispersion'] = disp_loss.item()
|
||||||
|
if valid:
|
||||||
|
return total_loss, log_loss_dict_normal
|
||||||
|
else:
|
||||||
|
return total_loss, None
|
||||||
|
|
||||||
|
class EncodecFlattenLoss():
|
||||||
|
def __init__(self, feature_list):
|
||||||
|
self.feature_list = feature_list
|
||||||
|
|
||||||
|
def get_nll_loss(self, logits, target, mask):
|
||||||
|
probs = logits.softmax(dim=-1)
|
||||||
|
if probs.ndim == 3:
|
||||||
|
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
|
||||||
|
if target.ndim == 2:
|
||||||
|
target = target.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
|
||||||
|
loss_seq = -torch.log(pt) # [batch_size*seq_len]
|
||||||
|
loss_seq = loss_seq * mask.flatten(0, 1) # [batch_size*seq_len]
|
||||||
|
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def __call__(self, logits, shifted_tgt, mask):
|
||||||
|
loss = self.get_nll_loss(logits, shifted_tgt, mask)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
class EncodecMultiClassLoss(EncodecFlattenLoss):
|
||||||
|
def __init__(self, feature_list):
|
||||||
|
super().__init__(feature_list)
|
||||||
|
|
||||||
|
def __call__(self, logits_dict, shifted_tgt, mask):
|
||||||
|
train_loss_list = []
|
||||||
|
for idx, key in enumerate(self.feature_list):
|
||||||
|
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
|
||||||
|
train_loss_list.append(training_loss)
|
||||||
|
total_loss = sum(train_loss_list) / len(train_loss_list)
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
########################### Learning rate Scheduler ################################
|
||||||
|
'''
|
||||||
|
This scheduler is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/#custom-cosineannealingwarmrestarts-1
|
||||||
|
It's basically a cosine annealing scheduler with warm restarts including two methods, warm up start and reducing maximum lr.
|
||||||
|
'''
|
||||||
|
|
||||||
|
class CosineAnnealingWarmUpRestarts(_LRScheduler):
|
||||||
|
def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1, eta_min=0):
|
||||||
|
if T_0 <= 0 or not isinstance(T_0, int):
|
||||||
|
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
||||||
|
if T_mult < 1 or not isinstance(T_mult, int):
|
||||||
|
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
||||||
|
if T_up < 0 or not isinstance(T_up, int):
|
||||||
|
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
|
||||||
|
self.T_0 = T_0
|
||||||
|
self.T_mult = T_mult
|
||||||
|
self.base_eta_max = eta_max
|
||||||
|
self.eta_max = eta_max
|
||||||
|
self.T_up = T_up
|
||||||
|
self.T_i = T_0
|
||||||
|
self.gamma = gamma
|
||||||
|
self.cycle = 0
|
||||||
|
self.T_cur = last_epoch
|
||||||
|
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if self.T_cur == -1:
|
||||||
|
return self.base_lrs
|
||||||
|
elif self.T_cur < self.T_up:
|
||||||
|
return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
|
||||||
|
else:
|
||||||
|
return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
|
||||||
|
for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
def step(self, epoch=None):
|
||||||
|
if epoch is None:
|
||||||
|
epoch = self.last_epoch + 1
|
||||||
|
self.T_cur = self.T_cur + 1
|
||||||
|
if self.T_cur >= self.T_i:
|
||||||
|
self.cycle += 1
|
||||||
|
self.T_cur = self.T_cur - self.T_i
|
||||||
|
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
|
||||||
|
else:
|
||||||
|
if epoch >= self.T_0:
|
||||||
|
if self.T_mult == 1:
|
||||||
|
self.T_cur = epoch % self.T_0
|
||||||
|
self.cycle = epoch // self.T_0
|
||||||
|
else:
|
||||||
|
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
|
||||||
|
self.cycle = n
|
||||||
|
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
||||||
|
self.T_i = self.T_0 * self.T_mult ** (n)
|
||||||
|
else:
|
||||||
|
self.T_i = self.T_0
|
||||||
|
self.T_cur = epoch
|
||||||
|
|
||||||
|
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
|
||||||
|
self.last_epoch = math.floor(epoch)
|
||||||
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||||
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
class CosineLRScheduler(_LRScheduler):
|
||||||
|
"""Cosine LR scheduler.
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Torch optimizer.
|
||||||
|
warmup_steps (int): Number of warmup steps.
|
||||||
|
total_steps (int): Total number of steps.
|
||||||
|
lr_min_ratio (float): Minimum learning rate.
|
||||||
|
cycle_length (float): Cycle length.
|
||||||
|
"""
|
||||||
|
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
|
||||||
|
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
assert self.warmup_steps >= 0
|
||||||
|
self.total_steps = total_steps
|
||||||
|
assert self.total_steps >= 0
|
||||||
|
self.lr_min_ratio = lr_min_ratio
|
||||||
|
self.cycle_length = cycle_length
|
||||||
|
super().__init__(optimizer)
|
||||||
|
|
||||||
|
def _get_sched_lr(self, lr: float, step: int):
|
||||||
|
if step < self.warmup_steps:
|
||||||
|
lr_ratio = step / self.warmup_steps
|
||||||
|
lr = lr_ratio * lr
|
||||||
|
elif step <= self.total_steps:
|
||||||
|
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
||||||
|
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
|
||||||
|
(1. + math.cos(math.pi * s / self.cycle_length))
|
||||||
|
lr = lr_ratio * lr
|
||||||
|
else:
|
||||||
|
lr_ratio = self.lr_min_ratio
|
||||||
|
lr = lr_ratio * lr
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
|
class DispersiveLoss(nn.Module):
|
||||||
|
def __init__(self, loss_type='infonce_l2', tau=0.5, lambda_weight=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_type = loss_type
|
||||||
|
self.tau = tau
|
||||||
|
self.lambda_weight = lambda_weight
|
||||||
|
|
||||||
|
def forward(self, features, diffusion_loss):
|
||||||
|
"""
|
||||||
|
features: 批次特征矩阵,形状为 [batch_size, feature_dim]
|
||||||
|
diffusion_loss: 原扩散损失
|
||||||
|
"""
|
||||||
|
batch_size = features.size(0)
|
||||||
|
|
||||||
|
# 计算距离矩阵
|
||||||
|
if self.loss_type == 'infonce_l2':
|
||||||
|
# 计算平方L2距离
|
||||||
|
dist_matrix = torch.cdist(features, features, p=2) ** 2
|
||||||
|
# 计算分散损失
|
||||||
|
exp_dist = torch.exp(-dist_matrix / self.tau)
|
||||||
|
disp_loss = torch.log(exp_dist.mean())
|
||||||
|
elif self.loss_type == 'hinge':
|
||||||
|
# Hinge损失,假设阈值epsilon=1.0
|
||||||
|
dist_matrix = torch.cdist(features, features, p=2)
|
||||||
|
disp_loss = torch.max(torch.zeros_like(dist_matrix), 1.0 - dist_matrix).mean()
|
||||||
|
elif self.loss_type == 'covariance':
|
||||||
|
# 协方差损失
|
||||||
|
normalized_features = (features - features.mean(dim=0)) / features.std(dim=0)
|
||||||
|
cov_matrix = torch.matmul(normalized_features.T, normalized_features) / batch_size
|
||||||
|
# 非对角线元素平方和
|
||||||
|
mask = ~torch.eye(cov_matrix.size(0), dtype=torch.bool)
|
||||||
|
disp_loss = (cov_matrix[mask] ** 2).mean()
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported loss type")
|
||||||
|
|
||||||
|
# 总损失 = 扩散损失 + lambda * 分散损失
|
||||||
|
total_loss = diffusion_loss + self.lambda_weight * disp_loss
|
||||||
|
return total_loss, disp_loss
|
||||||
1012
Amadeus/trainer_accelerate.py
Normal file
1012
Amadeus/trainer_accelerate.py
Normal file
File diff suppressed because it is too large
Load Diff
949
Amadeus/transformer_utils.py
Normal file
949
Amadeus/transformer_utils.py
Normal file
@ -0,0 +1,949 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from x_transformers import Decoder, Encoder, PrefixDecoder, CrossAttender
|
||||||
|
from transformers import T5EncoderModel
|
||||||
|
from data_representation.vocab_utils import LangTokenVocab
|
||||||
|
|
||||||
|
class PosEncoding(nn.Module):
|
||||||
|
def __init__(self, emb_size, max_t):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_size =emb_size
|
||||||
|
self.max_t = max_t
|
||||||
|
self.register_buffer('encoding', self._prepare_emb())
|
||||||
|
|
||||||
|
def _prepare_emb(self):
|
||||||
|
dim_axis = 10000**(torch.arange(self.emb_size//2) * 2 / self.emb_size) # 10000 ** (normalized values between 0~1 num_emb_dim)
|
||||||
|
timesteps = torch.arange(self.max_t)
|
||||||
|
pos_enc_in = timesteps.unsqueeze(1) / dim_axis.unsqueeze(0)
|
||||||
|
pos_enc_sin = torch.sin(pos_enc_in) # x values for sin are between 0 ~ 1 so the values could never be the same
|
||||||
|
pos_enc_cos = torch.cos(pos_enc_in)
|
||||||
|
|
||||||
|
pos_enc = torch.stack([pos_enc_sin, pos_enc_cos], dim=-1).reshape([self.max_t, self.emb_size])
|
||||||
|
return pos_enc
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoding[x]
|
||||||
|
|
||||||
|
class ResidualLayerNormModule(nn.Module):
|
||||||
|
def __init__(self, submodule):
|
||||||
|
super().__init__()
|
||||||
|
self.submodule = submodule
|
||||||
|
self.layer_norm = nn.LayerNorm(self.submodule.input_size)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, y=None):
|
||||||
|
if y is not None:
|
||||||
|
res_x = self.submodule(x, y, mask)
|
||||||
|
elif mask is not None:
|
||||||
|
res_x = self.submodule(x, mask)
|
||||||
|
else:
|
||||||
|
res_x = self.submodule(x)
|
||||||
|
x = x + res_x
|
||||||
|
return self.layer_norm(x)
|
||||||
|
|
||||||
|
class SingleEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab,
|
||||||
|
dim_model,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Embedding layer for REMI
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
vocab_size = vocab.get_vocab_size()
|
||||||
|
self.embedding = nn.Embedding(vocab_size, dim_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.embedding(x)
|
||||||
|
|
||||||
|
class MultiEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
dim_model:int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
'''
|
||||||
|
Embedding layer for compound tokens
|
||||||
|
'''
|
||||||
|
self.vocab_size = vocab.get_vocab_size()
|
||||||
|
self.feature_list = vocab.feature_list
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.layers = []
|
||||||
|
|
||||||
|
self._make_emb_layers()
|
||||||
|
self._init_params()
|
||||||
|
self._make_emb_boundaries_by_key()
|
||||||
|
|
||||||
|
def _init_params(self):
|
||||||
|
# apply kaiming init
|
||||||
|
for layer in self.layers:
|
||||||
|
if isinstance(layer, nn.Embedding):
|
||||||
|
nn.init.kaiming_normal_(layer.weight)
|
||||||
|
|
||||||
|
def _make_emb_layers(self):
|
||||||
|
vocab_sizes = [self.vocab_size[key] for key in self.feature_list]
|
||||||
|
self.embedding_sizes = [self.dim_model for _ in self.feature_list]
|
||||||
|
for vocab_size, embedding_size in zip(vocab_sizes, self.embedding_sizes):
|
||||||
|
if embedding_size != 0:
|
||||||
|
self.layers.append(nn.Embedding(vocab_size, embedding_size))
|
||||||
|
self.layers = nn.ModuleList(self.layers)
|
||||||
|
|
||||||
|
def _make_emb_boundaries_by_key(self):
|
||||||
|
'''
|
||||||
|
This function returns dict of boundaries for each embedding layer
|
||||||
|
'''
|
||||||
|
self.emb_boundary_by_key = {}
|
||||||
|
start_idx = 0
|
||||||
|
for key, emb_size in zip(self.feature_list, self.embedding_sizes):
|
||||||
|
if emb_size != 0:
|
||||||
|
self.emb_boundary_by_key[key] = (start_idx, start_idx + emb_size)
|
||||||
|
start_idx += emb_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
emb = torch.cat([module(x[..., i]) for i, module in enumerate(self.layers)], dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.layers)
|
||||||
|
|
||||||
|
def get_emb_by_key(self, key, token):
|
||||||
|
layer_idx = self.feature_list.index(key)
|
||||||
|
return self.layers[layer_idx](token)
|
||||||
|
|
||||||
|
class SummationEmbedder(MultiEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
dim_model:int
|
||||||
|
):
|
||||||
|
super().__init__(vocab, dim_model)
|
||||||
|
|
||||||
|
def forward(self, seq):
|
||||||
|
emb_list = [module(seq[..., i]) for i, module in enumerate(self.layers)]
|
||||||
|
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
|
||||||
|
output = torch.sum(stacked_emb, dim=2) # B x T x emb_size
|
||||||
|
return output
|
||||||
|
|
||||||
|
class AverageEmbedder(MultiEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
dim_model:int
|
||||||
|
):
|
||||||
|
super().__init__(vocab, dim_model)
|
||||||
|
|
||||||
|
def forward(self, seq):
|
||||||
|
emb_list = [module(seq[..., i]) for i, module in enumerate(self.layers)]
|
||||||
|
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
|
||||||
|
output = torch.mean(stacked_emb, dim=2) # B x T x emb_size
|
||||||
|
return output
|
||||||
|
|
||||||
|
class SelfAttentionEmbedder(MultiEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
dim_model:int
|
||||||
|
):
|
||||||
|
super().__init__(vocab, dim_model)
|
||||||
|
self.dropout = 0.1
|
||||||
|
|
||||||
|
self.transformer_encoder = Encoder(
|
||||||
|
dim = dim_model,
|
||||||
|
depth = 1,
|
||||||
|
heads = 8,
|
||||||
|
attn_dropout = self.dropout,
|
||||||
|
ff_dropout = self.dropout,
|
||||||
|
attn_flash = True)
|
||||||
|
|
||||||
|
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, self.dim_model), requires_grad=True)
|
||||||
|
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff()
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn()
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self):
|
||||||
|
for layer in self.transformer_encoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(self.dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(self.dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self):
|
||||||
|
for layer in self.transformer_encoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(self.dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_encoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def _apply_window_on_input_vec(self, embeddings):
|
||||||
|
window_size = 1
|
||||||
|
zero_vec = torch.zeros(embeddings.shape[0], window_size-1, embeddings.shape[2], embeddings.shape[3]).to(embeddings.device) # B x (window_size-1) x num_features x emb_size
|
||||||
|
window_applied_input_vec = torch.cat([zero_vec, embeddings], dim=1) # B x (T+window_size-1) x num_features x emb_size
|
||||||
|
window_applied_input_vec = window_applied_input_vec.unfold(1, window_size, 1) # B x T x window_size x emb_size x num_features
|
||||||
|
window_applied_input_vec = window_applied_input_vec.transpose(3, 4) # B x T x window_size x num_features x emb_size
|
||||||
|
window_applied_input_vec = window_applied_input_vec.reshape(embeddings.shape[0]*embeddings.shape[1], -1, embeddings.shape[3]) # (B*T) x (num_features*window_size) x emb_size
|
||||||
|
return window_applied_input_vec
|
||||||
|
|
||||||
|
def _apply_pos_enc(self, tgt):
|
||||||
|
pos = torch.arange(tgt.shape[1]).to(tgt.device) # (num_features*window_size+1)
|
||||||
|
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x (num_features*window_size+1)
|
||||||
|
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x (num_features*window_size+1) x emb_size
|
||||||
|
return tgt_pos
|
||||||
|
|
||||||
|
def forward(self, input_tokens):
|
||||||
|
'''
|
||||||
|
input_tokens: B x T x num_features
|
||||||
|
'''
|
||||||
|
# prepare input vector
|
||||||
|
emb_list = [module(input_tokens[..., i]) for i, module in enumerate(self.layers)] # B x T x 1 x emb_size
|
||||||
|
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
|
||||||
|
# apply window
|
||||||
|
stacked_emb = self._apply_window_on_input_vec(stacked_emb)
|
||||||
|
# add CLS
|
||||||
|
cls = self.cls_embedding.repeat(stacked_emb.shape[0], 1, 1) # (B*T) x 1 x emb_size
|
||||||
|
input_emb = torch.cat([stacked_emb, cls], dim=1) # (B*T) x (num_features*window_size+1) x emb_size
|
||||||
|
output = self.transformer_encoder(input_emb) # (B*T) x (num_features*window_size+1) x emb_size
|
||||||
|
# extract CLS
|
||||||
|
output = output[:, -1, :].reshape((input_tokens.shape[0], input_tokens.shape[1], -1)) # B x T x emb_size
|
||||||
|
return output
|
||||||
|
|
||||||
|
class RVQMultiEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab:LangTokenVocab,
|
||||||
|
dim_model:int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = vocab.get_vocab_size()
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.features = vocab.feature_list
|
||||||
|
self.layers = []
|
||||||
|
self._make_emb_layers()
|
||||||
|
|
||||||
|
def _make_emb_layers(self):
|
||||||
|
vocab_sizes = [self.vocab_size[key] for key in self.features]
|
||||||
|
self.embedding_sizes = [self.dim_model for _ in self.features]
|
||||||
|
for vocab_size, embedding_size in zip(vocab_sizes, self.embedding_sizes):
|
||||||
|
if embedding_size != 0:
|
||||||
|
self.layers.append(nn.Embedding(vocab_size, embedding_size))
|
||||||
|
self.layers = nn.ModuleList(self.layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
embeddings = torch.zeros(x.shape[0], x.shape[1], self.dim_model).to(x.device)
|
||||||
|
emb_list = [module(x[:, (idx+1)%4::4]) for idx, module in enumerate(self.layers)]
|
||||||
|
for idx, emb in enumerate(emb_list):
|
||||||
|
embeddings[:, (idx+1)%4::4] = emb
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def get_emb_by_key(self, key:str, token:torch.Tensor):
|
||||||
|
layer_idx = self.features.index(key)
|
||||||
|
return self.layers[layer_idx](token)
|
||||||
|
|
||||||
|
class XtransformerDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
return self.transformer_decoder(seq)
|
||||||
|
|
||||||
|
class XtransformerCrossAttendDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True,
|
||||||
|
cross_attend = True,
|
||||||
|
only_cross = False)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||||
|
if context_embedding is None:
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask
|
||||||
|
).last_hidden_state
|
||||||
|
else:
|
||||||
|
context = context_embedding
|
||||||
|
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
return self.transformer_decoder(seq, context=context)
|
||||||
|
|
||||||
|
class XtransformerLargeCrossAttendDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True,
|
||||||
|
cross_attend = True,
|
||||||
|
only_cross = False)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||||
|
if context_embedding is None:
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask
|
||||||
|
).last_hidden_state
|
||||||
|
else:
|
||||||
|
context = context_embedding
|
||||||
|
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
return self.transformer_decoder(seq, context=context)
|
||||||
|
|
||||||
|
class NewCrossAttendDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True,
|
||||||
|
cross_attend = True,
|
||||||
|
only_cross = False,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
ff_swish = True, # set this to True
|
||||||
|
ff_glu = True, # set to true to use for all feedforwards
|
||||||
|
)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||||
|
if context_embedding is None:
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask
|
||||||
|
).last_hidden_state
|
||||||
|
else:
|
||||||
|
context = context_embedding
|
||||||
|
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
return self.transformer_decoder(seq, context=context)
|
||||||
|
|
||||||
|
class NewCrossAttendwithRoPEDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True,
|
||||||
|
cross_attend = True,
|
||||||
|
only_cross = False,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
rotary_pos_emb = True,
|
||||||
|
ff_swish = True, # set this to True
|
||||||
|
ff_glu = True, # set to true to use for all feedforwards
|
||||||
|
)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||||
|
if context_embedding is None:
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask
|
||||||
|
).last_hidden_state
|
||||||
|
else:
|
||||||
|
context = context_embedding
|
||||||
|
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
return self.transformer_decoder(seq, context=context)
|
||||||
|
|
||||||
|
class XtransformerPrefixDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = PrefixDecoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None):
|
||||||
|
assert context is not None, 'context should be provided for prefix decoder'
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask
|
||||||
|
).last_hidden_state
|
||||||
|
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
return self.transformer_decoder(seq)
|
||||||
|
|
||||||
|
class XtransformerPretrainingDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
|
||||||
|
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
return self.transformer_decoder(seq)
|
||||||
|
|
||||||
|
class XtransformerFinetuningDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||||
|
if context_embedding is None:
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
).last_hidden_state
|
||||||
|
else:
|
||||||
|
context = context_embedding
|
||||||
|
|
||||||
|
# concatenate context with seq
|
||||||
|
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||||
|
# cut to only return the seq part
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||||
|
# cut to only return the seq part
|
||||||
|
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
# cut to only return the seq part
|
||||||
|
hidden_vec = self.transformer_decoder(seq)
|
||||||
|
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||||
|
return hidden_vec
|
||||||
|
|
||||||
|
class XtransformerLargeFinetuningDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True)
|
||||||
|
# add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||||
|
if context_embedding is None:
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
).last_hidden_state
|
||||||
|
else:
|
||||||
|
context = context_embedding
|
||||||
|
|
||||||
|
# concatenate context with seq
|
||||||
|
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||||
|
# cut to only return the seq part
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||||
|
# cut to only return the seq part
|
||||||
|
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
# cut to only return the seq part
|
||||||
|
hidden_vec = self.transformer_decoder(seq)
|
||||||
|
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||||
|
return hidden_vec
|
||||||
BIN
SongEval/.DS_Store
vendored
Normal file
BIN
SongEval/.DS_Store
vendored
Normal file
Binary file not shown.
201
SongEval/LICENSE
Normal file
201
SongEval/LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
88
SongEval/README.md
Normal file
88
SongEval/README.md
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# 🎵 SongEval: A Benchmark Dataset for Song Aesthetics Evaluation
|
||||||
|
|
||||||
|
[](https://huggingface.co/datasets/ASLP-lab/SongEval)
|
||||||
|
[](https://arxiv.org/pdf/2505.10793)
|
||||||
|
[](https://creativecommons.org/licenses/by-nc-sa/4.0/)
|
||||||
|
|
||||||
|
|
||||||
|
This repository provides a **trained aesthetic evaluation toolkit** based on [SongEval](https://huggingface.co/datasets/ASLP-lab/SongEval), the first large-scale, open-source dataset for human-perceived song aesthetics. The toolkit enables **automatic scoring of generated song** across five perceptual aesthetic dimensions aligned with professional musician judgments.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🌟 Key Features
|
||||||
|
|
||||||
|
- 🧠 **Pretrained neural models** for perceptual aesthetic evaluation
|
||||||
|
- 🎼 Predicts **five aesthetic dimensions**:
|
||||||
|
- Overall Coherence
|
||||||
|
- Memorability
|
||||||
|
- Naturalness of Vocal Breathing and Phrasing
|
||||||
|
- Clarity of Song Structure
|
||||||
|
- Overall Musicality
|
||||||
|
<!-- - 🧪 Supports **batch evaluation** for model benchmarking -->
|
||||||
|
- 🎧 Accepts **full-length songs** (vocals + accompaniment) as input
|
||||||
|
- ⚙️ Simple inference interface
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📦 Installation
|
||||||
|
|
||||||
|
Clone the repository and install dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/ASLP-lab/SongEval.git
|
||||||
|
cd SongEval
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
- Evaluate a single audio file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py -i /path/to/audio.mp3 -o /path/to/output
|
||||||
|
```
|
||||||
|
|
||||||
|
- Evaluate a list of audio files:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py -i /path/to/audio_list.txt -o /path/to/output
|
||||||
|
```
|
||||||
|
|
||||||
|
- Evaluate all audio files in a directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py -i /path/to/audio_directory -o /path/to/output
|
||||||
|
```
|
||||||
|
|
||||||
|
- Force evaluation on CPU (⚠️ CPU evaluation may be significantly slower) :
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py -i /path/to/audio.wav -o /path/to/output --use_cpu True
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 🙏 Acknowledgement
|
||||||
|
This project is mainly organized by the audio, speech and language processing lab [(ASLP@NPU)](http://www.npu-aslp.org/).
|
||||||
|
|
||||||
|
We sincerely thank the **Shanghai Conservatory of Music** for their expert guidance on music theory, aesthetics, and annotation design.
|
||||||
|
Meanwhile, we thank AISHELL to help with the orgnization of the song annotations.
|
||||||
|
|
||||||
|
<p align="center"> <img src="assets/logo.png" alt="Shanghai Conservatory of Music Logo"/> </p>
|
||||||
|
|
||||||
|
## 📑 License
|
||||||
|
This project is released under the CC BY-NC-SA 4.0 license.
|
||||||
|
|
||||||
|
You are free to use, modify, and build upon it for non-commercial purposes, with attribution.
|
||||||
|
|
||||||
|
## 📚 Citation
|
||||||
|
If you use this toolkit or the SongEval dataset, please cite the following:
|
||||||
|
```
|
||||||
|
@article{yao2025songeval,
|
||||||
|
title = {SongEval: A Benchmark Dataset for Song Aesthetics Evaluation},
|
||||||
|
author = {Yao, Jixun and Ma, Guobin and Xue, Huixin and Chen, Huakang and Hao, Chunbo and Jiang, Yuepeng and Liu, Haohe and Yuan, Ruibin and Xu, Jin and Xue, Wei and others},
|
||||||
|
journal = {arXiv preprint arXiv:2505.10793},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
BIN
SongEval/assets/logo.png
Normal file
BIN
SongEval/assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1016 KiB |
184
SongEval/clap_score.py
Normal file
184
SongEval/clap_score.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import laion_clap
|
||||||
|
from clap_module.factory import load_state_dict
|
||||||
|
import librosa
|
||||||
|
import pyloudnorm as pyln
|
||||||
|
|
||||||
|
# following documentation from https://github.com/LAION-AI/CLAP
|
||||||
|
def int16_to_float32(x):
|
||||||
|
return (x / 32767.0).astype(np.float32)
|
||||||
|
|
||||||
|
def float32_to_int16(x):
|
||||||
|
x = np.clip(x, a_min=-1., a_max=1.)
|
||||||
|
return (x * 32767.).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
|
def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_speech_audioset_epoch_15_esc_89.98.pt'):
|
||||||
|
"""
|
||||||
|
Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and
|
||||||
|
the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
|
||||||
|
|
||||||
|
This evaluation script assumes that audio_path files are identified with the ids in id2text.
|
||||||
|
|
||||||
|
clap_score() evaluates all ids in id2text.
|
||||||
|
|
||||||
|
GPU-based computation.
|
||||||
|
|
||||||
|
Select one of the following models from https://github.com/LAION-AI/CLAP:
|
||||||
|
- music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
|
||||||
|
- music_audioset_epoch_15_esc_90.14.pt
|
||||||
|
- music_speech_epoch_15_esc_89.25.pt
|
||||||
|
- 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)
|
||||||
|
|
||||||
|
Params:
|
||||||
|
-- id2text: dictionary with the mapping between id (generated audio filenames in audio_path)
|
||||||
|
and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
|
||||||
|
-- audio_path: path where the generated audio files to evaluate are available.
|
||||||
|
-- audio_files_extension: files extension (default .wav) in eval_path.
|
||||||
|
-- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
|
||||||
|
Returns:
|
||||||
|
-- CLAP-LION score
|
||||||
|
"""
|
||||||
|
# load model
|
||||||
|
if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
|
||||||
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
|
||||||
|
clap_path = 'load/clap_score/music_speech_audioset_epoch_15_esc_89.98.pt'
|
||||||
|
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
||||||
|
elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
|
||||||
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
|
||||||
|
clap_path = 'load/clap_score/music_audioset_epoch_15_esc_90.14.pt'
|
||||||
|
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
||||||
|
elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
|
||||||
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
|
||||||
|
clap_path = 'load/clap_score/music_speech_epoch_15_esc_89.25.pt'
|
||||||
|
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
||||||
|
elif clap_model == '630k-audioset-fusion-best.pt':
|
||||||
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
|
||||||
|
clap_path = 'load/clap_score/630k-audioset-fusion-best.pt'
|
||||||
|
model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
|
||||||
|
else:
|
||||||
|
raise ValueError('clap_model not implemented')
|
||||||
|
|
||||||
|
# download clap_model if not already downloaded
|
||||||
|
if not os.path.exists(clap_path):
|
||||||
|
print('Downloading ', clap_model, '...')
|
||||||
|
os.makedirs(os.path.dirname(clap_path), exist_ok=True)
|
||||||
|
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
|
||||||
|
with open(clap_path, 'wb') as file:
|
||||||
|
with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
|
||||||
|
for data in response.iter_content(chunk_size=8192):
|
||||||
|
file.write(data)
|
||||||
|
progress_bar.update(len(data))
|
||||||
|
|
||||||
|
# fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
|
||||||
|
pkg = load_state_dict(clap_path)
|
||||||
|
pkg.pop('text_branch.embeddings.position_ids', None)
|
||||||
|
model.model.load_state_dict(pkg)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if not os.path.isdir(audio_path):
|
||||||
|
raise ValueError('audio_path does not exist')
|
||||||
|
|
||||||
|
if id2text:
|
||||||
|
print('[EXTRACTING TEXT EMBEDDINGS] ')
|
||||||
|
batch_size = 64
|
||||||
|
text_emb = {}
|
||||||
|
for i in tqdm(range(0, len(id2text), batch_size)):
|
||||||
|
batch_ids = list(id2text.keys())[i:i+batch_size]
|
||||||
|
batch_texts = [id2text[id] for id in batch_ids]
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
|
||||||
|
for id, emb in zip(batch_ids, embeddings):
|
||||||
|
text_emb[id] = emb
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('Must specify id2text')
|
||||||
|
|
||||||
|
print('[EVALUATING GENERATIONS] ', audio_path)
|
||||||
|
score = 0
|
||||||
|
count = 0
|
||||||
|
for id in tqdm(id2text.keys()):
|
||||||
|
file_path = os.path.join(audio_path, str(id)+audio_files_extension)
|
||||||
|
with torch.no_grad():
|
||||||
|
audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
|
||||||
|
audio = pyln.normalize.peak(audio, -1.0)
|
||||||
|
audio = audio.reshape(1, -1) # unsqueeze (1,T)
|
||||||
|
audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
|
||||||
|
audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
|
||||||
|
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
|
||||||
|
score += cosine_sim
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return score / count if count > 0 else 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description='Compute CLAP score for generated audio files.')
|
||||||
|
parser.add_argument('--clap_model', type=str, default='630k-audioset-fusion-best.pt',
|
||||||
|
help='CLAP model to use for evaluation. Options: music_speech_audioset_epoch_15_esc_89.98.pt, music_audioset_epoch_15_esc_90.14.pt, music_speech_epoch_15_esc_89.25.pt, 630k-audioset-fusion-best.pt (default: 630k-audioset-fusion-best.pt)')
|
||||||
|
parser.add_argument('--root_path', type=str, default='../wandb/run-20250627_172105-xpe7nh5n-worseInstr/generated_samples_text_conditioned_top_p_threshold_0.99_temperature_1.15_8',
|
||||||
|
help='Path to the directory containing generated audio files and id2text mapping.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
clap_model = args.clap_model
|
||||||
|
root_path = args.root_path
|
||||||
|
json_file_path = os.path.join(root_path, 'name2prompt.jsonl')
|
||||||
|
generated_path = os.path.join(root_path, 'prompt_music')
|
||||||
|
if not os.path.exists(generated_path):
|
||||||
|
generated_path =root_path # if no 'music' subfolder, use root_path directly
|
||||||
|
|
||||||
|
with open(json_file_path, 'r') as f:
|
||||||
|
id2text_dict = {}
|
||||||
|
for line in f:
|
||||||
|
item = json.loads(line)
|
||||||
|
for k, v in item.items():
|
||||||
|
id2text_dict[k] = v[0]
|
||||||
|
print('length of id2text:', len(id2text_dict))
|
||||||
|
# id2text = {k+'_1': v[0] for k, v in id2text_dict.items()} # assuming each key has a list of prompts, we take the first one
|
||||||
|
id2text ={}
|
||||||
|
for k, v in id2text_dict.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
id2text[k] = v[0]
|
||||||
|
# ckeck if k exist as wav file
|
||||||
|
if os.path.exists(os.path.join(generated_path, str(k)+'.wav')):
|
||||||
|
id2text[k] = v[0]
|
||||||
|
else:
|
||||||
|
# find k_*, k_1, k_2, ... and check if they exist
|
||||||
|
for i in range(0, 10): # assuming no more than 100 variations
|
||||||
|
if os.path.exists(os.path.join(generated_path, str(k)+'_'+str(i)+'.wav')):
|
||||||
|
new_key = str(k) + '_' + str(i)
|
||||||
|
id2text[new_key] = v[0]
|
||||||
|
print('length of id2text after checking wav files:', len(id2text))
|
||||||
|
# check if wav exsists
|
||||||
|
new_id2text = {}
|
||||||
|
for id in id2text.keys():
|
||||||
|
file_path = os.path.join(generated_path, str(id)+'.wav')
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
new_id2text[id] = id2text[id]
|
||||||
|
else:
|
||||||
|
print(f"Warning: {file_path} does not exist, skipping this id.")
|
||||||
|
print('length of new_id2text:', len(new_id2text))
|
||||||
|
|
||||||
|
"""
|
||||||
|
IMPORTANT: the audios in generated_path should have the same ids as in id2text.
|
||||||
|
For musiccaps, you can load id2text as above and each generated_path audio file
|
||||||
|
corresponds to a prompt (text description) in musiccaps. Files are named with ids, as follows:
|
||||||
|
- your_model_outputs_folder/_-kssA-FOzU.wav
|
||||||
|
- your_model_outputs_folder/_0-2meOf9qY.wav
|
||||||
|
- your_model_outputs_folder/_1woPC5HWSg.wav
|
||||||
|
...
|
||||||
|
- your_model_outputs_folder/ZzyWbehtt0M.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
clp = clap_score(new_id2text, generated_path, audio_files_extension='.wav')
|
||||||
|
print('CLAP score (cosine similarity):', clp)
|
||||||
6
SongEval/config.yaml
Normal file
6
SongEval/config.yaml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
generator:
|
||||||
|
_target_: model.Generator
|
||||||
|
in_features: 1024
|
||||||
|
ffd_hidden_size: 4096
|
||||||
|
num_classes: 5
|
||||||
|
attn_layer_num: 4
|
||||||
456
SongEval/controlability.py
Normal file
456
SongEval/controlability.py
Normal file
@ -0,0 +1,456 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
generate_path = 'Text2midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-attribute2music/infer_test/topk15-t1.0-ngram0/all_midis'
|
||||||
|
# generate_path = 'Text2midi/t2m-inferalign/text2midi_infer_output'
|
||||||
|
# generate_path = 'wandb/no-disp-no-ciem/text_condi_top_p_t0.99_temp1.25'
|
||||||
|
test_set_json = "dataset/midicaps/train.json"
|
||||||
|
|
||||||
|
generated_eval_json_path = f"{generate_path}/eval.json"
|
||||||
|
generated_name2prompt_jsonl_path = f"{generate_path}/name2prompt.jsonl"
|
||||||
|
|
||||||
|
# 1. 读取 test_set,建立 prompt 到条目的映射
|
||||||
|
with open(test_set_json, 'r') as f:
|
||||||
|
test_set =[]
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
item = json.loads(line.strip())
|
||||||
|
test_set.append(item)
|
||||||
|
prompt2item = {item['caption']: item for item in test_set if item['test_set'] is True}
|
||||||
|
print(f"Number of prompts in test set: {len(prompt2item)}")
|
||||||
|
# 2. 读取 name2prompt.jsonl,建立 name 到 prompt 的映射
|
||||||
|
name2prompt = {}
|
||||||
|
with open(generated_name2prompt_jsonl_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
obj = json.loads(line)
|
||||||
|
name2prompt.update({k: v[0] for k, v in obj.items() if isinstance(v, list) and len(v) > 0})
|
||||||
|
# 3. 读取 eval.json
|
||||||
|
with open(generated_eval_json_path, 'r') as f:
|
||||||
|
eval_items = []
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
item = json.loads(line.strip())
|
||||||
|
eval_items.append(item)
|
||||||
|
|
||||||
|
# 4. 对每个 name,找到对应的 prompt,确保 prompt 在 test_set 里,然后找到 eval.json 里对应的条目
|
||||||
|
results = []
|
||||||
|
# turn the name of eval_items into relative name
|
||||||
|
for item in eval_items:
|
||||||
|
item['name'] = item['name'].split('/')[-1] # 假设 name 是一个路径,取最后一部分作为相对名称
|
||||||
|
# 去掉第二个下划线后面的内容
|
||||||
|
if '_' in item['name']:
|
||||||
|
item['name'] = item['name'].split('.')[0].split('_')[0] + '_' + item['name'].split('.')[0].split('_')[1]
|
||||||
|
# print(f"Processed eval item name: {item['name']}")
|
||||||
|
|
||||||
|
for name, prompt in name2prompt.items():
|
||||||
|
if prompt not in prompt2item:
|
||||||
|
print(f"Prompt not found in test set: {prompt}")
|
||||||
|
continue
|
||||||
|
# 找到 eval.json 里对应的条目(假设 eval.json 里有 name 字段)
|
||||||
|
eval_entry = next((item for item in eval_items if item.get('name') == name), None)
|
||||||
|
if eval_entry is None:
|
||||||
|
print(f"Eval entry not found for name: {name}")
|
||||||
|
continue
|
||||||
|
# 原始条目
|
||||||
|
original_entry = prompt2item[prompt]
|
||||||
|
results.append({
|
||||||
|
'name': name,
|
||||||
|
'prompt': prompt,
|
||||||
|
'eval_entry': eval_entry,
|
||||||
|
'original_entry': original_entry
|
||||||
|
})
|
||||||
|
print(f"Number of results: {len(results)}")
|
||||||
|
print(f"Sample result: {results[0] if results else 'No results'}")
|
||||||
|
|
||||||
|
def calculate_TBT_score(results):
|
||||||
|
"""
|
||||||
|
• Tempo Bin with Tolerance (TBT): The predicted bpm falls into the ground truth tempo bin or
|
||||||
|
a neighboring one.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'tempo' in eval_entry and 'tempo' in original_entry:
|
||||||
|
eval_tempo = eval_entry['tempo'][0] if isinstance(eval_entry['tempo'], list) else eval_entry['tempo']
|
||||||
|
original_tempo = original_entry['tempo']
|
||||||
|
if original_tempo is None or eval_tempo is None:
|
||||||
|
continue # 如果原始条目没有 tempo,跳过
|
||||||
|
# 检查 eval_tempo 是否在 original_tempo 的范围内
|
||||||
|
if original_tempo - 10 <= eval_tempo <= original_tempo + 15:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
TB_score = correct / total if total > 0 else 0
|
||||||
|
print(f"TB Score: {TB_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return TB_score
|
||||||
|
|
||||||
|
def calculate_CK_score(results):
|
||||||
|
"""
|
||||||
|
• Correct Key (CK): The predicted key matches the ground truth key.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'key' in eval_entry and 'key' in original_entry:
|
||||||
|
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
|
||||||
|
eval_key = eval_key if eval_key is not None else "C major" # 默认值为 C 大调
|
||||||
|
original_key = original_entry['key'] if original_entry['key'] is not None else "C major" # 默认值为 C 大调
|
||||||
|
if original_key is None or eval_key is None:
|
||||||
|
continue
|
||||||
|
if eval_key == original_key:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CK_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CK Score: {CK_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CK_score
|
||||||
|
def calculate_CKD_score(results):
|
||||||
|
"""
|
||||||
|
Correct Key with Duplicates (CKD): The predicted key matches the ground truth key or an equivalent key (i.e., a major key and its relative minor).
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'key' in eval_entry and 'key' in original_entry:
|
||||||
|
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
|
||||||
|
if eval_key is None:
|
||||||
|
eval_key = "C major" # 默认值为 C 大调
|
||||||
|
original_key = original_entry['key'] if original_entry['key'] is not None else "C major"
|
||||||
|
if original_key is None or eval_key is None:
|
||||||
|
continue # 如果原始条目没有 key,跳过
|
||||||
|
# 检查 eval_key 是否与 original_key 相同或是其相对小调
|
||||||
|
if eval_key == original_key or (eval_key.split(' ')[0] == original_key.split(' ')[0]):
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CKD_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CKD Score: {CKD_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CKD_score
|
||||||
|
|
||||||
|
def calculate_CTS_score(results):
|
||||||
|
"""
|
||||||
|
• Correct Time Signature (CTS): The predicted time signature matches the ground truth time signature.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'time_signature' in eval_entry and 'time_signature' in original_entry:
|
||||||
|
eval_time_signature = eval_entry['time_signature'][0] if isinstance(eval_entry['time_signature'], list) else eval_entry['time_signature']
|
||||||
|
original_time_signature = original_entry['time_signature']
|
||||||
|
if original_time_signature is None or eval_time_signature is None:
|
||||||
|
continue # 如果原始条目没有 time signature,跳过
|
||||||
|
if eval_time_signature == original_time_signature:
|
||||||
|
correct += 1
|
||||||
|
else:
|
||||||
|
# 检查是否为相同的节拍(如 4/4 和 2/2)
|
||||||
|
eval_numerator, eval_denominator = map(int, eval_time_signature.split('/'))
|
||||||
|
original_numerator, original_denominator = map(int, original_time_signature.split('/'))
|
||||||
|
if (eval_numerator == original_numerator and eval_denominator == original_denominator) or \
|
||||||
|
(eval_numerator * 2 == original_numerator and eval_denominator == original_denominator):
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CTS_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CTS Score: {CTS_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CTS_score
|
||||||
|
|
||||||
|
def calculate_ECM_score(results):
|
||||||
|
"""
|
||||||
|
Exact Chord Match (ECM): The predicted
|
||||||
|
chord sequence matches the ground truth exactly
|
||||||
|
in terms of order, chord root, and chord type, with
|
||||||
|
tolerance for missing and excess chord instances.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'chord_summary' in eval_entry and 'chord_summary' in original_entry:
|
||||||
|
eval_chord_summary = eval_entry['chord_summary'][0] if isinstance(eval_entry['chord_summary'], list) else eval_entry['chord_summary']
|
||||||
|
original_chord_summary = original_entry['chord_summary']
|
||||||
|
if original_chord_summary is None or eval_chord_summary is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_chord_summary 是否包含 original_chord_summary,两个都是列表,每个元素是一个字符串
|
||||||
|
if eval_chord_summary == original_chord_summary:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
ECM_score = correct / total if total > 0 else 0
|
||||||
|
print(f"ECM Score: {ECM_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return ECM_score
|
||||||
|
|
||||||
|
def calculate_CMO_score(results):
|
||||||
|
"""
|
||||||
|
• Chord Match in any Order (CMO): The portion of predicted chord sequence matching the
|
||||||
|
ground truth chord root and type, in any order
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'chords' in eval_entry and 'chord_summary' in original_entry:
|
||||||
|
eval_chords_seq = eval_entry['chords']
|
||||||
|
# remove the confidence score from eval_chords_seq
|
||||||
|
if isinstance(eval_chords_seq, list) and len(eval_chords_seq) > 0 and isinstance(eval_chords_seq[0], list):
|
||||||
|
eval_chords_seq = [chord[0] for chord in eval_chords_seq]
|
||||||
|
original_chord_summary = original_entry['chord_summary']
|
||||||
|
if original_chord_summary is None or eval_chords_seq is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_chords_seq 是否包含 original_chord_summary,两个都是列表
|
||||||
|
eval_chords_set = set(eval_chords_seq) # [['C', 0.464399092], ['G', 2.879274376]]
|
||||||
|
original_chord_set = set(original_chord_summary) # ['G', 'C']
|
||||||
|
if original_chord_set.issubset(eval_chords_set):
|
||||||
|
correct += 1
|
||||||
|
else:
|
||||||
|
if original_chord_set == eval_chords_set:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CMO_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CMO Score: {CMO_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CMO_score
|
||||||
|
|
||||||
|
def calculate_CI_score(results):
|
||||||
|
"""
|
||||||
|
•Correct Instrument (CI): The predicted instrument matches the ground truth instrument.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
|
||||||
|
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
|
||||||
|
original_instrument = original_entry['instrument_summary']
|
||||||
|
if original_instrument is None or eval_instrument is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_instrument 是否包含 original_instrument
|
||||||
|
if isinstance(eval_instrument, list):
|
||||||
|
eval_instrument_set = set(eval_instrument)
|
||||||
|
original_instrument_set = set(original_instrument)
|
||||||
|
if original_instrument_set.issubset(eval_instrument_set):
|
||||||
|
correct += 1
|
||||||
|
else:
|
||||||
|
if eval_instrument == original_instrument:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CI_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CI Score: {CI_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CI_score
|
||||||
|
|
||||||
|
def calculate_CI_top1_score(results):
|
||||||
|
"""
|
||||||
|
•Correct Instrument Top-1 (CI_top1): The predicted instrument matches the ground truth instrument
|
||||||
|
or is one of the top 3 predicted instruments.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
|
||||||
|
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
|
||||||
|
original_instrument = original_entry['instrument_summary']
|
||||||
|
if original_instrument is None or eval_instrument is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_instrument 是否包含 original_instrument中的一个元素
|
||||||
|
if isinstance(eval_instrument, list):
|
||||||
|
eval_instrument_set = set(eval_instrument)
|
||||||
|
original_instrument_set = set(original_instrument)
|
||||||
|
for inst in original_instrument_set:
|
||||||
|
if inst in eval_instrument_set:
|
||||||
|
correct += 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if eval_instrument == original_instrument:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CI_top1_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CI Top-1 Score: {CI_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CI_top1_score
|
||||||
|
|
||||||
|
def calculate_CG_score(results):
|
||||||
|
"""
|
||||||
|
• Correct Genre (CG): The predicted genre matches the ground truth genre.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'genre' in eval_entry and 'genre' in original_entry:
|
||||||
|
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
|
||||||
|
original_genre = original_entry['genre']
|
||||||
|
if original_genre is None or eval_genre is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_genre 是否包含 original_genre
|
||||||
|
if isinstance(eval_genre, list):
|
||||||
|
eval_genre_set = set(eval_genre)
|
||||||
|
original_genre_set = set(original_genre)
|
||||||
|
if original_genre_set.issubset(eval_genre_set):
|
||||||
|
correct += 1
|
||||||
|
else:
|
||||||
|
if eval_genre == original_genre:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CG_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CG Score: {CG_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CG_score
|
||||||
|
|
||||||
|
def calculate_CG_top1_score(results):
|
||||||
|
"""
|
||||||
|
• Correct Genre Top-1 (CG_top1): The predicted genre matches the ground truth genre or is one of the top 3 predicted genres.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'genre' in eval_entry and 'genre' in original_entry:
|
||||||
|
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
|
||||||
|
original_genre = original_entry['genre']
|
||||||
|
if original_genre is None or eval_genre is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_genre 是否包含 original_genre中的一个元素
|
||||||
|
if isinstance(eval_genre, list):
|
||||||
|
eval_genre_set = set(eval_genre)
|
||||||
|
original_genre_set = set(original_genre)
|
||||||
|
for gen in original_genre_set:
|
||||||
|
if gen in eval_genre_set:
|
||||||
|
correct += 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if eval_genre == original_genre:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CG_top1_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CG Top-1 Score: {CG_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CG_top1_score
|
||||||
|
|
||||||
|
def calculate_CM_score(results):
|
||||||
|
"""
|
||||||
|
• Correct Mood (CM): The predicted mood matches the ground truth mood.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'mood' in eval_entry and 'mood' in original_entry:
|
||||||
|
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||||||
|
original_mood = original_entry['mood']
|
||||||
|
if original_mood is None or eval_mood is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_mood 是否包含 original_mood
|
||||||
|
if isinstance(eval_mood, list):
|
||||||
|
eval_mood_set = set(eval_mood)
|
||||||
|
original_mood_set = set(original_mood)
|
||||||
|
if original_mood_set.issubset(eval_mood_set):
|
||||||
|
correct += 1
|
||||||
|
else:
|
||||||
|
if eval_mood == original_mood:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CM_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CM Score: {CM_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CM_score
|
||||||
|
|
||||||
|
def calculate_CM_top1_score(results):
|
||||||
|
"""
|
||||||
|
• Correct Mood Top-1 (CM_top1): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'mood' in eval_entry and 'mood' in original_entry:
|
||||||
|
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||||||
|
original_mood = original_entry['mood']
|
||||||
|
if original_mood is None or eval_mood is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_mood 是否包含 original_mood中的一个元素
|
||||||
|
if isinstance(eval_mood, list):
|
||||||
|
eval_mood_set = set(eval_mood)
|
||||||
|
original_mood_set = set(original_mood)
|
||||||
|
for mood in original_mood_set:
|
||||||
|
if mood in eval_mood_set:
|
||||||
|
correct += 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if eval_mood == original_mood:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CM_top1_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CM Top-1 Score: {CM_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CM_top1_score
|
||||||
|
|
||||||
|
def calculate_CM_top3_score(results):
|
||||||
|
"""
|
||||||
|
• Correct Mood Top-3 (CM_top3): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
|
||||||
|
"""
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
for result in results:
|
||||||
|
eval_entry = result['eval_entry']
|
||||||
|
original_entry = result['original_entry']
|
||||||
|
if 'mood' in eval_entry and 'mood' in original_entry:
|
||||||
|
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||||||
|
original_mood = original_entry['mood']
|
||||||
|
if original_mood is None or eval_mood is None:
|
||||||
|
continue
|
||||||
|
# 检查 eval_mood 是否包含 original_mood中的3个元素
|
||||||
|
if isinstance(eval_mood, list):
|
||||||
|
eval_mood_set = set(eval_mood)
|
||||||
|
original_mood_set = set(original_mood)
|
||||||
|
if len(original_mood_set) <= 3 and original_mood_set.issubset(eval_mood_set):
|
||||||
|
correct += 1
|
||||||
|
elif len(original_mood_set) > 3:
|
||||||
|
match_num = sum(1 for mood in original_mood_set if mood in eval_mood_set)
|
||||||
|
if match_num >= 3:
|
||||||
|
correct += 1
|
||||||
|
else:
|
||||||
|
if eval_mood == original_mood:
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
CM_top3_score = correct / total if total > 0 else 0
|
||||||
|
print(f"CM Top-3 Score: {CM_top3_score:.4f} (Correct: {correct}, Total: {total})")
|
||||||
|
return CM_top3_score
|
||||||
|
|
||||||
|
def calculate_all_scores(results):
|
||||||
|
"""
|
||||||
|
Calculate all scores and return them as a dictionary.
|
||||||
|
"""
|
||||||
|
scores = {
|
||||||
|
'TBT_score': calculate_TBT_score(results),
|
||||||
|
'CK_score': calculate_CK_score(results),
|
||||||
|
'CKD_score': calculate_CKD_score(results),
|
||||||
|
'CTS_score': calculate_CTS_score(results),
|
||||||
|
'ECM_score': calculate_ECM_score(results),
|
||||||
|
'CMO_score': calculate_CMO_score(results),
|
||||||
|
'CI_score': calculate_CI_score(results),
|
||||||
|
'CI_top1_score': calculate_CI_top1_score(results),
|
||||||
|
'CG_score': calculate_CG_score(results),
|
||||||
|
'CG_top1_score': calculate_CG_top1_score(results),
|
||||||
|
'CM_score': calculate_CM_score(results),
|
||||||
|
'CM_top1_score': calculate_CM_top1_score(results),
|
||||||
|
'CM_top3_score': calculate_CM_top3_score(results)
|
||||||
|
}
|
||||||
|
return scores
|
||||||
|
if __name__ == "__main__":
|
||||||
|
scores = calculate_all_scores(results)
|
||||||
|
print("All Scores:")
|
||||||
|
for score_name, score_value in scores.items():
|
||||||
|
print(f"{score_name}: {score_value:.4f}")
|
||||||
|
|
||||||
|
# Save the results to a JSON file
|
||||||
|
output_file = f"{generate_path}/results.json"
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
json.dump(scores, f, indent=4)
|
||||||
|
print(f"Results saved to {output_file}")
|
||||||
|
|
||||||
103
SongEval/ebr.py
Normal file
103
SongEval/ebr.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import muspy
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def compute_midi_metrics(file_path):
|
||||||
|
"""计算单个MIDI文件的音乐指标"""
|
||||||
|
try:
|
||||||
|
music = muspy.read(file_path)
|
||||||
|
scale_consistency = muspy.scale_consistency(music)
|
||||||
|
pitch_entropy = muspy.pitch_entropy(music)
|
||||||
|
pitch_class_entropy = muspy.pitch_class_entropy(music)
|
||||||
|
empty_beat_rate = muspy.empty_beat_rate(music)
|
||||||
|
groove_consistency = muspy.groove_consistency(music, 12)
|
||||||
|
metrics = {
|
||||||
|
'scale_consistency': scale_consistency,
|
||||||
|
'pitch_entropy': pitch_entropy,
|
||||||
|
'pitch_class_entropy': pitch_class_entropy,
|
||||||
|
'empty_beat_rate': empty_beat_rate,
|
||||||
|
'groove_consistency': groove_consistency,
|
||||||
|
'filename': os.path.basename(file_path)
|
||||||
|
}
|
||||||
|
return metrics
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_directory_metrics(directory_path, num_workers=8):
|
||||||
|
"""计算目录下所有MIDI文件的音乐指标(多线程加速)"""
|
||||||
|
midi_files = []
|
||||||
|
for root, _, files in os.walk(directory_path):
|
||||||
|
for file in files:
|
||||||
|
if file.lower().endswith(('.mid', '.midi')):
|
||||||
|
midi_files.append(os.path.join(root, file))
|
||||||
|
if not midi_files:
|
||||||
|
print("目录及子文件夹中未找到MIDI文件")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_metrics = []
|
||||||
|
average_metrics = {
|
||||||
|
'scale_consistency': 0,
|
||||||
|
'pitch_entropy': 0,
|
||||||
|
'pitch_class_entropy': 0,
|
||||||
|
'empty_beat_rate': 0,
|
||||||
|
'groove_consistency': 0
|
||||||
|
}
|
||||||
|
current_num = 0
|
||||||
|
total_scale_consistency = 0
|
||||||
|
total_pitch_entropy = 0
|
||||||
|
total_pitch_class_entropy = 0
|
||||||
|
total_empty_beat_rate = 0
|
||||||
|
total_groove_consistency = 0
|
||||||
|
print(f"正在处理目录: {directory_path}")
|
||||||
|
print(f"发现 {len(midi_files)} 个MIDI文件:")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
|
futures = {executor.submit(compute_midi_metrics, midi_file): midi_file for midi_file in midi_files}
|
||||||
|
for future in tqdm(as_completed(futures), total=len(midi_files), desc="处理中"):
|
||||||
|
metrics = future.result()
|
||||||
|
|
||||||
|
if metrics is not None:
|
||||||
|
current_num += 1
|
||||||
|
total_scale_consistency += metrics['scale_consistency']
|
||||||
|
total_pitch_entropy += metrics['pitch_entropy']
|
||||||
|
total_pitch_class_entropy += metrics['pitch_class_entropy']
|
||||||
|
total_empty_beat_rate += metrics['empty_beat_rate']
|
||||||
|
total_groove_consistency += metrics['groove_consistency']
|
||||||
|
average_metrics['scale_consistency'] = total_scale_consistency / current_num
|
||||||
|
average_metrics['pitch_entropy'] = total_pitch_entropy / current_num
|
||||||
|
average_metrics['pitch_class_entropy'] = total_pitch_class_entropy / current_num
|
||||||
|
average_metrics['empty_beat_rate'] = total_empty_beat_rate / current_num
|
||||||
|
average_metrics['groove_consistency'] = total_groove_consistency / current_num
|
||||||
|
print("current_metrics:", metrics)
|
||||||
|
|
||||||
|
all_metrics.append(metrics)
|
||||||
|
|
||||||
|
if not all_metrics:
|
||||||
|
print("所有文件处理失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
df = pd.DataFrame(all_metrics)
|
||||||
|
output_csv = os.path.join(directory_path, "midi_metrics_report.csv")
|
||||||
|
df.to_csv(output_csv, index=False)
|
||||||
|
avg_metrics = df.mean(numeric_only=True)
|
||||||
|
return df, avg_metrics
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="计算目录下所有MIDI文件的音乐指标")
|
||||||
|
parser.add_argument("path", type=str, help="包含MIDI文件的目录路径")
|
||||||
|
parser.add_argument("--threads", type=int, default=1, help="线程数(默认8)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.isdir(args.path):
|
||||||
|
print(f"错误: 路径 '{args.path}' 不存在或不是目录")
|
||||||
|
else:
|
||||||
|
result, averages = compute_directory_metrics(args.path, num_workers=args.threads)
|
||||||
|
if result is not None:
|
||||||
|
print("\n计算完成! 结果已保存到 midi_metrics_report.csv")
|
||||||
|
print("\n平均指标值:")
|
||||||
|
print(averages.to_string())
|
||||||
150
SongEval/eval.py
Normal file
150
SongEval/eval.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
from muq import MuQ
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Synthesizer(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
checkpoint_path,
|
||||||
|
input_path,
|
||||||
|
output_dir,
|
||||||
|
use_cpu: bool = False):
|
||||||
|
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
self.input_path = input_path
|
||||||
|
self.output_dir = output_dir
|
||||||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
|
self.device = torch.device('cuda') if (torch.cuda.is_available() and (not use_cpu)) else torch.device('cpu')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def setup(self):
|
||||||
|
|
||||||
|
train_config = OmegaConf.load(os.path.join(os.path.dirname(self.checkpoint_path), '../config.yaml'))
|
||||||
|
model = instantiate(train_config.generator).to(self.device).eval()
|
||||||
|
state_dict = load_file(self.checkpoint_path, device="cpu")
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
|
||||||
|
self.muq = self.muq.to(self.device).eval()
|
||||||
|
self.result_dcit = {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def synthesis(self):
|
||||||
|
if os.path.isfile(self.input_path):
|
||||||
|
if self.input_path.endswith(('.wav', '.mp3')):
|
||||||
|
lines = []
|
||||||
|
lines.append(self.input_path)
|
||||||
|
else:
|
||||||
|
with open(self.input_path, "r") as f:
|
||||||
|
lines = [line for line in f]
|
||||||
|
input_files = [{
|
||||||
|
"input_path": line.strip(),
|
||||||
|
} for line in lines]
|
||||||
|
print(f"input filelst: {self.input_path}")
|
||||||
|
elif os.path.isdir(self.input_path):
|
||||||
|
input_files = [{
|
||||||
|
"input_path": file,
|
||||||
|
}for file in glob.glob(os.path.join(self.input_path, '*')) if file.lower().endswith(('.wav', '.mp3'))]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"input_path {self.input_path} is not a file or directory")
|
||||||
|
|
||||||
|
|
||||||
|
for input in tqdm(input_files):
|
||||||
|
try:
|
||||||
|
self.handle(**input)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
continue
|
||||||
|
# add average
|
||||||
|
avg_values = {}
|
||||||
|
for key in self.result_dcit[list(self.result_dcit.keys())[0]].keys():
|
||||||
|
avg_values[key] = round(np.mean([self.result_dcit[fid][key] for fid in self.result_dcit]), 4)
|
||||||
|
self.result_dcit['average'] = avg_values
|
||||||
|
# save result
|
||||||
|
with open(os.path.join(self.output_dir, "result.json") , "w")as f:
|
||||||
|
json.dump(self.result_dcit, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def handle(self, input_path):
|
||||||
|
|
||||||
|
fid = os.path.basename(input_path).split('.')[0]
|
||||||
|
if input_path.endswith('.npy'):
|
||||||
|
input = np.load(input_path)
|
||||||
|
|
||||||
|
# check ssl
|
||||||
|
if len(input.shape) == 3 and input.shape[0] != 1:
|
||||||
|
print('ssl_shape error', input_path)
|
||||||
|
return
|
||||||
|
if np.isnan(input).any():
|
||||||
|
print('ssl nan', input_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
input = torch.from_numpy(input).to(self.device)
|
||||||
|
if len(input.shape) == 2:
|
||||||
|
input = input.unsqueeze(0)
|
||||||
|
|
||||||
|
if input_path.endswith(('.wav', '.mp3')):
|
||||||
|
wav, sr = librosa.load(input_path, sr=24000)
|
||||||
|
audio = torch.tensor(wav).unsqueeze(0).to(self.device)
|
||||||
|
output = self.muq(audio, output_hidden_states=True)
|
||||||
|
input = output["hidden_states"][6]
|
||||||
|
|
||||||
|
values = {}
|
||||||
|
scores_g = self.model(input).squeeze(0)
|
||||||
|
values['Coherence'] = round(scores_g[0].item(), 4)
|
||||||
|
values['Musicality'] = round(scores_g[1].item(), 4)
|
||||||
|
values['Memorability'] = round(scores_g[2].item(), 4)
|
||||||
|
values['Clarity'] = round(scores_g[3].item(), 4)
|
||||||
|
values['Naturalness'] = round(scores_g[4].item(), 4)
|
||||||
|
|
||||||
|
|
||||||
|
self.result_dcit[fid] = values
|
||||||
|
# delete
|
||||||
|
del input, output, scores_g, values,audio, wav, sr
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-i", "--input_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Input audio: path to a single file, a text file listing audio paths, or a directory of audio files."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o", "--output_dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Output directory for generated results (will be created if it doesn't exist)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_cpu",
|
||||||
|
type=str,
|
||||||
|
help="Force CPU mode even if a GPU is available.",
|
||||||
|
default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ckpt_path = "ckpt/model.safetensors"
|
||||||
|
|
||||||
|
synthesizer = Synthesizer(checkpoint_path=ckpt_path,
|
||||||
|
input_path=args.input_path,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
use_cpu=args.use_cpu)
|
||||||
|
|
||||||
|
synthesizer.setup()
|
||||||
|
|
||||||
|
synthesizer.synthesis()
|
||||||
404
SongEval/generate-batch_easy.py
Normal file
404
SongEval/generate-batch_easy.py
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from multiprocessing import Process,set_start_method
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from Amadeus.evaluation_utils import (
|
||||||
|
wandb_style_config_to_omega_config,
|
||||||
|
prepare_model_and_dataset_from_config,
|
||||||
|
get_best_ckpt_path_and_config,
|
||||||
|
Evaluator
|
||||||
|
)
|
||||||
|
from transformers import T5Tokenizer, T5EncoderModel
|
||||||
|
|
||||||
|
from Amadeus import model_zoo
|
||||||
|
from Amadeus.symbolic_encoding import data_utils
|
||||||
|
from Amadeus.model_zoo import AmadeusModel
|
||||||
|
from Amadeus.symbolic_encoding.data_utils import TuneCompiler
|
||||||
|
from Amadeus.symbolic_encoding.compile_utils import shift_and_pad
|
||||||
|
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
|
||||||
|
from Amadeus.symbolic_encoding import decoding_utils
|
||||||
|
from Amadeus.train_utils import adjust_prediction_order
|
||||||
|
from data_representation import vocab_utils
|
||||||
|
from data_representation.vocab_utils import LangTokenVocab
|
||||||
|
|
||||||
|
|
||||||
|
def get_argument_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-wandb_exp_dir",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="wandb experiment directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-generation_type",
|
||||||
|
type=str,
|
||||||
|
choices=('conditioned', 'unconditioned', 'text-conditioned'),
|
||||||
|
default='unconditioned',
|
||||||
|
help="generation type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-sampling_method",
|
||||||
|
type=str,
|
||||||
|
choices=('top_p', 'top_k'),
|
||||||
|
default='top_p',
|
||||||
|
help="sampling method",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.99,
|
||||||
|
help="threshold",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-temperature",
|
||||||
|
type=float,
|
||||||
|
default=1.15,
|
||||||
|
help="temperature",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-num_samples",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="number of samples to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-num_target_measure",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="number of target measures for conditioned generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-choose_selected_tunes",
|
||||||
|
action='store_true',
|
||||||
|
help="generate samples from selected tunes, only for SOD dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-generate_length",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="length of the generated sequence",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-num_processes",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="number of processes to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-gpu_ids",
|
||||||
|
type=str,
|
||||||
|
default="0,5",
|
||||||
|
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-prompt",
|
||||||
|
type=str,
|
||||||
|
default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.",
|
||||||
|
help="prompt for generation, only used for conditioned generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-prompt_file",
|
||||||
|
type=str,
|
||||||
|
default="dataset/midicaps/train.json",
|
||||||
|
help="file containing prompts for text-conditioned generation",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def load_resources(wandb_exp_dir, device):
|
||||||
|
"""Load model and dataset resources for a process"""
|
||||||
|
wandb_dir = Path('wandb')
|
||||||
|
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
|
||||||
|
config = OmegaConf.load(config_path)
|
||||||
|
config = wandb_style_config_to_omega_config(config)
|
||||||
|
|
||||||
|
# Load checkpoint to specified device
|
||||||
|
ckpt = torch.load(ckpt_path, map_location=device)
|
||||||
|
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
||||||
|
model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
torch.compile(model)
|
||||||
|
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||||
|
|
||||||
|
# Prepare dataset for prompts
|
||||||
|
condition_list = [x[1] for x in test_set.data_list]
|
||||||
|
dataset_for_prompt = []
|
||||||
|
for i in range(len(condition_list)):
|
||||||
|
condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
|
||||||
|
dataset_for_prompt.append((condition, condition_list[i]))
|
||||||
|
|
||||||
|
return config, model, dataset_for_prompt, vocab
|
||||||
|
|
||||||
|
def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
||||||
|
"""Worker process for conditioned generation"""
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
device = torch.device(f'cuda:{gpu_id}')
|
||||||
|
|
||||||
|
# Load resources with proper device
|
||||||
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||||
|
|
||||||
|
# Create output directory with process index
|
||||||
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||||
|
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
||||||
|
|
||||||
|
# Process assigned data slice
|
||||||
|
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
||||||
|
batch_dir = base_path / f"process_{process_idx}_batch_{idx}"
|
||||||
|
batch_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
evaluator.generate_samples_with_prompt(
|
||||||
|
batch_dir,
|
||||||
|
args.num_target_measure,
|
||||||
|
tune_in_idx,
|
||||||
|
tune_name,
|
||||||
|
config.data_params.first_pred_feature,
|
||||||
|
args.sampling_method,
|
||||||
|
args.threshold,
|
||||||
|
args.temperature,
|
||||||
|
generation_length=args.generate_length
|
||||||
|
)
|
||||||
|
def generate_samples_unconditioned(config, vocab, model, device,save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||||
|
encoding_scheme = config.nn_params.encoding_scheme
|
||||||
|
|
||||||
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||||
|
try:
|
||||||
|
in_beat_resolution = in_beat_resolution_dict[config.dataset]
|
||||||
|
except KeyError:
|
||||||
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||||
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||||
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||||
|
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
|
||||||
|
|
||||||
|
for i in range(num_samples):
|
||||||
|
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||||
|
if encoding_scheme == 'nb':
|
||||||
|
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||||
|
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
|
||||||
|
|
||||||
|
def generate_samples_with_text_prompt(config, vocab, model, device, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||||
|
encoding_scheme = config.nn_params.encoding_scheme
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-large')
|
||||||
|
encoder = T5EncoderModel.from_pretrained('google/flan-t5-large').to(device)
|
||||||
|
print(f"Using T5EncoderModel for text prompt: {prompt}")
|
||||||
|
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(device)
|
||||||
|
context = encoder(**context).last_hidden_state
|
||||||
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||||
|
try:
|
||||||
|
in_beat_resolution = in_beat_resolution_dict[config.dataset]
|
||||||
|
except KeyError:
|
||||||
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||||
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||||
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||||
|
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
|
||||||
|
|
||||||
|
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||||
|
if encoding_scheme == 'nb':
|
||||||
|
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||||
|
# Open the jsonl file and count the number of lines to determine the current index
|
||||||
|
jsonl_path = save_dir / "name2prompt.jsonl"
|
||||||
|
if jsonl_path.exists():
|
||||||
|
with open(jsonl_path, 'r') as f:
|
||||||
|
current_idx = sum(1 for _ in f)
|
||||||
|
else:
|
||||||
|
current_idx = 0
|
||||||
|
|
||||||
|
name = f"prompt_{current_idx}"
|
||||||
|
name2prompt_dict = defaultdict(list)
|
||||||
|
name2prompt_dict[name].append(prompt)
|
||||||
|
with open(jsonl_path, 'a') as f:
|
||||||
|
f.write(json.dumps(name2prompt_dict) + '\n')
|
||||||
|
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))
|
||||||
|
|
||||||
|
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
|
||||||
|
"""Worker process for unconditioned generation"""
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
device = torch.device(f'cuda:{gpu_id}')
|
||||||
|
|
||||||
|
# Load resources with proper device
|
||||||
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||||
|
|
||||||
|
# Create output directory with process index
|
||||||
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||||
|
f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate assigned number of samples
|
||||||
|
batch_dir = base_path
|
||||||
|
generate_samples_unconditioned(
|
||||||
|
config,
|
||||||
|
vocab,
|
||||||
|
model,
|
||||||
|
batch_dir,
|
||||||
|
num_samples,
|
||||||
|
config.data_params.first_pred_feature,
|
||||||
|
args.sampling_method,
|
||||||
|
args.threshold,
|
||||||
|
args.temperature,
|
||||||
|
generation_length=args.generate_length,
|
||||||
|
uid=f"{process_idx}"
|
||||||
|
)
|
||||||
|
def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
|
||||||
|
"""Worker process for unconditioned generation"""
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
device = torch.device(f'cuda:{gpu_id}')
|
||||||
|
|
||||||
|
# Load resources with proper device
|
||||||
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||||
|
|
||||||
|
# Create output directory with process index
|
||||||
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||||
|
f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||||
|
base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate assigned number of samples
|
||||||
|
batch_dir = base_path
|
||||||
|
for idx, tune_name in enumerate(data_slice):
|
||||||
|
print(f"Process {process_idx} generating samples for tune: {tune_name}")
|
||||||
|
generate_samples_with_text_prompt(
|
||||||
|
config,
|
||||||
|
vocab,
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
batch_dir,
|
||||||
|
prompt=tune_name,
|
||||||
|
first_pred_feature=config.data_params.first_pred_feature,
|
||||||
|
sampling_method=args.sampling_method,
|
||||||
|
threshold=args.threshold,
|
||||||
|
temperature=args.temperature,
|
||||||
|
generation_length=args.generate_length,
|
||||||
|
uid=f"{process_idx}_{idx}"
|
||||||
|
)
|
||||||
|
def main():
|
||||||
|
# use spawn method for multiprocessing
|
||||||
|
set_start_method('spawn', force=True)
|
||||||
|
args = get_argument_parser().parse_args()
|
||||||
|
gpu_ids = list(map(int, args.gpu_ids.split(',')))
|
||||||
|
|
||||||
|
# Validate GPU availability
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA is not available")
|
||||||
|
if len(gpu_ids) == 0:
|
||||||
|
raise ValueError("At least one GPU must be specified")
|
||||||
|
|
||||||
|
# Validate process count
|
||||||
|
if args.num_processes < 1:
|
||||||
|
raise ValueError("Number of processes must be at least 1")
|
||||||
|
if len(gpu_ids) < args.num_processes:
|
||||||
|
print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared")
|
||||||
|
|
||||||
|
# Prepare data slices for processes
|
||||||
|
processes = []
|
||||||
|
try:
|
||||||
|
if args.generation_type == 'conditioned':
|
||||||
|
# Prepare selected tunes
|
||||||
|
wandb_dir = Path('wandb') / args.wandb_exp_dir
|
||||||
|
if not wandb_dir.exists():
|
||||||
|
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
||||||
|
|
||||||
|
# Load test set to get selected tunes (dummy load to get dataset info)
|
||||||
|
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
_, test_set, _ = prepare_model_and_dataset_from_config(
|
||||||
|
wandb_dir / "files" / "config.yaml",
|
||||||
|
wandb_dir / "files" / "metadata.json",
|
||||||
|
wandb_dir / "files" / "vocab.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||||
|
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||||
|
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||||
|
else:
|
||||||
|
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
|
||||||
|
|
||||||
|
# Split selected data across processes
|
||||||
|
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
|
||||||
|
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||||
|
|
||||||
|
for i in range(args.num_processes):
|
||||||
|
start_idx = i * chunk_size
|
||||||
|
end_idx = min((i+1)*chunk_size, len(selected_data))
|
||||||
|
data_slice = selected_data[start_idx:end_idx]
|
||||||
|
|
||||||
|
if not data_slice:
|
||||||
|
continue
|
||||||
|
|
||||||
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||||
|
p = Process(
|
||||||
|
target=conditioned_worker,
|
||||||
|
args=(i, gpu_id, args, data_slice)
|
||||||
|
)
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
elif args.generation_type == 'unconditioned':
|
||||||
|
samples_per_proc = args.num_samples // args.num_processes
|
||||||
|
remainder = args.num_samples % args.num_processes
|
||||||
|
|
||||||
|
for i in range(args.num_processes):
|
||||||
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||||
|
samples = samples_per_proc + (1 if i < remainder else 0)
|
||||||
|
|
||||||
|
if samples <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
p = Process(
|
||||||
|
target=unconditioned_worker,
|
||||||
|
args=(i, gpu_id, args, samples)
|
||||||
|
)
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
elif args.generation_type == 'text-conditioned':
|
||||||
|
samples_per_proc = args.num_samples // args.num_processes
|
||||||
|
remainder = args.num_samples % args.num_processes
|
||||||
|
# Load prompts from file
|
||||||
|
prompt_name_list = []
|
||||||
|
with open(args.prompt_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
prompt_data = json.loads(line.strip())
|
||||||
|
prompt_text = prompt_data['caption']
|
||||||
|
if prompt_data['test_set'] is True:
|
||||||
|
prompt_name_list.append(prompt_text)
|
||||||
|
print("length of prompt_name_list:", len(prompt_name_list))
|
||||||
|
if len(prompt_name_list) >= args.num_samples:
|
||||||
|
print(f"Reached the limit of {args.num_samples} prompts.")
|
||||||
|
break
|
||||||
|
for i in range(args.num_processes):
|
||||||
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||||
|
samples = samples_per_proc + (1 if i < remainder else 0)
|
||||||
|
|
||||||
|
if samples <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Split prompt names across processes
|
||||||
|
start_idx = i * (len(prompt_name_list) // args.num_processes)
|
||||||
|
end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes)
|
||||||
|
data_slice = prompt_name_list[start_idx:end_idx]
|
||||||
|
|
||||||
|
p = Process(
|
||||||
|
target=text_conditioned_worker,
|
||||||
|
args=(i, gpu_id, args, samples, data_slice)
|
||||||
|
)
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
# Wait for all processes to complete
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in main process: {str(e)}")
|
||||||
|
for p in processes:
|
||||||
|
p.terminate()
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user