first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

BIN
Amadeus/.DS_Store vendored Normal file

Binary file not shown.

0
Amadeus/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

56
Amadeus/catsample.py Normal file
View File

@ -0,0 +1,56 @@
import torch
import torch.nn.functional as F
def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
logits = categorical_probs.clamp(min=1e-9).log()
return F.gumbel_softmax(logits, hard=hard)
def sample_categorical(categorical_probs, method="hard"):
if method == "hard":
gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
return (categorical_probs / gumbel_norm).argmax(dim=-1)
else:
raise ValueError(f"Method {method} for sampling categorical variables is not valid.")
def direct_sampling(logits):
probs = logits.softmax(dim=-1)
index = sample_categorical(probs.to(torch.float32))
return index
def top_p_sampling(logits, p=0.9):
probs = logits.softmax(dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs.masked_fill_(indices_to_remove, 0)
probs /= probs.sum(dim=-1).unsqueeze(-1)
index = sample_categorical(probs.to(torch.float32))
return index
def top_k_sampling(logits, k=400):
top_k_values, top_k_indices = torch.topk(logits, int(k))
top_k_probs = top_k_values.softmax(dim=-1)
index = sample_categorical(top_k_probs.to(torch.float32))
index = top_k_indices[torch.arange(index.size(0)), index]
return index
def sample_with_strategy(update_logits, strategy, para = None):
if strategy == "direct":
return direct_sampling(update_logits)
elif strategy == "top_p":
return top_p_sampling(update_logits, para)
elif strategy == "top_k":
return top_k_sampling(update_logits, para)
else:
raise ValueError(f"Strategy {strategy} is not valid.")

533
Amadeus/evaluation_utils.py Normal file
View File

@ -0,0 +1,533 @@
from collections import defaultdict
from typing import Union
from math import log
from omegaconf import DictConfig
from pathlib import Path
import pickle
import json
import torch
from tqdm.auto import tqdm
from transformers import T5Tokenizer, T5EncoderModel
from . import model_zoo
from .symbolic_encoding import data_utils
from .model_zoo import AmadeusModel
from .symbolic_encoding.data_utils import TuneCompiler
from .symbolic_encoding.compile_utils import shift_and_pad
from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
from .symbolic_encoding import decoding_utils
from .train_utils import adjust_prediction_order
from data_representation import vocab_utils
from data_representation.vocab_utils import LangTokenVocab
def wandb_style_config_to_omega_config(wandb_conf):
# remove wandb related config
for wandb_key in ["wandb_version", "_wandb"]:
if wandb_key in wandb_conf:
del wandb_conf[wandb_key] # wandb-related config should not be overrided!
# print(wandb_conf)
# remove nonnecessary fields such as desc and value
for key in wandb_conf:
# if 'desc' in wandb_conf[key]:
# del wandb_conf[key]['desc']
if isinstance(wandb_conf[key], dict) and 'value' in wandb_conf[key]:
wandb_conf[key] = wandb_conf[key]['value']
# 处理存在'value'的情况
try:
if 'value' in wandb_conf[key]:
wandb_conf[key] = wandb_conf[key]['value']
except:
pass
return wandb_conf
def get_dir_from_wandb_by_code(wandb_dir: Path, code:str) -> Path:
for dir in wandb_dir.iterdir():
if dir.name.endswith(code):
return dir
print(f'No such code in wandb_dir: {code}')
return None
def get_best_ckpt_path_and_config(wandb_dir, code):
dir = get_dir_from_wandb_by_code(wandb_dir, code)
if dir is None:
raise ValueError('No such code in wandb_dir')
ckpt_dir = dir / 'files' / 'checkpoints'
config_path = dir / 'files' / 'config.yaml'
# print all files in ckpt_dir
vocab_path = next(ckpt_dir.glob('vocab*'))
metadata_path = next(ckpt_dir.glob('*metadata.json'))
# if there is pt file ending with 'last', return it
if len(list(ckpt_dir.glob('*last.pt'))) > 0:
last_ckpt_fn = next(ckpt_dir.glob('*last.pt'))
else:
pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', '')))
last_ckpt_fn = pt_fns[-1]
return last_ckpt_fn, config_path, metadata_path, vocab_path
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
nn_params = config.nn_params
dataset_name = config.dataset
vocab_path = Path(vocab_path)
if 'Encodec' in dataset_name:
encodec_tokens_path = Path(f"dataset/maestro-v3.0.0-encodec_tokens")
encodec_dataset = EncodecDataset(config, encodec_tokens_path, None, None)
vocab_sizes = encodec_dataset.vocab.get_vocab_size()
train_set, valid_set, test_set = encodec_dataset.split_train_valid_test_set()
lm_model:model_zoo.LanguageModelTransformer= getattr(model_zoo, nn_params.model_name)(config, vocab_sizes)
else:
# print(config)
encoding_scheme = config.nn_params.encoding_scheme
num_features = config.nn_params.num_features
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
in_vocab_file_path=vocab_path,
event_data=None,
encoding_scheme=encoding_scheme,
num_features=num_features)
# Initialize symbolic dataset based on dataset name and configuration parameters
symbolic_dataset = getattr(data_utils, dataset_name)(
vocab=vocab,
encoding_scheme=encoding_scheme,
num_features=num_features,
debug=config.general.debug,
aug_type=config.data_params.aug_type,
input_length=config.train_params.input_length,
first_pred_feature=config.data_params.first_pred_feature,
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
for_evaluation=True,
)
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
print(f"---{nn_params.main_decoder}--- is used")
print(f"---{dataset_name}--- is used")
print(f"---{encoding_scheme}--- is used")
split_ratio = config.data_params.split_ratio
# test_set = []
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
# get proper prediction order according to the encoding scheme and target feature in the config
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
# Create the Transformer model based on configuration parameters
AmadeusModel = getattr(model_zoo, nn_params.model_name)(
vocab=symbolic_dataset.vocab,
input_length=config.train_params.input_length,
prediction_order=prediction_order,
input_embedder_name=nn_params.input_embedder_name,
main_decoder_name=nn_params.main_decoder_name,
sub_decoder_name=nn_params.sub_decoder_name,
sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0,
sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \
if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False,
dim=nn_params.main_decoder.dim_model,
heads=nn_params.main_decoder.num_head,
depth=nn_params.main_decoder.num_layer,
dropout=nn_params.model_dropout,
)
return AmadeusModel, test_set, symbolic_dataset.vocab
def add_conti_in_valid(tensor, encoding_scheme):
new_target = tensor.clone()
# Assuming tensor shape is [batch, sequence, features]
# Create a shifted version of the tensor
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
# The first element of each sequence cannot be a duplicate by definition
shifted_tensor[:, 0, :] = new_target[:, 0, :] + 1
# Identify where the original and shifted tensors are the same (duplicates)
duplicates = new_target == shifted_tensor
# TODO: convert hard-coded part
# convert values into False except the 1st and 2nd features
if encoding_scheme == 'nb':
if tensor.shape[2] == 5:
# change beat, instrument
duplicates[:, :, 0] = False
duplicates[:, :, 3] = False
duplicates[:, :, 4] = False
elif tensor.shape[2] == 4:
# change beat
duplicates[:, :, 0] = False
duplicates[:, :, 2] = False
duplicates[:, :, 3] = False
elif tensor.shape[2] == 7:
# change beat, chord, tempo
duplicates[:, :, 0] = False
duplicates[:, :, 4] = False
duplicates[:, :, 5] = False
duplicates[:, :, 6] = False
elif encoding_scheme == 'cp':
if tensor.shape[2] == 5:
# change instrument
duplicates[:, :, 0] = False
duplicates[:, :, 1] = False
duplicates[:, :, 3] = False
duplicates[:, :, 4] = False
elif tensor.shape[2] == 7:
# change chord, tempo
duplicates[:, :, 0] = False
duplicates[:, :, 1] = False
duplicates[:, :, 4] = False
duplicates[:, :, 5] = False
duplicates[:, :, 6] = False
# Replace duplicates with 9999
new_target[duplicates] = 9999
return new_target
# TODO: hard coded
def add_conti(list_of_lists, encoding_scheme):
if encoding_scheme == 'nb':
if len(list_of_lists[0]) == 4:
# type, beat, pitch, duration
for i in range(0, len(list_of_lists)):
if list_of_lists[i][0] == 'SSS':
list_of_lists[i][1] = 'Conti'
elif len(list_of_lists[0]) == 5:
# type, beat, instrument, pitch, duration
previous_instrument = None
for i in range(0, len(list_of_lists)):
if list_of_lists[i][0] == 'SSS':
list_of_lists[i][1] = 'Conti'
if list_of_lists[i][2] == previous_instrument and previous_instrument != 0:
list_of_lists[i][2] = 'Conti'
else:
previous_instrument = list_of_lists[i][2]
elif len(list_of_lists[0]) == 7:
# type, beat, chord, tempo, pitch, duration, velocity
previous_chord = None
previous_tempo = None
for i in range(0, len(list_of_lists)):
if list_of_lists[i][0] == 'SSS':
list_of_lists[i][1] = 'Conti'
if list_of_lists[i][2] == previous_chord and previous_chord != 0:
list_of_lists[i][2] = 'Conti'
elif list_of_lists[i][2] != previous_chord and list_of_lists[i][2] != 0:
previous_chord = list_of_lists[i][2]
if list_of_lists[i][3] == previous_tempo and previous_tempo != 0:
list_of_lists[i][3] = 'Conti'
elif list_of_lists[i][3] != previous_tempo and list_of_lists[i][3] != 0:
previous_tempo = list_of_lists[i][3]
elif encoding_scheme == 'cp':
if len(list_of_lists[0]) == 7:
# type, beat, chord, tempo, pitch, duration, velocity
previous_chord = None
previous_tempo = None
for i in range(0, len(list_of_lists)):
current_chord = list_of_lists[i][2]
current_tempo = list_of_lists[i][3]
if current_chord == previous_chord and current_chord != 0:
list_of_lists[i][2] = 'Conti'
elif current_chord != previous_chord and current_chord != 0:
previous_chord = current_chord
if current_tempo == previous_tempo and current_tempo != 0:
list_of_lists[i][3] = 'Conti'
elif current_tempo != previous_tempo and current_tempo != 0:
previous_tempo = current_tempo
if len(list_of_lists[0]) == 5:
# type, beat, instrument, pitch, duration
previous_instrument = None
for i in range(0, len(list_of_lists)):
current_instrument = list_of_lists[i][2]
if current_instrument == previous_instrument and current_instrument != 0:
list_of_lists[i][2] = 'Conti'
elif current_instrument != previous_instrument and current_instrument != 0:
previous_instrument = current_instrument
return list_of_lists
class Evaluator:
def __init__(self,
config: DictConfig,
model:AmadeusModel,
test_set:TuneCompiler,
vocab: Union[LangTokenVocab, LangTokenVocab],
device:str='cuda',
batch_size:int=16):
self.config = config
self.device = device
self.vocab = vocab
self.model = model
self.model.eval()
self.model.to(device)
self.test_set = test_set
self.input_len = config.train_params.input_length
self.loss_by_class = {key:[] for key in self.vocab.feature_list}
self.count_by_class = {key:0 for key in self.vocab.feature_list}
self.batch_size = batch_size
self.is_multiclass = True if config.nn_params.encoding_scheme == 'nb' or config.nn_params.encoding_scheme == 'cp' else False
self.first_pred_feature = self.config.data_params.first_pred_feature
self.neglect_keywords = ['SSS', 'SSN', 'Conti', 'Metrical', 'Note']
self.valid_item_prob = []
# we don't use focal loss on evaluation
self.focal_alpha = 1
self.focal_gamma = 0
def save_results(self, save_fn):
# convert loss_by_clas tensor to cpu
for key in self.loss_by_class.keys():
self.loss_by_class[key] = torch.tensor(self.loss_by_class[key]).cpu()
self.count_by_class[key] = torch.tensor(self.count_by_class[key]).cpu()
torch.save({'loss_by_class':self.loss_by_class, 'count_by_class':self.count_by_class}, save_fn)
@torch.inference_mode()
def get_perplexity(self,less_than=256):
for data in tqdm(self.test_set.data_list, desc='Cal over dataset', position=0):
data_tensor = torch.LongTensor(data[0])
if self.config.nn_params.encoding_scheme == 'nb':
data_tensor = shift_and_pad(data_tensor, self.first_pred_feature)
data_tensor = data_tensor[:-1]
x_seg = data_tensor[:-1].unsqueeze(0)
y_seg = data_tensor[1:].unsqueeze(0)
self._cal_initial_seg(x_seg, y_seg)
if x_seg.shape[1] > self.input_len:
cat_logits = []
cat_y = []
cat_mask_indices = []
batch_x = x_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
batch_y = y_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
if self.is_multiclass:
batch_x = batch_x.transpose(1,2)
batch_y = batch_y.transpose(1,2)
for batch_start_idx in tqdm(range(0, min(batch_x.shape[0], less_than), self.batch_size), desc='In piece iter', position=1, leave=False):
x = batch_x[batch_start_idx:batch_start_idx+self.batch_size]
y = batch_y[batch_start_idx:batch_start_idx+self.batch_size]
logits, y,mask_indices = self._cal_following_seg(x, y)
cat_logits.append(logits)
cat_y.append(y)
cat_mask_indices.append(mask_indices)
if self.is_multiclass:
cat_dict = {}
for key in self.vocab.feature_list:
cat_dict[key] = torch.cat([logits_dict[key] for logits_dict in cat_logits], dim=0)
cat_logits = cat_dict
else:
cat_logits = torch.cat(cat_logits, dim=0)
cat_y = torch.cat(cat_y, dim=0)
mask_indices = torch.cat(cat_mask_indices, dim=0)
if self.is_multiclass:
self._update_loss_for_multi_class(cat_logits, cat_y,mask_indices)
else:
cat_prob = torch.nn.functional.softmax(cat_logits, dim=-1)
pt = cat_prob[torch.arange(cat_prob.shape[0]), cat_y]
# focal_loss = -self.focal_alpha * (1-pt)**self.focal_gamma * torch.log(pt) # [batch_size*seq_len]
loss = -torch.log(pt)
self._update_loss_for_single_class(loss, cat_y)
@torch.inference_mode()
def _update_loss_for_single_class(self, neg_log_prob:torch.Tensor, y:torch.Tensor):
for key in self.vocab.feature_list:
feature_mask = self.vocab.total_mask[key].to(y.device) # [vocab_size,]
mask_for_target = feature_mask[y] # [b*t]
normal_loss_seq_by_class = neg_log_prob[mask_for_target==1]
if mask_for_target.sum().item() != 0:
self.loss_by_class[key] += normal_loss_seq_by_class.tolist()
self.count_by_class[key] += mask_for_target.sum().item()
@torch.inference_mode()
def _update_loss_for_multi_class(self, logits_dict:dict, tgt:torch.Tensor, mask_indices:torch.Tensor=None):
correct_token_prob = []
for index, key in enumerate(self.vocab.feature_list):
feat_tgt = tgt[:,index]
logit_values = logits_dict[key]
logit_values = logit_values
prob_values = torch.nn.functional.softmax(logit_values, dim=-1)
# replce the false
correct_token_prob.append(prob_values[torch.arange(prob_values.shape[0]), feat_tgt])
correct_token_prob = torch.stack(correct_token_prob, dim=1)
# tgt = reverse_shift_and_pad_for_tensor(tgt, self.first_pred_feature)
y_decoded = self.vocab.decode(tgt)
y_decoded = add_conti(y_decoded, self.config.nn_params.encoding_scheme)
# correct_token_prob = reverse_shift_and_pad_for_tensor(correct_token_prob, self.first_pred_feature)
num_notes = logits_dict['pitch'].shape[0]
cum_prob = 1
max_num = mask_indices.size(0)
for idx in range(max_num):
if max_num != num_notes:
print("not equal",max_num,num_notes)
token = y_decoded[idx]
vaild_mask = mask_indices[idx,:]
token_prob = correct_token_prob[idx].tolist()
for j, key in enumerate(self.vocab.feature_list):
cur_feature = token[j]
whether_predicted = vaild_mask[j]
# clamp cur_prob to avoid when cur_prob is 0
cur_prob = max(token_prob[j], 1e-10)
if cur_feature == 0: # ignore token
continue
if whether_predicted is False: # skip provided token
continue
if cur_feature in self.neglect_keywords:
cum_prob *= cur_prob
continue
if self.config.nn_params.encoding_scheme == 'cp' and 'time_signature' in cur_feature:
cum_prob *= cur_prob
continue
if self.config.nn_params.encoding_scheme == 'cp' and 'Bar' in cur_feature:
cum_prob = 1
continue
self.valid_item_prob.append([cur_feature, cur_prob, cur_prob*cum_prob])
pt = cur_prob*cum_prob
loss = -log(pt)
self.loss_by_class[key].append(loss)
self.count_by_class[key] += 1
cum_prob = 1
@torch.inference_mode()
def _cal_initial_seg(self, x_seg, y_seg):
x, y = x_seg[:, :self.input_len].to(self.device), y_seg[:, :self.input_len].to(self.device)
mask_indices = torch.ones_like(y).bool().to(self.device).flatten(0,1)
if self.config.use_diff is True:
logits,(mask_indices,_) = self.model(x, y)
else:
logits = self.model(x, y)
y = y.flatten(0,1)
if self.is_multiclass:
for key in logits.keys():
feat_tensor = logits[key].flatten(0,1)
logits[key] = feat_tensor
self._update_loss_for_multi_class(logits, y, mask_indices)
else:
prob = torch.nn.functional.softmax(logits, dim=-1)
prob = prob.flatten(0,1)
pt = prob[torch.arange(len(y)), y]
loss = -torch.log(pt)
self._update_loss_for_single_class(loss, y)
@torch.inference_mode()
def _cal_following_seg(self, x:torch.Tensor, y:torch.Tensor):
x, y = x.to(self.device), y.to(self.device)
mask_indices = torch.ones_like(y).bool().to(self.device)
if self.config.use_diff is True:
logits,(mask_indices,_) = self.model(x, y)
else:
logits = self.model(x, y)
y = y[:, -1:].flatten(0,1).cpu()
mask_indices = mask_indices.reshape(x.shape)[:,-1:].flatten(0,1).cpu()
if self.is_multiclass:
logits_dict = {}
for key in self.vocab.feature_list:
logits_dict[key] = logits[key][:, -1:].flatten(0,1).cpu()
return logits_dict, y,mask_indices
else:
logits = logits[:, -1:].flatten(0,1).cpu()
return logits, y,mask_indices
def prepare_prompt_and_ground_truth(self, save_dir, num_target_samples, num_target_measures):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
for i, (tuneidx, tune_name) in enumerate(self.test_set):
ground_truth_sample = tuneidx
try:
decoder(ground_truth_sample, output_path=str(save_dir / f"{i}_{tune_name}_gt.mid"))
except:
print(f"Error in generating {i}_{tune_name}.mid")
prompt = self.model.decoder._prepare_inference(start_token=self.model.decoder.net.start_token, manual_seed=0, condition=tuneidx, num_target_measures=num_target_measures)
try:
decoder(prompt, output_path=str(save_dir / f"{i}_{tune_name}_prompt.mid"))
except:
print(f"Error in generating {i}_{tune_name}_prompt.mid")
if i == num_target_samples:
break
def generate_samples_with_prompt(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
tuneidx = tuneidx.cuda()
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
for i in range(num_samples):
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
def generate_samples_with_text_prompt(self, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = self.config.nn_params.encoding_scheme
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
encoder = T5EncoderModel.from_pretrained('google/flan-t5-base').to(self.device)
print(f"Using T5EncoderModel for text prompt: {prompt}")
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(self.device)
context = encoder(**context).last_hidden_state
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
# Open the jsonl file and count the number of lines to determine the current index
jsonl_path = save_dir / "name2prompt.jsonl"
if jsonl_path.exists():
with open(jsonl_path, 'r') as f:
current_idx = sum(1 for _ in f)
else:
current_idx = 0
name = f"prompt_{current_idx}"
name2prompt_dict = defaultdict(list)
name2prompt_dict[name].append(prompt)
with open(jsonl_path, 'a') as f:
f.write(json.dumps(name2prompt_dict) + '\n')
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))

512
Amadeus/model_zoo.py Normal file
View File

@ -0,0 +1,512 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import time
import json
from . import transformer_utils
from . import sub_decoder_zoo
from x_transformers.x_transformers import LayerIntermediates, AbsolutePositionalEmbedding
from data_representation.vocab_utils import LangTokenVocab
import os
class AmadeusModelWrapper(nn.Module):
def __init__(
self,
*,
vocab:LangTokenVocab,
input_length:int,
prediction_order:list,
input_embedder_name:str,
main_decoder_name:str,
sub_decoder_name:str,
sub_decoder_depth:int,
sub_decoder_enricher_use:bool,
dim:int,
heads:int,
depth:int,
dropout:float
):
'''
This class wraps the three main components of the AmadeusModel model,
which are the input embedding layer, the main transformer decoder, and the sub-decoder.
'''
super().__init__()
self.vocab = vocab
self.vocab_size = vocab.get_vocab_size()
self.start_token = vocab.sos_token if hasattr(vocab, 'sos_token') else None
self.end_token = vocab.eos_token if hasattr(vocab, 'eos_token') else None
self.input_length = input_length
self.prediction_order = prediction_order
self._get_input_embedder(input_embedder_name, vocab, dropout, dim)
self._get_main_decoder(main_decoder_name, input_length, dim, heads, depth, dropout)
self._get_sub_decoder(sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout)
self.bos_token_hidden = None
def _get_input_embedder(self, input_embedder_name, vocab, dropout, dim):
self.emb_dropout = nn.Dropout(dropout)
self.input_embedder = getattr(transformer_utils, input_embedder_name)(
vocab=vocab,
dim_model=dim
)
def _get_main_decoder(self, main_decoder_name, input_length, dim, heads, depth, dropout):
self.pos_enc = AbsolutePositionalEmbedding(dim, input_length)
self.main_norm = nn.LayerNorm(dim)
self.main_decoder = getattr(transformer_utils, main_decoder_name)(
dim=dim,
depth=depth,
heads=heads,
dropout=dropout
)
def _get_sub_decoder(self, sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout):
self.sub_decoder = getattr(sub_decoder_zoo, sub_decoder_name)(
prediction_order=prediction_order,
vocab=vocab,
dim=dim,
sub_decoder_depth=sub_decoder_depth,
heads=heads,
dropout=dropout,
sub_decoder_enricher_use=sub_decoder_enricher_use
)
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
embedding = self.input_embedder(input_seq) + self.pos_enc(input_seq)
embedding = self.emb_dropout(embedding)
hidden_vec,layer_inter = self.main_decoder(embedding,train=True, context=context) # B x T x d_model
hidden_vec = self.main_norm(hidden_vec)
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': target, 'bos_token_hidden': self.bos_token_hidden}
logits = self.sub_decoder(input_dict)
# 选择总数中离三分之一最近的层
num_layers = len(layer_inter.layer_hiddens)
idx = round(num_layers / 3)
idx = min(max(idx, 0), num_layers - 1)
input_dict['hidden_vec'] = layer_inter.layer_hiddens[idx]
return logits, input_dict
class AmadeusModelAutoregressiveWrapper(nn.Module):
def __init__(self, net:AmadeusModelWrapper):
'''
Initializes an autoregressive wrapper around the AmadeusModelWrapper,
which allows sequential token generation.
Arguments:
- net: The nested music transformer model that performs the token generation.
'''
super().__init__()
self.net = net
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context)
def _prepare_inference(self, start_token, manual_seed, condition=None, num_target_measures=4):
'''
Prepares the initial tokens for autoregressive inference. If a manual seed is provided,
it sets the seed for reproducibility. If a condition is given, it selects a subset of
the tokens based on certain criteria related to the encoding scheme.
Arguments:
- start_token: The token that represents the start of a sequence.
- manual_seed: A seed value for reproducibility in inference (if greater than 0).
- condition: An optional tensor used for conditional generation, which helps select a
portion of the input tokens based on the encoding scheme.
Returns:
- total_out: A tensor containing the initial tokens for inference, padded to ensure compatibility
with the model.
'''
if manual_seed > 0:
torch.manual_seed(manual_seed)
total_out = []
if condition is None:
# Use the start token if no condition is given
total_out.extend(start_token)
else:
# Extract the portion of the sequence depending on encoding scheme (remi, cp, or nb)
if self.net.vocab.encoding_scheme == 'remi':
type_boundaries = self.net.vocab.remi_vocab_boundaries_by_key['type']
# vocab idx -> 0:SOS, 1:EOS, 2:Bar_without_time_signature, ... where_type_ends:Bar_time_signature_end, ...
measure_bool = (2 <= condition) & (condition < type_boundaries[1]) # between Bar_ts_start and Bar_ts_end
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
elif self.net.vocab.encoding_scheme == 'cp':
# find the start and end of the measure
beat_event2idx = self.net.vocab.event2idx['beat']
for event, idx in beat_event2idx.items():
if event == 0:
continue
if event == 'Bar':
start_idx = idx
elif event.startswith('Beat'):
end_idx = idx
break
measure_bool = (condition[:,1] >= start_idx) & (condition[:,1] < end_idx) # measure tokens
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
# measure_bool = (condition[:,1] == 1) # measure tokens
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
elif self.net.vocab.encoding_scheme == 'nb':
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
if conditional_input_len == 0:
conditional_input_len = 50
selected_tokens = condition[:conditional_input_len].tolist()
total_out.extend(selected_tokens)
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
return total_out
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
'''
Runs one step of autoregressive generation by taking the input sequence, embedding it,
passing it through the main decoder, and generating logits and a sampled token.
Arguments:
- input_seq: The input sequence tensor to be embedded and processed.
- cache: Optional cache for attention mechanisms to avoid recomputation.
- sampling_method: Sampling strategy used to select the next token.
- threshold: Optional threshold value for sampling methods that require it.
- temperature: Controls the randomness of predictions (higher temperature increases randomness).
Returns:
- logits: The predicted logits for the next token.
- sampled_token: The token sampled from the logits.
- intermidiates: Intermediate states from the main decoder, useful for caching.
'''
embedding = self.net.input_embedder(input_seq) + self.net.pos_enc(input_seq)
embedding = self.net.emb_dropout(embedding)
# Run through the main decoder and normalize
hidden_vec, intermidiates = self.net.main_decoder(embedding, cache,context_embedding=context) # B x T x d_model
hidden_vec = self.net.main_norm(hidden_vec)
hidden_vec = hidden_vec[:, -1:] # Keep only the last time step
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
# Generate the next token
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
return logits, sampled_token, intermidiates, hidden_vec
def _update_total_out(self, total_out, sampled_token):
'''
Updates the output sequence with the newly sampled token. Depending on the encoding scheme,
it either appends the token directly or processes feature-based sampling.
Arguments:
- total_out: The tensor containing the previously generated tokens.
- sampled_token: The newly generated token to be appended.
Returns:
- total_out: Updated output tensor with the newly generated token.
- sampled_token: The processed sampled token.
'''
if self.net.vocab.encoding_scheme == 'remi':
# For remi encoding, directly append the sampled token
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
else:
# Handle other encoding schemes by concatenating features
sampled_token_list = []
for key in self.net.vocab.feature_list:
sampled_token_list.append(sampled_token[key])
sampled_token = torch.cat(sampled_token_list, dim=-1)
# print(total_out.shape)
if len(sampled_token.shape) == 2:
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=1)
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
return total_out, sampled_token
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
'''
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
until the desired maximum sequence length is reached or the end token is encountered.
Arguments:
- manual_seed: A seed value for reproducibility in inference.
- max_seq_len: The maximum length of the generated sequence.
- condition: An optional conditioning sequence to start generation from.
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
- temperature: Controls the randomness of the token sampling process.
- batch_size: The number of sequences to generate in parallel.
Returns:
- total_out: The generated sequence of tokens as a tensor.
'''
# Prepare the starting sequence for inference
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
# If a condition is provided, run one initial step
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
else:
cache = LayerIntermediates()
# Continue generating tokens until the maximum sequence length is reached
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
bos_hidden_vec = None
hidden_vec_list = []
token_time_list = []
while total_out.shape[1] < max_seq_len:
pbar.update(1)
input_tensor = total_out[:, -self.net.input_length:]
# Generate the next token and update the cache
time_start = time.time()
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time()
token_time_list.append(time_end - time_start)
if bos_hidden_vec is None:
bos_hidden_vec = hidden_vec
hidden_vec_list.append(hidden_vec)
# Update attention cache to handle autoregressive generation
for inter in cache.attn_intermediates:
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
# Update the generated output with the new token
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
# Stop if the end token is reached
if sampled_token.tolist() == self.net.end_token[0]:
break
# append hidden_vec to pkl
# save_path = 'hidden/diffnoaug_hidden_vec.pt'
# save_time_path = 'hidden/diff_noaug_token_time.json'
# if os.path.exists(save_path):
# # Load existing list and append
# hidden_vec_all = torch.load(save_path, map_location="cpu")
# hidden_vec_all.extend(hidden_vec_list)
# torch.save(hidden_vec_all, save_path)
# else:
# torch.save(hidden_vec_list, save_path)
# if os.path.exists(save_time_path):
# # Load existing list and append
# token_time_all = json.load(open(save_time_path, 'r'))
# token_time_all = token_time_all['token_time_list']
# token_time_all.extend(token_time_list)
# average_time = sum(token_time_all) / len(token_time_all)
# data = {
# 'average_time': average_time,
# 'token_time_list': token_time_all
# }
# json.dump(data, open(save_time_path, 'w'), indent=4)
# else:
# average_time = sum(token_time_list) / len(token_time_list)
# data = {
# 'average_time': average_time,
# 'token_time_list': token_time_list
# }
# json.dump(data, open(save_time_path, 'w'), indent=4)
return total_out
def generate_batch(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1):
'''
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
until the desired maximum sequence length is reached or the end token is encountered.
Arguments:
- manual_seed: A seed value for reproducibility in inference.
- max_seq_len: The maximum length of the generated sequence.
- condition: An optional conditioning sequence to start generation from.
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
- temperature: Controls the randomness of the token sampling process.
- batch_size: The number of sequences to generate in parallel.
Returns:
- total_out: The generated sequence of tokens as a tensor.
'''
# Prepare the starting sequence for inference
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
# total_out (1,1,num) -> (bs,1,num)
total_out = total_out.repeat(batch_size, 1, 1)
# If a condition is provided, run one initial step
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
else:
cache = LayerIntermediates()
# Continue generating tokens until the maximum sequence length is reached
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
while total_out.shape[1] < max_seq_len:
pbar.update(1)
input_tensor = total_out[:, -self.net.input_length:]
# Generate the next token and update the cache
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
# Update attention cache to handle autoregressive generation
for inter in cache.attn_intermediates:
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
# Update the generated output with the new token
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
# Stop if the end token is reached
if sampled_token.tolist() == self.net.end_token[0]:
break
return total_out
class AmadeusModel(nn.Module):
def __init__(
self,
vocab:LangTokenVocab,
input_length:int,
prediction_order:list,
input_embedder_name:str,
main_decoder_name:str,
sub_decoder_name:str,
sub_decoder_depth:int,
sub_decoder_enricher_use:bool,
dim:int,
heads:int,
depth:int,
dropout:float
):
'''
This class combines the wrapper classes and initializes the full AmadeusModel model,
which can perform autoregressive sequence generation for symbolic music.
Vocabulary used for tokenization of the symbolic music data.
Length of the input seqkeuence in tokens.
Defines the order in which features are predicted in a sequence used for compound shift
Name of the input embedding model to be used (e.g., one-hot embedding or learned embeddings).
Name of the main transformer decoder model used for generating the hidden representations for compound tokens.
Name of the sub-decoder, which processes the hidden states and decodes the sub-tokens inside the compound tokens.
Depth (number of layers) of the sub-decoder.
Whether to use an additional enricher module in the sub-decoder to refine representations.
Dimensionality of the model (hidden size of the transformer layers).
Number of attention heads in the transformer layers.
Number of layers in the main decoder.
Dropout rate for all layers in the model.
'''
super().__init__()
decoder = AmadeusModelWrapper(
vocab=vocab,
input_length=input_length,
prediction_order=prediction_order,
input_embedder_name=input_embedder_name,
main_decoder_name=main_decoder_name,
sub_decoder_name=sub_decoder_name,
sub_decoder_depth=sub_decoder_depth,
sub_decoder_enricher_use=sub_decoder_enricher_use,
dim=dim,
heads=heads,
depth=depth,
dropout=dropout
)
self.decoder = AmadeusModelAutoregressiveWrapper(
net=decoder
)
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
return self.decoder(input_seq, target, context=context)
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
if batch_size == 1:
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
else:
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
class AmadeusModel4Encodec(AmadeusModel):
def __init__(
self,
vocab:LangTokenVocab,
input_length:int,
prediction_order:list,
input_embedder_name:str,
main_decoder_name:str,
sub_decoder_name:str,
sub_decoder_depth:int,
sub_decoder_enricher_use:bool,
dim:int,
heads:int,
depth:int,
dropout:float
):
super().__init__(
vocab=vocab,
input_length=input_length,
prediction_order=prediction_order,
input_embedder_name=input_embedder_name,
main_decoder_name=main_decoder_name,
sub_decoder_name=sub_decoder_name,
sub_decoder_depth=sub_decoder_depth,
sub_decoder_enricher_use=sub_decoder_enricher_use,
dim=dim,
heads=heads,
depth=depth,
dropout=dropout
)
def _prepare_inference(self, start_token, manual_seed, condition=None):
if manual_seed > 0:
torch.manual_seed(manual_seed)
total_out = []
if condition is None:
total_out.extend(start_token)
else:
if self.decoder.net.vocab.encoding_scheme == 'remi':
selected_tokens = condition[:1500].tolist()
else:
selected_tokens = condition[:500].tolist()
total_out.extend(selected_tokens)
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.decoder.net.device)
return total_out
def _update_total_out(self, total_out, sampled_token):
if self.decoder.net.vocab.encoding_scheme == 'remi':
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
else:
sampled_token_list = []
for key in self.decoder.net.vocab.feature_list:
sampled_token_list.append(sampled_token[key])
sampled_token = torch.cat(sampled_token_list, dim=-1) # B(1) x num_features
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
return total_out, sampled_token
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1):
embedding = self.decoder.net.input_embedder(input_seq) + self.decoder.net.pos_enc(input_seq)
embedding = self.decoder.net.emb_dropout(embedding)
hidden_vec, intermidiates = self.decoder.net.main_decoder(embedding, cache) # B x T x d_model
hidden_vec = self.decoder.net.main_norm(hidden_vec)
hidden_vec = hidden_vec[:, -1:] # B x 1 x d_model
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': None}
if self.decoder.net.vocab.encoding_scheme == 'remi':
feature_class_idx = (input_seq.shape[1] - 1) % 4
feature_type = self.decoder.net.vocab.feature_list[feature_class_idx]
logits, sampled_token = self.decoder.net.sub_decoder.run_one_step(input_dict, sampling_method, threshold, temperature, feature_type)
else:
logits, sampled_token = self.decoder.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
return logits, sampled_token, intermidiates
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, sampling_method=None, threshold=None, temperature=1):
total_out = self._prepare_inference(self.decoder.net.start_token, manual_seed, condition)
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.decoder.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
else:
cache = LayerIntermediates()
while total_out.shape[1] < max_seq_len:
input_tensor = total_out[:, -self.decoder.net.input_length:]
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
for inter in cache.attn_intermediates:
inter.cached_kv = [t[..., -(self.decoder.net.input_length - 1):, :] for t in inter.cached_kv] # B x num_heads x T x d_head
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
if sampled_token.tolist() == self.decoder.net.end_token[0]:
break
return total_out

168
Amadeus/sampling_utils.py Normal file
View File

@ -0,0 +1,168 @@
import torch
import torch.nn.functional as F
def top_p_sampling(logits, thres=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Create an empty tensor to hold the new logits
new_logits = logits.clone()
# Use the sorted indices to place the '-inf' in the original places
indices_to_remove = sorted_indices[sorted_indices_to_remove]
new_logits[..., indices_to_remove] = float('-inf')
return new_logits
# refered: https://github.com/cimeister/typical-sampling
def typical_sampling(logits, thres=0.99):
# calculate entropy
normalized = torch.nn.functional.log_softmax(logits, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = logits.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < thres).sum(dim=-1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(-1, last_ind.view(-1, 1, 1))
# if self.min_tokens_to_keep > 1:
# # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
# sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
scores = logits.masked_fill(indices_to_remove, float("-inf"))
return scores
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
'''
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
#
# refered: https://github.com/john-hewitt/truncation-sampling
def eta_sampling(logits, epsilon) -> torch.FloatTensor:
probabilities = logits.softmax(dim=-1)
entropy = torch.distributions.Categorical(probs=probabilities).entropy()
new_epsilon = min(epsilon, torch.sqrt(torch.tensor(epsilon))*torch.exp(-entropy))
indices_to_remove = probabilities < new_epsilon
max_word = torch.argmax(logits, dim=-1)
indices_to_remove[..., max_word.squeeze()] = 0
new_scores = logits.masked_fill(indices_to_remove, float("-inf"))
return new_scores
def sample(logits, sampling_method, threshold, temperature):
"""Sample from the logits with a specific sampling strategy."""
if sampling_method == "top_p":
probs = F.softmax(top_p_sampling(logits, thres=threshold) / temperature, dim=-1)
elif sampling_method == "typical":
probs = F.softmax(typical_sampling(logits, thres=threshold) / temperature, dim=-1)
elif sampling_method == "eta":
probs = F.softmax(eta_sampling(logits, epsilon=threshold) / temperature, dim=-1)
else:
probs = F.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs[-1,-1,:], 1)
def sample_with_prob(logits, sampling_method, threshold, temperature):
"""Sample from the logits with a specific sampling strategy and return the token and its probability."""
# temporarily apply the sampling method to logits
logits = logits / temperature
# logits = add_gumbel_noise(logits, temperature)
if sampling_method == "top_p":
modified_logits = top_p_sampling(logits, thres=threshold)
elif sampling_method == "typical":
modified_logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
modified_logits = eta_sampling(logits, epsilon=threshold)
else:
modified_logits = logits # 其他情况直接使用原始logits
# print(modified_logits.shape)
# 应用温度调整并计算概率
# probs = F.softmax(modified_logits / temperature, dim=-1)
probs = F.softmax(modified_logits, dim=-1)
# 获取最后一个位置的概率分布
# probs_last = probs[-1, -1, :]
# print(probs.shape)
probs_last = probs[-1, -1, :]
# 采样
sampled_token = torch.multinomial(probs_last, num_samples=1)
# 获取对应的概率值
prob_value = probs_last[sampled_token]
return sampled_token, prob_value.squeeze()
def top_p_sampling_fast(logits, thres=0.9):
"""
logits: Tensor of shape [B, L, V]
Returns: logits with low-prob tokens masked as -inf, shape [B, L, V]
"""
# Step 1: sort logits and get indices
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) # [B, L, V]
# Step 2: compute cumulative probs
probs = F.softmax(sorted_logits, dim=-1) # [B, L, V]
cum_probs = torch.cumsum(probs, dim=-1) # [B, L, V]
# Step 3: mask tokens beyond cumulative threshold
sorted_mask = cum_probs > thres
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False # always keep at least one token
# Step 4: scatter back to original order
# Create mask of same shape as logits, default False
mask = torch.zeros_like(logits, dtype=torch.bool) # [B, L, V]
mask = mask.scatter(-1, sorted_indices, sorted_mask)
# Step 5: mask logits
logits = logits.masked_fill(mask, float('-inf')) # final masked logits
return logits
def sample_with_prob_fast(logits, sampling_method="top_p", threshold=0.9, temperature=1.0, mask_indices=None):
"""
logits: [B*T, num_sub_tokens, vocab_size]
mask_indices: mask indicating which tokens to sample, shape = [B*T, num_sub_tokens]
"""
if temperature != 1.0:
logits = logits / temperature
if sampling_method == "top_p":
logits = top_p_sampling_fast(logits, thres=threshold) # should support batch
elif sampling_method == "typical":
logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
logits = eta_sampling(logits, epsilon=threshold)
# else: keep logits as-is
probs = torch.softmax(logits, dim=-1) # [B*T, num_sub_tokens, vocab_size]
B, L, V = probs.shape
probs_flat = probs.view(-1, V) # [(B*T * num_sub_tokens), V]
# 采样multinomial 不能一次性处理 3D展平后采样
sampled = torch.multinomial(probs_flat, num_samples=1) # [(B*T * num_sub_tokens), 1]
sampled = sampled.view(B, L) # [B*T, num_sub_tokens]
sampled_probs = torch.gather(probs, 2, sampled.unsqueeze(-1)).squeeze(-1) # [B*T, num_sub_tokens]
return sampled, sampled_probs

View File

@ -0,0 +1,228 @@
from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, in_size, out_size, hidden_size, dropout):
super().__init__()
self.out_size = out_size
self.layer = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(hidden_size, out_size)
)
def forward(self, x):
return self.layer(x)
class extendedMLP(nn.Module):
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
super().__init__()
self.input_size = in_size
self.layers = nn.ModuleList()
if num_layers == 1:
# Only one layer
self.layers.append(nn.Linear(in_size, out_size))
return
elif num_layers > 1:
# First layer
self.layers.append(nn.Linear(in_size, hidden_size))
self.layers.append(nn.Dropout(dropout))
self.layers.append(nn.ReLU())
# Intermediate layers
if num_layers > 2:
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
self.layers.append(nn.Linear(hidden_size, hidden_size))
self.layers.append(nn.Dropout(dropout))
self.layers.append(nn.ReLU())
# Last layer
self.layers.append(nn.Linear(hidden_size, out_size))
else:
raise ValueError("num_layers should be a positive integer")
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class multiMLP(nn.Module):
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
super().__init__()
self.out_size = out_size
self.layer = nn.ModuleList([MLP(in_size, out_size, hidden_size, dropout) for _ in pred_order])
def forward(self, x, choice):
'''
x: B x T x d_model
choice: token type from self.pred_order (str or list of str)
'''
if isinstance(choice, str):
idx = self.pred_order.index(choice)
return self.layer[idx](x)
elif len(choice) > 1 and not isinstance(choice, str):
raise ValueError("multiMLP doesn't support parallel prediction")
class ResidualLayerNormModule(nn.Module):
def __init__(self, submodule: nn.Module):
super().__init__()
self.submodule = submodule
if submodule.__class__.__name__ == 'MultiheadAttention':
self.layer_norm = nn.LayerNorm(self.submodule.embed_dim)
else:
self.layer_norm = nn.LayerNorm(self.submodule.input_size)
def forward_attention(self, q, k, v, attn_mask, type):
attn_output, _ = self.submodule(q, k, v, attn_mask=attn_mask, need_weights=False, average_attn_weights=False)
return self.layer_norm(attn_output + q)
def forward_mlp(self, x):
return self.layer_norm(self.submodule(x) + x)
class MultiProj_hidden2logit(nn.Module):
def __init__(self, dim, vocab_sizes):
super().__init__()
self.layers = nn.ModuleDict({
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
})
def forward(self, hidden_vec, feature):
logit = self.layers[f"layer_{feature}"](hidden_vec)
return logit
class MultiProj_catvec2hidden(nn.Module):
def __init__(self, config, par_pred_keys, seq_pred_keys):
super().__init__()
'''
This class is used in SQstyleEachEmbStrategy
par_pred_keys: list of independent features(These tokens are predicted in parallel)
seq_pred_keys: list of sequential features(These tokens are predicted sequentially)
'''
net_param = config.nn_params
self.d_model = net_param.model.d_model
independent_emb_size = 0
for key in par_pred_keys:
independent_emb_size += net_param.emb[key]
self.layers = nn.ModuleDict({
'layer_independent': nn.Linear(self.d_model + independent_emb_size, self.d_model),
**{f"layer_{key}": nn.Linear(self.d_model + net_param.emb[key], self.d_model) for key in seq_pred_keys}
})
self.par_pred_keys = par_pred_keys
self.seq_pred_keys = seq_pred_keys
self.dropout = nn.Dropout(0.1)
self.relu = nn.ReLU()
def forward(self, x, choice):
'''
x: B x T x (d_model + emb_size)
choice: key type (str or list of str)
'''
if isinstance(choice, str): # single key
assert choice in self.seq_pred_keys
output = self.layers[f"layer_{choice}"](x)
return self.relu(self.dropout(output))
elif len(choice) > 1 and not isinstance(choice, str): # multiple keys, parallel
assert choice == self.par_pred_keys # the order of choice should be the same as the order of self.par_pred_keys
output = self.layers['layer_independent'](x)
return self.relu(self.dropout(output))
def mask_tensor(tensor, mask_rate=0.15):
# Get the size of the tensor
batch_size, seq_len, dim = tensor.size()
# Calculate the total number of elements and the number to mask
total_elements = batch_size * seq_len
num_to_mask = int(total_elements * mask_rate)
# Create a 1D binary mask where 1 indicates that element will be masked.
# Start by creating a tensor of zeros with length equal to the total number of elements.
mask = torch.zeros(total_elements).to(tensor.device)
# Set `num_to_mask` random indices to 1 (masking)
indices_to_mask = torch.randperm(total_elements)[:num_to_mask]
mask[indices_to_mask] = 1
# Reshape the mask to match the original tensor's shape
mask = mask.reshape(batch_size, seq_len)
mask = mask.unsqueeze(2) # B x T x 1
masked_tensor = tensor * (mask == 0).float() # B x T x d_model
return masked_tensor
def generate_causality_mask_on_window(size, window_size):
mask = torch.zeros((size, size))
for i in range(size):
mask[i, i+window_size:] = 1
return mask.bool()
# generate boolean mask, if the value is 1 or true, it means the value is masked
# considers BOS token and mask margin
def generate_CA_mask(tgt_len, memory_len, mask_margin=0):
mask = torch.triu(torch.ones((tgt_len, memory_len)), diagonal=mask_margin+1)
return mask.bool()
# generate boolean mask, if the value is 1 or true, it means the value is masked
def generate_SA_mask(tgt_len):
mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1)
return mask.bool()
def generate_none_causality_mask(tgt_len, memory_len):
mask = torch.zeros((tgt_len, memory_len))
return mask.bool()
class DecoderLayer(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
self.dropout = nn.Dropout(dropout)
def forward(self, input_dict):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
# cross attention
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
return output_dict
class TransformerLayer(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.self_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
self.dropout = nn.Dropout(dropout)
def forward(self, input_dict):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
# self attention
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
input_dict['input_seq'] = attn_output
# cross attention
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
return output_dict
class FeatureEnricher(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
self.dropout = nn.Dropout(dropout)
def forward(self, input_dict):
'''
input_dict = {'input_seq': input_seq, 'memory': memory}
'''
# cross attention
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], None, type='feature_enrichment')
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
return output_dict

1280
Amadeus/sub_decoder_zoo.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,46 @@
from sf2utils.sf2parse import Sf2File
def print_sorted_presets(sf2_path):
presets_info = []
with open(sf2_path, 'rb') as f:
sf2 = Sf2File(f)
for preset in sf2.presets:
try:
# 尝试直接读取
name = getattr(preset, 'name', '???').strip('\x00')
bank = getattr(preset, 'bank', None)
program = getattr(preset, 'preset', None)
# 如果获取不到,再尝试从子属性中取
if bank is None or program is None:
for attr in dir(preset):
attr_value = getattr(preset, attr)
if hasattr(attr_value, 'bank') and hasattr(attr_value, 'preset'):
bank = attr_value.bank
program = attr_value.preset
name = getattr(attr_value, 'name', name).strip('\x00')
break
# 收集有效结果
if bank is not None and program is not None:
presets_info.append((program, bank, name))
except Exception as e:
print(f"Error reading preset: {e}")
# 按 program 升序排序(若需要按 bank 再 program改为 sorted(..., key=lambda x: (x[1], x[0]))
presets_info.sort(key=lambda x: x[0])
# 打印结果
print(f"{'Program':<8} {'Bank':<6} {'Preset Name'}")
print("-" * 40)
for program, bank, name in presets_info:
print(f"{program:<8} {bank:<6} {name}")
# DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
# 替换为你的 sf2 文件路径
sf2_path = "/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2"
print_sorted_presets(sf2_path)

View File

@ -0,0 +1,94 @@
import random
from typing import Union
import torch
class Augmentor:
def __init__(
self,
vocab,
aug_type:Union[str, None],
input_length:int
):
self.vocab = vocab
self.aug_type = aug_type
self.input_length = input_length
self.feature_list = vocab.feature_list
self.num_features = len(self.feature_list)
self.encoding_scheme = vocab.encoding_scheme
self.pitch_idx = self.feature_list.index('pitch')
if 'chord' in self.feature_list:
self.chord_idx = self.feature_list.index('chord')
def _get_shift(self, segment):
# the pitch vocab has ignore token in 0 index
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
pitch_mask = segment != 0
pitch_segment = segment[pitch_mask[:,self.pitch_idx], self.pitch_idx]
# check if tensor is empty
if pitch_segment.numel() == 0:
shift = 0
else:
lowest_pitch = max(12, torch.min(pitch_segment))
highest_pitch = min(119, torch.max(pitch_segment))
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) > 11)[0][-1].item()
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) < 120)[0][-1].item()
shift = random.randint(-lower_shift_bound, upper_shift_bound)
else: # remi
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
segemnt_pitch_mask = mask_for_pitch[segment]
segment_pitch = segment * segemnt_pitch_mask
segment_pitch = segment_pitch[segment_pitch != 0]
# check if tensor is empty
if segment_pitch.numel() == 0:
shift = 0
else:
lower_bound = torch.argwhere(mask_for_pitch == 1)[0].item()
upper_bound = torch.argwhere(mask_for_pitch == 1)[-1].item()
lowest_pitch = max(lower_bound, torch.min(segment_pitch))
highest_pitch = min(upper_bound, torch.max(segment_pitch))
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) >= lower_bound)[0][-1].item()
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) <= upper_bound)[0][-1].item()
shift = random.randint(-lower_shift_bound, upper_shift_bound)
return shift
# TODO: arrange hard coded part
def __call__(self, segment):
'''
input_tensor is segments of x, y
for transformer_xl, the shape of x, y is [max_num_segments, input_length, num_features]
so we need to change the shape of x, y to [max_num_segments*input_length, num_features]
'''
if self.aug_type == 'random':
shift = self._get_shift(segment)
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
# pitch augmentation
segment_pitch_mask = segment != 0
new_segment = segment.clone()
new_segment[segment_pitch_mask[:,self.pitch_idx], self.pitch_idx] += shift
if 'chord' in self.feature_list:
# chord augmentation
segment_chord_mask = (segment[:,self.chord_idx] != 0) & (segment[:,self.chord_idx] != 1)
new_segment[segment_chord_mask, self.chord_idx] = (((new_segment[segment_chord_mask, self.chord_idx]-2) % 12) + shift ) % 12 + ((new_segment[segment_chord_mask, self.chord_idx]-2) // 12) * 12 + 2
segment = new_segment
else: # remi
# choose random interger between -5 and 6
# the augmented results from shift -6 and 6 are same, so we choose -5 and 6
# pitch augmentation
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
segment_pitch_mask = mask_for_pitch[segment]
new_segment = segment.clone()
new_segment_valid = (new_segment + shift) * segment_pitch_mask
new_segment = new_segment * (1 - segment_pitch_mask) + new_segment_valid
if 'chord' in self.feature_list:
# chord augmentation
mask_for_chord = self.vocab.total_mask['chord'].clone().to(segment.device)
chord_n_n_idx = torch.argwhere(mask_for_chord == 1)[-1].item()
mask_for_chord[chord_n_n_idx] = 0
start_idx_chord = self.vocab.remi_vocab_boundaries_by_key['chord'][0]
segment_chord_mask = mask_for_chord[segment]
new_segment_valid = ((((new_segment - start_idx_chord) % 12 + shift) % 12) + ((new_segment - start_idx_chord) // 12) * 12 + start_idx_chord) * segment_chord_mask
new_segment = new_segment * (1 - segment_chord_mask) + new_segment_valid
segment = new_segment
return segment

View File

@ -0,0 +1,207 @@
import random
from collections import defaultdict
import torch
import numpy as np
import random
def reverse_shift_and_pad(tune_in_idx, slice_boundary=4):
new_lst = [curr_elems[:slice_boundary] + next_elems[slice_boundary:] for curr_elems, next_elems in zip(tune_in_idx, tune_in_idx[1:])]
return new_lst
def reverse_shift_and_pad_for_tensor(tensor, first_pred_feature):
'''
tensor: [batch_size x seq_len x feature_size]
'''
if first_pred_feature == 'type':
return tensor
if tensor.shape[-1] == 8:
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
elif tensor.shape[-1] == 7:
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'pitch':4, 'duration':5, 'velocity':6}
elif tensor.shape[-1] == 5:
slice_boundary_dict = {'type':0, 'beat':1, 'instrument':2, 'pitch':3, 'duration':4}
elif tensor.shape[-1] == 4:
slice_boundary_dict = {'type':0, 'beat':1, 'pitch':2, 'duration':3}
slice_boundary = slice_boundary_dict[first_pred_feature]
new_tensor = torch.zeros_like(tensor)
new_tensor[..., :, :slice_boundary] = tensor[..., :, :slice_boundary]
new_tensor[..., :-1, slice_boundary:] = tensor[..., 1:, slice_boundary:]
return new_tensor
def shift_and_pad(tune_in_idx, first_pred_feature):
if first_pred_feature == 'type':
return tune_in_idx
if len(tune_in_idx[0]) == 8:
slice_boundary_dict = {'type':0, 'beat':-7, 'chord':-6, 'tempo':-5, 'instrument':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
elif len(tune_in_idx[0]) == 7:
slice_boundary_dict = {'type':0, 'beat':-6, 'chord':-5, 'tempo':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
elif len(tune_in_idx[0]) == 5:
slice_boundary_dict = {'type':0, 'beat':-4, 'instrument':-3, 'pitch':-2, 'duration':-1}
elif len(tune_in_idx[0]) == 4:
slice_boundary_dict = {'type':0, 'beat':-3, 'pitch':-2, 'duration':-1}
slice_boundary = slice_boundary_dict[first_pred_feature]
# Add an empty list padded with zeros at the beginning, and sos and eos tokens are not shifted
padded_tune_in_idx = torch.cat([torch.zeros(1, len(tune_in_idx[0]), dtype=torch.long), tune_in_idx], dim=0)
new_tensor = torch.zeros_like(padded_tune_in_idx)
new_tensor[:, slice_boundary:] = padded_tune_in_idx[:, slice_boundary:]
new_tensor[:-1, :slice_boundary] = padded_tune_in_idx[1:, :slice_boundary]
return new_tensor
class VanillaTransformer_compiler():
def __init__(
self,
data_list,
augmentor,
eos_token,
input_length,
first_pred_feature,
encoding_scheme
):
self.data_list = data_list
self.augmentor = augmentor
self.eos_token = eos_token
self.input_length = input_length
self.first_pred_feature = first_pred_feature
self.encoding_scheme = encoding_scheme
def make_segments(self, data_type):
segments = []
tune_name2segment = defaultdict(list)
segment2tune_name = []
num_segments = 0
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
# shift and pad
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
if data_type == 'train':
if len(tune_in_idx) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
else:
start_point = 0
while start_point + self.input_length+1 < len(tune_in_idx):
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment = tune_in_idx[start_point:start_point + self.input_length+1]
segments.append([segment, mask])
segment2tune_name.append(tune_name)
assert len(segment) == self.input_length+1
# Randomly choose the start point for the next segment, which is in the range of half of the current segment to the end of the current segment
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
# if text controled,we only use the first segment
# add the last segment
if len(tune_in_idx[start_point:]) < self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
else: # for validset
for i in range(0, len(tune_in_idx), self.input_length+1):
segment = tune_in_idx[i:i+self.input_length+1]
if len(segment) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([segment, padding_seq], dim=0)
segment2tune_name.append(tune_name)
segments.append([segment, mask])
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
else:
mask = torch.ones(self.input_length+1, dtype=torch.long)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
segments.append([segment, mask])
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
assert len(segment) == self.input_length+1
return segments, tune_name2segment, segment2tune_name
def make_segments_iters(self, data_type):
tune_name2segment = defaultdict(list)
segment2tune_name = []
num_segments = 0
# shuffle the data_list
if data_type == 'train':
random.shuffle(self.data_list)
print("length of data_list:", len(self.data_list))
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
# shift and pad
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
if data_type == 'train':
if len(tune_in_idx) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
else:
start_point = 0
while start_point + self.input_length+1 < len(tune_in_idx):
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment = tune_in_idx[start_point:start_point + self.input_length+1]
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
assert len(segment) == self.input_length+1
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
# break
if len(tune_in_idx[start_point:]) < self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
else: # for validset
for i in range(0, len(tune_in_idx), self.input_length+1):
segment = tune_in_idx[i:i+self.input_length+1]
if len(segment) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([segment, padding_seq], dim=0)
segment2tune_name.append(tune_name)
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
yield [segment, mask], tune_name2segment, segment2tune_name
else:
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment2tune_name.append(tune_name)
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
yield [segment, mask], tune_name2segment, segment2tune_name
assert len(segment) == self.input_length+1

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,404 @@
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
class MuteWarn:
def __enter__(self):
self._init_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._init_stdout
def save_score_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
with MuteWarn():
convert = converter.parse(midi_fn)
convert.write('musicxml.png', fp=file_name)
def save_pianoroll_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
midi_obj_muspy = muspy.read_midi(midi_fn)
midi_obj_muspy.show_pianoroll(track_label='program', preset='frame')
plt.gcf().set_size_inches(20, 10)
plt.savefig(file_name)
plt.close()
def save_wav_from_midi(midi_fn, file_name, qpm=80):
assert isinstance(midi_fn, str)
with MuteWarn():
music = muspy.read_midi(midi_fn)
music.tempos[0].qpm = qpm
music.write_audio(file_name, rate=44100, gain=3)
def save_wav_from_midi_fluidsynth(midi_fn, file_name, gain=3):
assert isinstance(midi_fn, str)
fs = FluidSynth(gain=gain)
fs.midi_to_audio(midi_fn, file_name)
class MidiDecoder4REMI:
def __init__(
self,
vocab,
in_beat_resolution,
dataset_name
):
self.vocab = vocab
self.in_beat_resolution = in_beat_resolution
self.dataset_name = dataset_name
if dataset_name == 'SymphonyMIDI':
self.gain = 0.7
elif dataset_name == 'SOD' or dataset_name == 'LakhClean':
self.gain = 1.1
elif dataset_name == 'Pop1k7' or dataset_name == 'Pop909':
self.gain = 2.5
else:
self.gain = 1.5
def __call__(self, generated_output, output_path=None):
'''
generated_output: list of tensor, the tensor
'''
idx2event = self.vocab.idx2event
if generated_output.dim() == 2:
generated_output = generated_output.squeeze(0)
events = [idx2event[token.item()] for token in generated_output]
midi_obj = miditoolkit.midi.parser.MidiFile()
if 'tempo' not in idx2event.keys():
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
instr_notes_dict = defaultdict(list)
for i in range(len(events)-2):
cur_event = events[i]
# print(cur_event)
name = cur_event.split('_')[0]
attr = cur_event.split('_')
if name == 'Bar':
bar_pos += cur_bar_resol
if 'time' in cur_event:
cur_num, cur_denom = attr[-1].split('/')
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif name == 'Beat':
beat_pos = int(attr[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks
elif name == 'Chord':
chord_text = attr[1] + '_' + attr[2]
midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
elif name == 'Tempo':
midi_obj.tempo_changes.append(
TempoChange(tempo=int(attr[1]), time=cur_pos))
elif name == 'Instrument':
cur_instr = int(attr[1])
else:
if len(self.vocab.feature_list) == 7 or len(self.vocab.feature_list) == 8:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1] and \
'Note_Velocity' in events[i+2]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = int(events[i+2].split('_')[-1])
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
elif len(self.vocab.feature_list) == 4 or len(self.vocab.feature_list) == 5:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = 90
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
# make subdir
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4CP(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_chord_tempo(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0:
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, tempo, chord,'beat, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if 'time_signature' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['type']] == 'Metrical':
if 'time_signature' in token_with_7infos[feature2idx['beat']]:
cur_num, cur_denom = token_with_7infos[feature2idx['beat']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['beat']] == 'Bar':
bar_pos += cur_bar_resol
elif 'Beat' in str(token_with_7infos[feature2idx['beat']]):
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# chord and tempo
midi_obj = self._update_chord_tempo(midi_obj, cur_pos, token_with_7infos, feature2idx)
elif token_with_7infos[feature2idx['type']] == 'Note':
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1]
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 80
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
output_music_dir = os.path.join(os.path.dirname(output_path), 'music')
if not os.path.exists(output_music_dir):
os.makedirs(output_music_dir)
midi_obj.dump(output_path)
save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, output_music_dir.replace('.mid', '.wav'), gain=self.gain)
return midi_obj
class MidiDecoder4NB(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_additional_info(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0 and token_with_7infos[feature2idx['chord']] != 'Chord_N_N':
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, beat, chord, tempo, intrument, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if token_with_7infos[feature2idx['type']] == 'Empty_Bar' or token_with_7infos[feature2idx['type']] == 'SNN':
bar_pos += cur_bar_resol
elif 'NNN' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
if 'Beat' in str(token_with_7infos[feature2idx['beat']]) or 'CONTI' in str(token_with_7infos[feature2idx['beat']]):
if 'Beat' in str(token_with_7infos[feature2idx['beat']]): # when beat is not CONTI beat is updated
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# update chord and tempo
midi_obj = self._update_additional_info(midi_obj, cur_pos, token_with_7infos, feature2idx)
# note
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1] # duration between 1~192
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 90
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj

View File

@ -0,0 +1,208 @@
import torch
import numpy as np
from collections import Counter
# TODO: refactor hard coded values
def check_syntax_errors_in_inference_for_nb(generated_output, feature_list):
generated_output = generated_output.squeeze(0)
type_idx = feature_list.index('type')
beat_idx = feature_list.index('beat')
type_beat_list = []
for token in generated_output:
type_beat_list.append((token[type_idx].item(), token[beat_idx].item())) # type, beat
last_note = 1
beat_type_unmatched_error_list = []
num_unmatched_errors = 0
beat_backwards_error_list = []
num_backwards_errors = 0
for type_beat in type_beat_list:
if type_beat[0] == 4: # same bar, new beat
if type_beat[1] == 0 or type_beat[1] == 1:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(type_beat)
if type_beat[1] <= last_note:
num_backwards_errors += 1
beat_backwards_error_list.append([last_note, type_beat])
else:
last_note = type_beat[1] # update last note
elif type_beat[0] >= 5: # new bar, new beat
if type_beat[1] == 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(type_beat)
last_note = 1
unmatched_error_rate = num_unmatched_errors / len(type_beat_list)
backwards_error_rate = num_backwards_errors / len(type_beat_list)
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
return type_beat_errors_dict
def check_syntax_errors_in_inference_for_cp(generated_output, feature_list):
generated_output = generated_output.squeeze(0)
type_idx = feature_list.index('type')
beat_idx = feature_list.index('beat')
pitch_idx = feature_list.index('pitch')
duration_idx = feature_list.index('duration')
last_note = 1
beat_type_unmatched_error_list = []
num_unmatched_errors = 0
beat_backwards_error_list = []
num_backwards_errors = 0
for token in generated_output:
if token[type_idx].item() == 2: # Metrical
if token[pitch_idx].item() != 0 or token[duration_idx].item() != 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(token)
if token[beat_idx].item() == 1: # new bar
last_note = 1 # last note will be updated in the next token
elif token[beat_idx].item() != 0 and token[beat_idx].item() <= last_note:
num_backwards_errors += 1
last_note = token[beat_idx].item() # update last note
beat_backwards_error_list.append([last_note, token])
else:
last_note = token[beat_idx].item() # update last note
if token[type_idx].item() == 3: # Note
if token[beat_idx].item() != 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(token)
unmatched_error_rate = num_unmatched_errors / len(generated_output)
backwards_error_rate = num_backwards_errors / len(generated_output)
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
return type_beat_errors_dict
def check_syntax_errors_in_inference_for_remi(generated_output, vocab):
generated_output = generated_output.squeeze(0)
# to check duration errors
beat_mask = vocab.total_mask['beat'].to(generated_output.device)
beat_mask_for_target = beat_mask[generated_output]
beat_target = generated_output * beat_mask_for_target
bar_mask = vocab.total_mask['type'].to(generated_output.device)
bar_mask_for_target = bar_mask[generated_output]
bar_target = (generated_output+1) * bar_mask_for_target # as bar token in 0 in remi vocab, we add 1 to bar token
target = beat_target + bar_target
target = target[target!=0]
# collect beats in between bars(idx=1)
num_backwards_errors = 0
collected_beats = []
total_beats = 0
for token in target:
if token == 1 or 3 <= token <= 26: # Bar_None, or Bar_time_signature
collected_beats_tensor = torch.tensor(collected_beats)
diff = torch.diff(collected_beats_tensor)
num_error_beats = torch.where(diff<=0)[0].shape[0]
num_backwards_errors += num_error_beats
collected_beats = []
else:
collected_beats.append(token.item())
total_beats += 1
if total_beats != 0:
backwards_error_rate = num_backwards_errors / total_beats
else:
backwards_error_rate = 0
# print(f"error rate in beat backwards: {backwards_error_rate}")
return {'beat_backwards_error': backwards_error_rate}
def type_beat_errors_in_validation_nb(beat_prob, answer_type, input_beat, mask):
bool_mask = mask.bool().flatten() # (b*t)
pred_beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
valid_pred_beat_idx = pred_beat_idx[bool_mask] # valid beat_idx
answer_type = answer_type.flatten() # (b*t)
valid_type_input = answer_type[bool_mask] # valid answer_type
type_beat_list = []
for i in range(len(valid_pred_beat_idx)):
type_beat_list.append((valid_type_input[i].item(), valid_pred_beat_idx[i].item())) # type, beat
input_beat = input_beat.flatten()
valid_input_beat = input_beat[bool_mask]
last_note = 1
num_unmatched_errors = 0
num_backwards_errors = 0
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
# update last note
if input_beat_idx.item() >= 1: # beat
last_note = input_beat_idx.item()
if type_beat[0] == 4: # same bar, new beat
if type_beat[1] == 0 or type_beat[1] == 1:
num_unmatched_errors += 1
if type_beat[1] <= last_note:
num_backwards_errors += 1
elif type_beat[0] >= 5: # new bar, new beat
if type_beat[1] == 0:
num_unmatched_errors += 1
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
def type_beat_errors_in_validation_cp(beat_prob, answer_type, input_beat, mask):
bool_mask = mask.bool().flatten() # (b*t)
beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
valid_beat_idx = beat_idx[bool_mask] # valid beat_idx
answer_type = answer_type.flatten() # (b*t)
valid_type_input = answer_type[bool_mask] # valid answer_type
type_beat_list = []
for i in range(len(valid_beat_idx)):
type_beat_list.append((valid_type_input[i].item(), valid_beat_idx[i].item())) # type, beat
input_beat = input_beat.flatten()
valid_input_beat = input_beat[bool_mask]
last_note = 1
num_unmatched_errors = 0
num_backwards_errors = 0
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
# update last note
if input_beat_idx.item() == 1: # bar
last_note = 1
elif input_beat_idx.item() >= 2: # new beat
last_note = input_beat_idx.item()
# check errors
if type_beat[0] == 2: # Metrical
if type_beat[1] == 0: # ignore
num_unmatched_errors += 1
elif type_beat[1] >= 2: # new beat
if type_beat[1] <= last_note:
num_backwards_errors += 1
elif type_beat[0] == 3: # Note
if type_beat[1] != 0:
num_unmatched_errors += 1
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
def get_beat_difference_metric(prob_dict, arranged_prob_dict, mask):
orign_beat_prob = prob_dict['beat'] # b x t x vocab_size
arranged_beat_prob = arranged_prob_dict['beat'] # b x t x vocab_size
# calculate similarity between original beat prob and arranged beat prob
origin_beat_token = torch.argmax(orign_beat_prob, dim=-1) * mask # b x t
arranged_beat_token = torch.argmax(arranged_beat_prob, dim=-1) * mask # b x t
num_same_beat = torch.sum(origin_beat_token == arranged_beat_token) - torch.sum(mask==0)
num_beat = torch.sum(mask==1)
beat_sim = (num_same_beat / num_beat).item() # scalar
# apply mask, shape of mask: b x t
orign_beat_prob = orign_beat_prob * mask.unsqueeze(-1) # b x t x vocab_size
arranged_beat_prob = arranged_beat_prob * mask.unsqueeze(-1)
# calculate cosine similarity between original beat prob and arranged beat prob
orign_beat_prob = orign_beat_prob.flatten(0,1) # (b*t) x vocab_size
arranged_beat_prob = arranged_beat_prob.flatten(0,1) # (b*t) x vocab_size
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
beat_cos_sim = cos(orign_beat_prob, arranged_beat_prob) # (b*t)
# exclude invalid tokens, zero padding tokens
beat_cos_sim = beat_cos_sim[mask.flatten().bool()] # num_valid_tokens
beat_cos_sim = torch.mean(beat_cos_sim).item() # scalar
return {'beat_cos_sim': beat_cos_sim, 'beat_sim': beat_sim}
def get_gini_coefficient(generated_output):
if len(generated_output.shape) == 3:
generated_output = generated_output.squeeze(0).tolist()
gen_list = [tuple(x) for x in generated_output]
else:
gen_list = generated_output.squeeze(0).tolist()
counts = Counter(gen_list).values()
sorted_counts = sorted(counts)
n = len(sorted_counts)
cumulative_counts = np.cumsum(sorted_counts)
cumulative_proportion = cumulative_counts / cumulative_counts[-1]
lorenz_area = sum(cumulative_proportion[:-1]) / n # Exclude the last element
equality_area = 0.5 # The area under line of perfect equality
gini = (equality_area - lorenz_area) / equality_area
return gini

View File

@ -0,0 +1,78 @@
import argparse
import os
import subprocess
from pydub import AudioSegment
'''
This file is a modified version of midi2audio.py from https://github.com/bzamecnik/midi2audio
Author: Bohumír Zámečník (@bzamecnik)
License: MIT, see the LICENSE file
'''
__all__ = ['FluidSynth']
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
DEFAULT_SAMPLE_RATE = 48000
DEFAULT_GAIN = 0.05
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
# DEFAULT_SAMPLE_RATE = 16000
# DEFAULT_GAIN = 0.20
class FluidSynth():
def __init__(self, sound_font=DEFAULT_SOUND_FONT, sample_rate=DEFAULT_SAMPLE_RATE, gain=DEFAULT_GAIN):
self.sample_rate = sample_rate
self.sound_font = os.path.expanduser(sound_font)
self.gain = gain
def midi_to_audio(self, midi_file: str, audio_file: str, verbose=True):
if verbose:
stdout = None
else:
stdout = subprocess.DEVNULL
# Convert MIDI to WAV
subprocess.call(
['fluidsynth', '-ni', '-g', str(self.gain), self.sound_font, midi_file, '-F', audio_file, '-r', str(self.sample_rate)],
stdout=stdout
)
# Convert WAV to MP3
# mp3_path = audio_file.replace('.wav', '.mp3')
# AudioSegment.from_wav(audio_file).export(mp3_path, format="mp3")
# # Delete the temporary WAV file
# os.remove(audio_file)
def play_midi(self, midi_file):
subprocess.call(['fluidsynth', '-i', '-g', str(self.gain), self.sound_font, midi_file, '-r', str(self.sample_rate)])
def parse_args(allow_synth=True):
parser = argparse.ArgumentParser(description='Convert MIDI to audio via FluidSynth')
parser.add_argument('midi_file', metavar='MIDI', type=str)
if allow_synth:
parser.add_argument('audio_file', metavar='AUDIO', type=str, nargs='?')
parser.add_argument('-s', '--sound-font', type=str,
default=DEFAULT_SOUND_FONT,
help='path to a SF2 sound font (default: %s)' % DEFAULT_SOUND_FONT)
parser.add_argument('-r', '--sample-rate', type=int, nargs='?',
default=DEFAULT_SAMPLE_RATE,
help='sample rate in Hz (default: %s)' % DEFAULT_SAMPLE_RATE)
return parser.parse_args()
def main(allow_synth=True):
args = parse_args(allow_synth)
fs = FluidSynth(args.sound_font, args.sample_rate)
if allow_synth and args.audio_file:
fs.midi_to_audio(args.midi_file, args.audio_file)
else:
fs.play_midi(args.midi_file)
def main_play():
"""
A method for the `midiplay` entry point. It omits the audio file from args.
"""
main(allow_synth=False)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,65 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
- nn_params: nb8_embSum_diff_t2m_150M_finetunning
# - nn_params: nb8_embSum_diff_t2m_150M_pretraining
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M
# - nn_params: nb8_embSum_subFeedForward
# - nn_params: nb8_embSum_diff
# nn_params: nb8_SA_diff
# - nn_params: nb8_embSum_diff_main12head16dim512_ave
# - nn_params: nb8_embSum_NMT_main12_head_16_dim512
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
# captions_path: dataset/symphonyNet/syd-caption.json
use_ddp: True # True, False | distributed data parallel
use_fp16: True # True, False | mixed precision training
use_diff: True # True,use diffusion in subdecoder
diff_steps: 8 # number of diffusion steps
use_dispLoss: True
lambda_weight: 0.5
tau: 0.5
train_params:
device: cuda
batch_size: 3
grad_clip: 1.0
num_iter: 300000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
iterations_per_training_cycle: 10 # number of iterations for logging training loss
iterations_per_validation_cycle: 5000 # number of iterations for validation process
input_length: 3072 # input sequence length3072
# you can use focal loss, it it's not used, set focal_gamma to 0
focal_alpha: 1
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.00005
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps
max_lr: 0.00015
gamma: 0.6 # the decay rate for 'cosineannealingwarmuprestarts'
# Distributed Data Parallel
world_size: 5 # 0 means no distributed training
gradient_accumulation_steps: 4 # 1 means no gradient accumulation
inference_params:
num_uncond_generation: 1 # number of unconditional generation
num_cond_generation: 3 # number of conditional generation
data_params:
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
split_ratio: 0.998 # train-validation-test split ratio
aug_type: pitch # random, null | pitch and chord augmentation type
general:
debug: False
make_log: True # True, False | update the log file in wandb online to your designated project and entity
infer_and_log: True # True, False | inference and log the results

View File

@ -0,0 +1,54 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
# - nn_params: nb8_embSum_diff
- nn_params: nb8_embSum_subFeedForward
# - nn_params: nb8_SA_diff
# - nn_params: nb8_embSum_diff_main12head16dim512_ave
# - nn_params: nb8_embSum_NMT_main12_head_16_dim512
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean
use_ddp: True # True, False | distributed data parallel
use_fp16: True # True, False | mixed precision training
use_diff: True # True,use diffusion in subdecoder
use_dispLoss: True
lambda_weight: 0.5
tau: 0.5
diff_steps: 8 # number of diffusion steps
train_params:
device: cuda
batch_size: 8
grad_clip: 1.0
num_iter: 25000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
num_cycles_for_model_checkpoint: 10 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
iterations_per_training_cycle: 10 # number of iterations for logging training loss
iterations_per_validation_cycle: 500 # number of iterations for validation process
input_length: 3072 # input sequence length3072
# you can use focal loss, it it's not used, set focal_gamma to 0
focal_alpha: 1
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.0001
decay_step_rate: 0.4 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 # number of warmup steps
max_lr: 0.00015
gamma: 0.6 # the decay rate for 'cosineannealingwarmuprestarts'
# Distributed Data Parallel
world_size: 5 # 0 means no distributed training
gradient_accumulation_steps: 1 # 1 means no gradient accumulation
inference_params:
num_uncond_generation: 1 # number of unconditional generation
num_cond_generation: 3 # number of conditional generation
data_params:
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
split_ratio: 0.99 # train-validation-test split ratio
aug_type: null # random, null | pitch and chord augmentation type
general:
debug: False
make_log: True # True, False | update the log file in wandb online to your designated project and entity
infer_and_log: True # True, False | inference and log the results

View File

@ -0,0 +1,20 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
input_length: 1024
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,20 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
input_length: 1024
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
partial_sequential_prediction: True
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,20 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
partial_sequential_prediction: True
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
input_length: 1024
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: Parallel
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: RNN
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: SelfAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: Parallel
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: RNN
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: SelfAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SelfAttentionEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 6
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,20 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,20 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: AverageEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 2
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 6
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerCrossAttendDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerFinetuningDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 20
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerPrefixDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 20
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerCrossAttendDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: Parallel
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 5
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 8
num_head: 8

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 7
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 8
num_head: 8

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 8
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 8
num_head: 8

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 8
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16

View File

@ -0,0 +1,17 @@
program: train.py
method: grid
metric:
name: valid.total
goal: minimize
parameters:
train_params.batch_size:
values: [8]
train_params.focal_gamma:
values: [0, 1]
nn_params.main_decoder.input_length:
values: [8192]
command:
- python3
- ${program}
- ${args_no_hyphens}

428
Amadeus/train_utils.py Normal file
View File

@ -0,0 +1,428 @@
import math
from numpy import mask_indices
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
from collections import defaultdict
import torch.nn.functional as F
def add_conti_for_single_feature(tensor):
new_target = tensor.clone()
# Assuming tensor shape is [batch, sequence, features]
# Create a shifted version of the tensor
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
# The first element of each sequence cannot be a duplicate by definition
shifted_tensor[:, 0] = new_target[:, 0] + 1
# Identify where the original and shifted tensors are the same (duplicates)
duplicates = new_target == shifted_tensor
# Replace duplicates with 9999
new_target[duplicates] = 9999
return new_target
def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_params):
feature_prediction_order_dict = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
}
if encoding_scheme == 'remi':
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'cp':
if nn_params.get("partial_sequential_prediction", False):
default_prediction_order = feature_prediction_order_dict[num_features]
prediction_order = [default_prediction_order[0], default_prediction_order[1:]]
else:
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'nb':
assert target_feature in feature_prediction_order_dict[num_features], f"Target feature {target_feature} not in the selected sub-token set. Please check target feature in the config and num_features in nn_params."
default_prediction_order = feature_prediction_order_dict[num_features]
# Reorganize the prediction order based on the target_feature
target_index = default_prediction_order.index(target_feature)
prediction_order = default_prediction_order[target_index:] + default_prediction_order[:target_index]
return prediction_order
########################### Loss function ################################
class NLLLoss4REMI():
def __init__(
self,
focal_alpha:float,
focal_gamma:float,
):
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss_seq = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss, loss_seq
def __call__(self, logits, shifted_tgt, mask, vocab):
if vocab is not None:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
loss_by_class_normal = defaultdict(float)
shifted_tgt_with_mask = shifted_tgt * mask # [b, t]
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
for feature in vocab.feature_list:
feature_mask = vocab.total_mask[feature].to(answers_idx.device) # [327,]
mask_for_target = feature_mask[answers_idx] # [b*t]
normal_loss_seq_by_class = loss_seq * mask_for_target
if mask_for_target.sum().item() != 0:
loss_by_class_normal[feature+'_normal'] += (normal_loss_seq_by_class.sum().item() / mask_for_target.sum().item())
return loss, loss_by_class_normal
else:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
return loss, None
class NLLLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token):
probs = logits.softmax(dim=-1)
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target] # [batch_size*seq_len]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * total_mask # [batch_size*seq_len]
loss = loss.sum() / total_mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits_dict, shifted_tgt, mask, valid):
train_loss_list = []
log_loss_dict_normal = {}
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
if valid:
if key == 'type':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'beat':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'chord' or key == 'tempo' or key == 'instrument':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None)
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
def dispersive_loss(z, tau=0.5, eps=1e-8):
"""使用余弦距离的Dispersive Loss实现"""
B = z.size(0)
# 计算余弦相似度矩阵 [B, B]
z_norm = torch.nn.functional.normalize(z, p=2, dim=1) # 向量归一化
sim_matrix = torch.matmul(z_norm, z_norm.transpose(0, 1)) # 余弦相似度
# 转换为余弦距离 (1 - 相似度),排除对角线
mask = 1 - torch.eye(B, device=z.device)
cos_dist = (1 - sim_matrix) * mask
# 计算分散性损失与L2版本相同
exp_term = torch.exp(-cos_dist / tau)
mean_exp = exp_term.sum() / (B * (B - 1) + eps)
loss = -torch.log(mean_exp + eps)
return loss
class DiffusionLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask,mask_indices, p_mask):
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
if mask.ndim == 2:
mask = mask.flatten(0, 1)
# datatype of logits, target, mask_indices, p_mask should be the same
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
loss = (token_loss * mask[mask_indices]).sum() / mask[mask_indices].sum()
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token, mask_indices, p_mask):
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
return loss
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
train_loss_list = []
log_loss_dict_normal = {}
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
disp_loss = None
if input_dict is not None:
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
feat = hidden_vec.mean(dim=1) #bs,dim
disp_loss = dispersive_loss(feat, tau=tau) # scalar
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
train_loss_list.append(training_loss)
if valid:
if key == 'type':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'beat':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'chord' or key == 'tempo' or key == 'instrument':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if disp_loss is not None:
total_loss = total_loss + lambda_weight * disp_loss
log_loss_dict_normal['dispersion'] = disp_loss.item()
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
class EncodecFlattenLoss():
def __init__(self, feature_list):
self.feature_list = feature_list
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss_seq = -torch.log(pt) # [batch_size*seq_len]
loss_seq = loss_seq * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits, shifted_tgt, mask):
loss = self.get_nll_loss(logits, shifted_tgt, mask)
return loss
class EncodecMultiClassLoss(EncodecFlattenLoss):
def __init__(self, feature_list):
super().__init__(feature_list)
def __call__(self, logits_dict, shifted_tgt, mask):
train_loss_list = []
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
total_loss = sum(train_loss_list) / len(train_loss_list)
return total_loss
########################### Learning rate Scheduler ################################
'''
This scheduler is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/#custom-cosineannealingwarmrestarts-1
It's basically a cosine annealing scheduler with warm restarts including two methods, warm up start and reducing maximum lr.
'''
class CosineAnnealingWarmUpRestarts(_LRScheduler):
def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1, eta_min=0):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
if T_up < 0 or not isinstance(T_up, int):
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
self.T_0 = T_0
self.T_mult = T_mult
self.base_eta_max = eta_max
self.eta_max = eta_max
self.T_up = T_up
self.T_i = T_0
self.gamma = gamma
self.cycle = 0
self.T_cur = last_epoch
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.T_cur == -1:
return self.base_lrs
elif self.T_cur < self.T_up:
return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
else:
return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.cycle += 1
self.T_cur = self.T_cur - self.T_i
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
else:
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
self.cycle = epoch // self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.cycle = n
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineLRScheduler(_LRScheduler):
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
class DispersiveLoss(nn.Module):
def __init__(self, loss_type='infonce_l2', tau=0.5, lambda_weight=0.5):
super().__init__()
self.loss_type = loss_type
self.tau = tau
self.lambda_weight = lambda_weight
def forward(self, features, diffusion_loss):
"""
features: 批次特征矩阵,形状为 [batch_size, feature_dim]
diffusion_loss: 原扩散损失
"""
batch_size = features.size(0)
# 计算距离矩阵
if self.loss_type == 'infonce_l2':
# 计算平方L2距离
dist_matrix = torch.cdist(features, features, p=2) ** 2
# 计算分散损失
exp_dist = torch.exp(-dist_matrix / self.tau)
disp_loss = torch.log(exp_dist.mean())
elif self.loss_type == 'hinge':
# Hinge损失假设阈值epsilon=1.0
dist_matrix = torch.cdist(features, features, p=2)
disp_loss = torch.max(torch.zeros_like(dist_matrix), 1.0 - dist_matrix).mean()
elif self.loss_type == 'covariance':
# 协方差损失
normalized_features = (features - features.mean(dim=0)) / features.std(dim=0)
cov_matrix = torch.matmul(normalized_features.T, normalized_features) / batch_size
# 非对角线元素平方和
mask = ~torch.eye(cov_matrix.size(0), dtype=torch.bool)
disp_loss = (cov_matrix[mask] ** 2).mean()
else:
raise ValueError("Unsupported loss type")
# 总损失 = 扩散损失 + lambda * 分散损失
total_loss = diffusion_loss + self.lambda_weight * disp_loss
return total_loss, disp_loss

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,949 @@
import torch
import torch.nn as nn
from x_transformers import Decoder, Encoder, PrefixDecoder, CrossAttender
from transformers import T5EncoderModel
from data_representation.vocab_utils import LangTokenVocab
class PosEncoding(nn.Module):
def __init__(self, emb_size, max_t):
super().__init__()
self.emb_size =emb_size
self.max_t = max_t
self.register_buffer('encoding', self._prepare_emb())
def _prepare_emb(self):
dim_axis = 10000**(torch.arange(self.emb_size//2) * 2 / self.emb_size) # 10000 ** (normalized values between 0~1 num_emb_dim)
timesteps = torch.arange(self.max_t)
pos_enc_in = timesteps.unsqueeze(1) / dim_axis.unsqueeze(0)
pos_enc_sin = torch.sin(pos_enc_in) # x values for sin are between 0 ~ 1 so the values could never be the same
pos_enc_cos = torch.cos(pos_enc_in)
pos_enc = torch.stack([pos_enc_sin, pos_enc_cos], dim=-1).reshape([self.max_t, self.emb_size])
return pos_enc
def forward(self, x):
return self.encoding[x]
class ResidualLayerNormModule(nn.Module):
def __init__(self, submodule):
super().__init__()
self.submodule = submodule
self.layer_norm = nn.LayerNorm(self.submodule.input_size)
def forward(self, x, mask=None, y=None):
if y is not None:
res_x = self.submodule(x, y, mask)
elif mask is not None:
res_x = self.submodule(x, mask)
else:
res_x = self.submodule(x)
x = x + res_x
return self.layer_norm(x)
class SingleEmbedding(nn.Module):
def __init__(
self,
vocab,
dim_model,
):
'''
Embedding layer for REMI
'''
super().__init__()
vocab_size = vocab.get_vocab_size()
self.embedding = nn.Embedding(vocab_size, dim_model)
def forward(self, x):
return self.embedding(x)
class MultiEmbedding(nn.Module):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int,
):
super().__init__()
'''
Embedding layer for compound tokens
'''
self.vocab_size = vocab.get_vocab_size()
self.feature_list = vocab.feature_list
self.dim_model = dim_model
self.layers = []
self._make_emb_layers()
self._init_params()
self._make_emb_boundaries_by_key()
def _init_params(self):
# apply kaiming init
for layer in self.layers:
if isinstance(layer, nn.Embedding):
nn.init.kaiming_normal_(layer.weight)
def _make_emb_layers(self):
vocab_sizes = [self.vocab_size[key] for key in self.feature_list]
self.embedding_sizes = [self.dim_model for _ in self.feature_list]
for vocab_size, embedding_size in zip(vocab_sizes, self.embedding_sizes):
if embedding_size != 0:
self.layers.append(nn.Embedding(vocab_size, embedding_size))
self.layers = nn.ModuleList(self.layers)
def _make_emb_boundaries_by_key(self):
'''
This function returns dict of boundaries for each embedding layer
'''
self.emb_boundary_by_key = {}
start_idx = 0
for key, emb_size in zip(self.feature_list, self.embedding_sizes):
if emb_size != 0:
self.emb_boundary_by_key[key] = (start_idx, start_idx + emb_size)
start_idx += emb_size
def forward(self, x):
emb = torch.cat([module(x[..., i]) for i, module in enumerate(self.layers)], dim=-1)
return emb
def __len__(self):
return len(self.layers)
def get_emb_by_key(self, key, token):
layer_idx = self.feature_list.index(key)
return self.layers[layer_idx](token)
class SummationEmbedder(MultiEmbedding):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__(vocab, dim_model)
def forward(self, seq):
emb_list = [module(seq[..., i]) for i, module in enumerate(self.layers)]
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
output = torch.sum(stacked_emb, dim=2) # B x T x emb_size
return output
class AverageEmbedder(MultiEmbedding):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__(vocab, dim_model)
def forward(self, seq):
emb_list = [module(seq[..., i]) for i, module in enumerate(self.layers)]
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
output = torch.mean(stacked_emb, dim=2) # B x T x emb_size
return output
class SelfAttentionEmbedder(MultiEmbedding):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__(vocab, dim_model)
self.dropout = 0.1
self.transformer_encoder = Encoder(
dim = dim_model,
depth = 1,
heads = 8,
attn_dropout = self.dropout,
ff_dropout = self.dropout,
attn_flash = True)
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, self.dim_model), requires_grad=True)
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff()
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn()
def _add_dropout_after_attn(self):
for layer in self.transformer_encoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(self.dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(self.dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self):
for layer in self.transformer_encoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(self.dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_encoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def _apply_window_on_input_vec(self, embeddings):
window_size = 1
zero_vec = torch.zeros(embeddings.shape[0], window_size-1, embeddings.shape[2], embeddings.shape[3]).to(embeddings.device) # B x (window_size-1) x num_features x emb_size
window_applied_input_vec = torch.cat([zero_vec, embeddings], dim=1) # B x (T+window_size-1) x num_features x emb_size
window_applied_input_vec = window_applied_input_vec.unfold(1, window_size, 1) # B x T x window_size x emb_size x num_features
window_applied_input_vec = window_applied_input_vec.transpose(3, 4) # B x T x window_size x num_features x emb_size
window_applied_input_vec = window_applied_input_vec.reshape(embeddings.shape[0]*embeddings.shape[1], -1, embeddings.shape[3]) # (B*T) x (num_features*window_size) x emb_size
return window_applied_input_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # (num_features*window_size+1)
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x (num_features*window_size+1)
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x (num_features*window_size+1) x emb_size
return tgt_pos
def forward(self, input_tokens):
'''
input_tokens: B x T x num_features
'''
# prepare input vector
emb_list = [module(input_tokens[..., i]) for i, module in enumerate(self.layers)] # B x T x 1 x emb_size
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
# apply window
stacked_emb = self._apply_window_on_input_vec(stacked_emb)
# add CLS
cls = self.cls_embedding.repeat(stacked_emb.shape[0], 1, 1) # (B*T) x 1 x emb_size
input_emb = torch.cat([stacked_emb, cls], dim=1) # (B*T) x (num_features*window_size+1) x emb_size
output = self.transformer_encoder(input_emb) # (B*T) x (num_features*window_size+1) x emb_size
# extract CLS
output = output[:, -1, :].reshape((input_tokens.shape[0], input_tokens.shape[1], -1)) # B x T x emb_size
return output
class RVQMultiEmbedding(nn.Module):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__()
self.vocab_size = vocab.get_vocab_size()
self.dim_model = dim_model
self.features = vocab.feature_list
self.layers = []
self._make_emb_layers()
def _make_emb_layers(self):
vocab_sizes = [self.vocab_size[key] for key in self.features]
self.embedding_sizes = [self.dim_model for _ in self.features]
for vocab_size, embedding_size in zip(vocab_sizes, self.embedding_sizes):
if embedding_size != 0:
self.layers.append(nn.Embedding(vocab_size, embedding_size))
self.layers = nn.ModuleList(self.layers)
def forward(self, x):
embeddings = torch.zeros(x.shape[0], x.shape[1], self.dim_model).to(x.device)
emb_list = [module(x[:, (idx+1)%4::4]) for idx, module in enumerate(self.layers)]
for idx, emb in enumerate(emb_list):
embeddings[:, (idx+1)%4::4] = emb
return embeddings
def get_emb_by_key(self, key:str, token:torch.Tensor):
layer_idx = self.features.index(key)
return self.layers[layer_idx](token)
class XtransformerDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class XtransformerLargeCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class NewCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False,
use_rmsnorm=True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class NewCrossAttendwithRoPEDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class XtransformerPrefixDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = PrefixDecoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None):
assert context is not None, 'context should be provided for prefix decoder'
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerPretrainingDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerFinetuningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec
class XtransformerLargeFinetuningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec

BIN
SongEval/.DS_Store vendored Normal file

Binary file not shown.

201
SongEval/LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

88
SongEval/README.md Normal file
View File

@ -0,0 +1,88 @@
# 🎵 SongEval: A Benchmark Dataset for Song Aesthetics Evaluation
[![Hugging Face Dataset](https://img.shields.io/badge/HuggingFace-Dataset-blue)](https://huggingface.co/datasets/ASLP-lab/SongEval)
[![Arxiv Paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2505.10793)
[![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
This repository provides a **trained aesthetic evaluation toolkit** based on [SongEval](https://huggingface.co/datasets/ASLP-lab/SongEval), the first large-scale, open-source dataset for human-perceived song aesthetics. The toolkit enables **automatic scoring of generated song** across five perceptual aesthetic dimensions aligned with professional musician judgments.
---
## 🌟 Key Features
- 🧠 **Pretrained neural models** for perceptual aesthetic evaluation
- 🎼 Predicts **five aesthetic dimensions**:
- Overall Coherence
- Memorability
- Naturalness of Vocal Breathing and Phrasing
- Clarity of Song Structure
- Overall Musicality
<!-- - 🧪 Supports **batch evaluation** for model benchmarking -->
- 🎧 Accepts **full-length songs** (vocals + accompaniment) as input
- ⚙️ Simple inference interface
---
## 📦 Installation
Clone the repository and install dependencies:
```bash
git clone https://github.com/ASLP-lab/SongEval.git
cd SongEval
pip install -r requirements.txt
```
## 🚀 Quick Start
- Evaluate a single audio file:
```bash
python eval.py -i /path/to/audio.mp3 -o /path/to/output
```
- Evaluate a list of audio files:
```bash
python eval.py -i /path/to/audio_list.txt -o /path/to/output
```
- Evaluate all audio files in a directory:
```bash
python eval.py -i /path/to/audio_directory -o /path/to/output
```
- Force evaluation on CPU (⚠️ CPU evaluation may be significantly slower) :
```bash
python eval.py -i /path/to/audio.wav -o /path/to/output --use_cpu True
```
## 🙏 Acknowledgement
This project is mainly organized by the audio, speech and language processing lab [(ASLP@NPU)](http://www.npu-aslp.org/).
We sincerely thank the **Shanghai Conservatory of Music** for their expert guidance on music theory, aesthetics, and annotation design.
Meanwhile, we thank AISHELL to help with the orgnization of the song annotations.
<p align="center"> <img src="assets/logo.png" alt="Shanghai Conservatory of Music Logo"/> </p>
## 📑 License
This project is released under the CC BY-NC-SA 4.0 license.
You are free to use, modify, and build upon it for non-commercial purposes, with attribution.
## 📚 Citation
If you use this toolkit or the SongEval dataset, please cite the following:
```
@article{yao2025songeval,
title = {SongEval: A Benchmark Dataset for Song Aesthetics Evaluation},
author = {Yao, Jixun and Ma, Guobin and Xue, Huixin and Chen, Huakang and Hao, Chunbo and Jiang, Yuepeng and Liu, Haohe and Yuan, Ruibin and Xu, Jin and Xue, Wei and others},
journal = {arXiv preprint arXiv:2505.10793},
year={2025}
}
```

BIN
SongEval/assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1016 KiB

184
SongEval/clap_score.py Normal file
View File

@ -0,0 +1,184 @@
import os
import requests
from tqdm import tqdm
import torch
import numpy as np
import laion_clap
from clap_module.factory import load_state_dict
import librosa
import pyloudnorm as pyln
# following documentation from https://github.com/LAION-AI/CLAP
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_speech_audioset_epoch_15_esc_89.98.pt'):
"""
Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and
the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
This evaluation script assumes that audio_path files are identified with the ids in id2text.
clap_score() evaluates all ids in id2text.
GPU-based computation.
Select one of the following models from https://github.com/LAION-AI/CLAP:
- music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
- music_audioset_epoch_15_esc_90.14.pt
- music_speech_epoch_15_esc_89.25.pt
- 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)
Params:
-- id2text: dictionary with the mapping between id (generated audio filenames in audio_path)
and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
-- audio_path: path where the generated audio files to evaluate are available.
-- audio_files_extension: files extension (default .wav) in eval_path.
-- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
Returns:
-- CLAP-LION score
"""
# load model
if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
clap_path = 'load/clap_score/music_speech_audioset_epoch_15_esc_89.98.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
clap_path = 'load/clap_score/music_audioset_epoch_15_esc_90.14.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
clap_path = 'load/clap_score/music_speech_epoch_15_esc_89.25.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == '630k-audioset-fusion-best.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
clap_path = 'load/clap_score/630k-audioset-fusion-best.pt'
model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
else:
raise ValueError('clap_model not implemented')
# download clap_model if not already downloaded
if not os.path.exists(clap_path):
print('Downloading ', clap_model, '...')
os.makedirs(os.path.dirname(clap_path), exist_ok=True)
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(clap_path, 'wb') as file:
with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
for data in response.iter_content(chunk_size=8192):
file.write(data)
progress_bar.update(len(data))
# fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
pkg = load_state_dict(clap_path)
pkg.pop('text_branch.embeddings.position_ids', None)
model.model.load_state_dict(pkg)
model.eval()
if not os.path.isdir(audio_path):
raise ValueError('audio_path does not exist')
if id2text:
print('[EXTRACTING TEXT EMBEDDINGS] ')
batch_size = 64
text_emb = {}
for i in tqdm(range(0, len(id2text), batch_size)):
batch_ids = list(id2text.keys())[i:i+batch_size]
batch_texts = [id2text[id] for id in batch_ids]
with torch.no_grad():
embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
for id, emb in zip(batch_ids, embeddings):
text_emb[id] = emb
else:
raise ValueError('Must specify id2text')
print('[EVALUATING GENERATIONS] ', audio_path)
score = 0
count = 0
for id in tqdm(id2text.keys()):
file_path = os.path.join(audio_path, str(id)+audio_files_extension)
with torch.no_grad():
audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
audio = pyln.normalize.peak(audio, -1.0)
audio = audio.reshape(1, -1) # unsqueeze (1,T)
audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
score += cosine_sim
count += 1
return score / count if count > 0 else 0
if __name__ == "__main__":
import pandas as pd
import json
import argparse
parser = argparse.ArgumentParser(description='Compute CLAP score for generated audio files.')
parser.add_argument('--clap_model', type=str, default='630k-audioset-fusion-best.pt',
help='CLAP model to use for evaluation. Options: music_speech_audioset_epoch_15_esc_89.98.pt, music_audioset_epoch_15_esc_90.14.pt, music_speech_epoch_15_esc_89.25.pt, 630k-audioset-fusion-best.pt (default: 630k-audioset-fusion-best.pt)')
parser.add_argument('--root_path', type=str, default='../wandb/run-20250627_172105-xpe7nh5n-worseInstr/generated_samples_text_conditioned_top_p_threshold_0.99_temperature_1.15_8',
help='Path to the directory containing generated audio files and id2text mapping.')
args = parser.parse_args()
clap_model = args.clap_model
root_path = args.root_path
json_file_path = os.path.join(root_path, 'name2prompt.jsonl')
generated_path = os.path.join(root_path, 'prompt_music')
if not os.path.exists(generated_path):
generated_path =root_path # if no 'music' subfolder, use root_path directly
with open(json_file_path, 'r') as f:
id2text_dict = {}
for line in f:
item = json.loads(line)
for k, v in item.items():
id2text_dict[k] = v[0]
print('length of id2text:', len(id2text_dict))
# id2text = {k+'_1': v[0] for k, v in id2text_dict.items()} # assuming each key has a list of prompts, we take the first one
id2text ={}
for k, v in id2text_dict.items():
if isinstance(v, list):
id2text[k] = v[0]
# ckeck if k exist as wav file
if os.path.exists(os.path.join(generated_path, str(k)+'.wav')):
id2text[k] = v[0]
else:
# find k_*, k_1, k_2, ... and check if they exist
for i in range(0, 10): # assuming no more than 100 variations
if os.path.exists(os.path.join(generated_path, str(k)+'_'+str(i)+'.wav')):
new_key = str(k) + '_' + str(i)
id2text[new_key] = v[0]
print('length of id2text after checking wav files:', len(id2text))
# check if wav exsists
new_id2text = {}
for id in id2text.keys():
file_path = os.path.join(generated_path, str(id)+'.wav')
if os.path.exists(file_path):
new_id2text[id] = id2text[id]
else:
print(f"Warning: {file_path} does not exist, skipping this id.")
print('length of new_id2text:', len(new_id2text))
"""
IMPORTANT: the audios in generated_path should have the same ids as in id2text.
For musiccaps, you can load id2text as above and each generated_path audio file
corresponds to a prompt (text description) in musiccaps. Files are named with ids, as follows:
- your_model_outputs_folder/_-kssA-FOzU.wav
- your_model_outputs_folder/_0-2meOf9qY.wav
- your_model_outputs_folder/_1woPC5HWSg.wav
...
- your_model_outputs_folder/ZzyWbehtt0M.wav
"""
clp = clap_score(new_id2text, generated_path, audio_files_extension='.wav')
print('CLAP score (cosine similarity):', clp)

6
SongEval/config.yaml Normal file
View File

@ -0,0 +1,6 @@
generator:
_target_: model.Generator
in_features: 1024
ffd_hidden_size: 4096
num_classes: 5
attn_layer_num: 4

456
SongEval/controlability.py Normal file
View File

@ -0,0 +1,456 @@
import json
generate_path = 'Text2midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-attribute2music/infer_test/topk15-t1.0-ngram0/all_midis'
# generate_path = 'Text2midi/t2m-inferalign/text2midi_infer_output'
# generate_path = 'wandb/no-disp-no-ciem/text_condi_top_p_t0.99_temp1.25'
test_set_json = "dataset/midicaps/train.json"
generated_eval_json_path = f"{generate_path}/eval.json"
generated_name2prompt_jsonl_path = f"{generate_path}/name2prompt.jsonl"
# 1. 读取 test_set建立 prompt 到条目的映射
with open(test_set_json, 'r') as f:
test_set =[]
for line in f:
if not line.strip():
continue
item = json.loads(line.strip())
test_set.append(item)
prompt2item = {item['caption']: item for item in test_set if item['test_set'] is True}
print(f"Number of prompts in test set: {len(prompt2item)}")
# 2. 读取 name2prompt.jsonl建立 name 到 prompt 的映射
name2prompt = {}
with open(generated_name2prompt_jsonl_path, 'r') as f:
for line in f:
obj = json.loads(line)
name2prompt.update({k: v[0] for k, v in obj.items() if isinstance(v, list) and len(v) > 0})
# 3. 读取 eval.json
with open(generated_eval_json_path, 'r') as f:
eval_items = []
for line in f:
if not line.strip():
continue
item = json.loads(line.strip())
eval_items.append(item)
# 4. 对每个 name找到对应的 prompt确保 prompt 在 test_set 里,然后找到 eval.json 里对应的条目
results = []
# turn the name of eval_items into relative name
for item in eval_items:
item['name'] = item['name'].split('/')[-1] # 假设 name 是一个路径,取最后一部分作为相对名称
# 去掉第二个下划线后面的内容
if '_' in item['name']:
item['name'] = item['name'].split('.')[0].split('_')[0] + '_' + item['name'].split('.')[0].split('_')[1]
# print(f"Processed eval item name: {item['name']}")
for name, prompt in name2prompt.items():
if prompt not in prompt2item:
print(f"Prompt not found in test set: {prompt}")
continue
# 找到 eval.json 里对应的条目(假设 eval.json 里有 name 字段)
eval_entry = next((item for item in eval_items if item.get('name') == name), None)
if eval_entry is None:
print(f"Eval entry not found for name: {name}")
continue
# 原始条目
original_entry = prompt2item[prompt]
results.append({
'name': name,
'prompt': prompt,
'eval_entry': eval_entry,
'original_entry': original_entry
})
print(f"Number of results: {len(results)}")
print(f"Sample result: {results[0] if results else 'No results'}")
def calculate_TBT_score(results):
"""
• Tempo Bin with Tolerance (TBT): The predicted bpm falls into the ground truth tempo bin or
a neighboring one.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'tempo' in eval_entry and 'tempo' in original_entry:
eval_tempo = eval_entry['tempo'][0] if isinstance(eval_entry['tempo'], list) else eval_entry['tempo']
original_tempo = original_entry['tempo']
if original_tempo is None or eval_tempo is None:
continue # 如果原始条目没有 tempo跳过
# 检查 eval_tempo 是否在 original_tempo 的范围内
if original_tempo - 10 <= eval_tempo <= original_tempo + 15:
correct += 1
total += 1
TB_score = correct / total if total > 0 else 0
print(f"TB Score: {TB_score:.4f} (Correct: {correct}, Total: {total})")
return TB_score
def calculate_CK_score(results):
"""
• Correct Key (CK): The predicted key matches the ground truth key.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'key' in eval_entry and 'key' in original_entry:
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
eval_key = eval_key if eval_key is not None else "C major" # 默认值为 C 大调
original_key = original_entry['key'] if original_entry['key'] is not None else "C major" # 默认值为 C 大调
if original_key is None or eval_key is None:
continue
if eval_key == original_key:
correct += 1
total += 1
CK_score = correct / total if total > 0 else 0
print(f"CK Score: {CK_score:.4f} (Correct: {correct}, Total: {total})")
return CK_score
def calculate_CKD_score(results):
"""
Correct Key with Duplicates (CKD): The predicted key matches the ground truth key or an equivalent key (i.e., a major key and its relative minor).
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'key' in eval_entry and 'key' in original_entry:
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
if eval_key is None:
eval_key = "C major" # 默认值为 C 大调
original_key = original_entry['key'] if original_entry['key'] is not None else "C major"
if original_key is None or eval_key is None:
continue # 如果原始条目没有 key跳过
# 检查 eval_key 是否与 original_key 相同或是其相对小调
if eval_key == original_key or (eval_key.split(' ')[0] == original_key.split(' ')[0]):
correct += 1
total += 1
CKD_score = correct / total if total > 0 else 0
print(f"CKD Score: {CKD_score:.4f} (Correct: {correct}, Total: {total})")
return CKD_score
def calculate_CTS_score(results):
"""
• Correct Time Signature (CTS): The predicted time signature matches the ground truth time signature.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'time_signature' in eval_entry and 'time_signature' in original_entry:
eval_time_signature = eval_entry['time_signature'][0] if isinstance(eval_entry['time_signature'], list) else eval_entry['time_signature']
original_time_signature = original_entry['time_signature']
if original_time_signature is None or eval_time_signature is None:
continue # 如果原始条目没有 time signature跳过
if eval_time_signature == original_time_signature:
correct += 1
else:
# 检查是否为相同的节拍(如 4/4 和 2/2
eval_numerator, eval_denominator = map(int, eval_time_signature.split('/'))
original_numerator, original_denominator = map(int, original_time_signature.split('/'))
if (eval_numerator == original_numerator and eval_denominator == original_denominator) or \
(eval_numerator * 2 == original_numerator and eval_denominator == original_denominator):
correct += 1
total += 1
CTS_score = correct / total if total > 0 else 0
print(f"CTS Score: {CTS_score:.4f} (Correct: {correct}, Total: {total})")
return CTS_score
def calculate_ECM_score(results):
"""
Exact Chord Match (ECM): The predicted
chord sequence matches the ground truth exactly
in terms of order, chord root, and chord type, with
tolerance for missing and excess chord instances.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'chord_summary' in eval_entry and 'chord_summary' in original_entry:
eval_chord_summary = eval_entry['chord_summary'][0] if isinstance(eval_entry['chord_summary'], list) else eval_entry['chord_summary']
original_chord_summary = original_entry['chord_summary']
if original_chord_summary is None or eval_chord_summary is None:
continue
# 检查 eval_chord_summary 是否包含 original_chord_summary两个都是列表每个元素是一个字符串
if eval_chord_summary == original_chord_summary:
correct += 1
total += 1
ECM_score = correct / total if total > 0 else 0
print(f"ECM Score: {ECM_score:.4f} (Correct: {correct}, Total: {total})")
return ECM_score
def calculate_CMO_score(results):
"""
• Chord Match in any Order (CMO): The portion of predicted chord sequence matching the
ground truth chord root and type, in any order
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'chords' in eval_entry and 'chord_summary' in original_entry:
eval_chords_seq = eval_entry['chords']
# remove the confidence score from eval_chords_seq
if isinstance(eval_chords_seq, list) and len(eval_chords_seq) > 0 and isinstance(eval_chords_seq[0], list):
eval_chords_seq = [chord[0] for chord in eval_chords_seq]
original_chord_summary = original_entry['chord_summary']
if original_chord_summary is None or eval_chords_seq is None:
continue
# 检查 eval_chords_seq 是否包含 original_chord_summary两个都是列表
eval_chords_set = set(eval_chords_seq) # [['C', 0.464399092], ['G', 2.879274376]]
original_chord_set = set(original_chord_summary) # ['G', 'C']
if original_chord_set.issubset(eval_chords_set):
correct += 1
else:
if original_chord_set == eval_chords_set:
correct += 1
total += 1
CMO_score = correct / total if total > 0 else 0
print(f"CMO Score: {CMO_score:.4f} (Correct: {correct}, Total: {total})")
return CMO_score
def calculate_CI_score(results):
"""
•Correct Instrument (CI): The predicted instrument matches the ground truth instrument.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
original_instrument = original_entry['instrument_summary']
if original_instrument is None or eval_instrument is None:
continue
# 检查 eval_instrument 是否包含 original_instrument
if isinstance(eval_instrument, list):
eval_instrument_set = set(eval_instrument)
original_instrument_set = set(original_instrument)
if original_instrument_set.issubset(eval_instrument_set):
correct += 1
else:
if eval_instrument == original_instrument:
correct += 1
total += 1
CI_score = correct / total if total > 0 else 0
print(f"CI Score: {CI_score:.4f} (Correct: {correct}, Total: {total})")
return CI_score
def calculate_CI_top1_score(results):
"""
•Correct Instrument Top-1 (CI_top1): The predicted instrument matches the ground truth instrument
or is one of the top 3 predicted instruments.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
original_instrument = original_entry['instrument_summary']
if original_instrument is None or eval_instrument is None:
continue
# 检查 eval_instrument 是否包含 original_instrument中的一个元素
if isinstance(eval_instrument, list):
eval_instrument_set = set(eval_instrument)
original_instrument_set = set(original_instrument)
for inst in original_instrument_set:
if inst in eval_instrument_set:
correct += 1
break
else:
if eval_instrument == original_instrument:
correct += 1
total += 1
CI_top1_score = correct / total if total > 0 else 0
print(f"CI Top-1 Score: {CI_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CI_top1_score
def calculate_CG_score(results):
"""
• Correct Genre (CG): The predicted genre matches the ground truth genre.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'genre' in eval_entry and 'genre' in original_entry:
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
original_genre = original_entry['genre']
if original_genre is None or eval_genre is None:
continue
# 检查 eval_genre 是否包含 original_genre
if isinstance(eval_genre, list):
eval_genre_set = set(eval_genre)
original_genre_set = set(original_genre)
if original_genre_set.issubset(eval_genre_set):
correct += 1
else:
if eval_genre == original_genre:
correct += 1
total += 1
CG_score = correct / total if total > 0 else 0
print(f"CG Score: {CG_score:.4f} (Correct: {correct}, Total: {total})")
return CG_score
def calculate_CG_top1_score(results):
"""
• Correct Genre Top-1 (CG_top1): The predicted genre matches the ground truth genre or is one of the top 3 predicted genres.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'genre' in eval_entry and 'genre' in original_entry:
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
original_genre = original_entry['genre']
if original_genre is None or eval_genre is None:
continue
# 检查 eval_genre 是否包含 original_genre中的一个元素
if isinstance(eval_genre, list):
eval_genre_set = set(eval_genre)
original_genre_set = set(original_genre)
for gen in original_genre_set:
if gen in eval_genre_set:
correct += 1
break
else:
if eval_genre == original_genre:
correct += 1
total += 1
CG_top1_score = correct / total if total > 0 else 0
print(f"CG Top-1 Score: {CG_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CG_top1_score
def calculate_CM_score(results):
"""
• Correct Mood (CM): The predicted mood matches the ground truth mood.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
if original_mood_set.issubset(eval_mood_set):
correct += 1
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_score = correct / total if total > 0 else 0
print(f"CM Score: {CM_score:.4f} (Correct: {correct}, Total: {total})")
return CM_score
def calculate_CM_top1_score(results):
"""
• Correct Mood Top-1 (CM_top1): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood中的一个元素
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
for mood in original_mood_set:
if mood in eval_mood_set:
correct += 1
break
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_top1_score = correct / total if total > 0 else 0
print(f"CM Top-1 Score: {CM_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CM_top1_score
def calculate_CM_top3_score(results):
"""
• Correct Mood Top-3 (CM_top3): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood中的3个元素
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
if len(original_mood_set) <= 3 and original_mood_set.issubset(eval_mood_set):
correct += 1
elif len(original_mood_set) > 3:
match_num = sum(1 for mood in original_mood_set if mood in eval_mood_set)
if match_num >= 3:
correct += 1
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_top3_score = correct / total if total > 0 else 0
print(f"CM Top-3 Score: {CM_top3_score:.4f} (Correct: {correct}, Total: {total})")
return CM_top3_score
def calculate_all_scores(results):
"""
Calculate all scores and return them as a dictionary.
"""
scores = {
'TBT_score': calculate_TBT_score(results),
'CK_score': calculate_CK_score(results),
'CKD_score': calculate_CKD_score(results),
'CTS_score': calculate_CTS_score(results),
'ECM_score': calculate_ECM_score(results),
'CMO_score': calculate_CMO_score(results),
'CI_score': calculate_CI_score(results),
'CI_top1_score': calculate_CI_top1_score(results),
'CG_score': calculate_CG_score(results),
'CG_top1_score': calculate_CG_top1_score(results),
'CM_score': calculate_CM_score(results),
'CM_top1_score': calculate_CM_top1_score(results),
'CM_top3_score': calculate_CM_top3_score(results)
}
return scores
if __name__ == "__main__":
scores = calculate_all_scores(results)
print("All Scores:")
for score_name, score_value in scores.items():
print(f"{score_name}: {score_value:.4f}")
# Save the results to a JSON file
output_file = f"{generate_path}/results.json"
with open(output_file, 'w') as f:
json.dump(scores, f, indent=4)
print(f"Results saved to {output_file}")

103
SongEval/ebr.py Normal file
View File

@ -0,0 +1,103 @@
import argparse
import glob
import os
import pandas as pd
import muspy
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
def compute_midi_metrics(file_path):
"""计算单个MIDI文件的音乐指标"""
try:
music = muspy.read(file_path)
scale_consistency = muspy.scale_consistency(music)
pitch_entropy = muspy.pitch_entropy(music)
pitch_class_entropy = muspy.pitch_class_entropy(music)
empty_beat_rate = muspy.empty_beat_rate(music)
groove_consistency = muspy.groove_consistency(music, 12)
metrics = {
'scale_consistency': scale_consistency,
'pitch_entropy': pitch_entropy,
'pitch_class_entropy': pitch_class_entropy,
'empty_beat_rate': empty_beat_rate,
'groove_consistency': groove_consistency,
'filename': os.path.basename(file_path)
}
return metrics
except Exception as e:
print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
return None
def compute_directory_metrics(directory_path, num_workers=8):
"""计算目录下所有MIDI文件的音乐指标多线程加速"""
midi_files = []
for root, _, files in os.walk(directory_path):
for file in files:
if file.lower().endswith(('.mid', '.midi')):
midi_files.append(os.path.join(root, file))
if not midi_files:
print("目录及子文件夹中未找到MIDI文件")
return None
all_metrics = []
average_metrics = {
'scale_consistency': 0,
'pitch_entropy': 0,
'pitch_class_entropy': 0,
'empty_beat_rate': 0,
'groove_consistency': 0
}
current_num = 0
total_scale_consistency = 0
total_pitch_entropy = 0
total_pitch_class_entropy = 0
total_empty_beat_rate = 0
total_groove_consistency = 0
print(f"正在处理目录: {directory_path}")
print(f"发现 {len(midi_files)} 个MIDI文件:")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(compute_midi_metrics, midi_file): midi_file for midi_file in midi_files}
for future in tqdm(as_completed(futures), total=len(midi_files), desc="处理中"):
metrics = future.result()
if metrics is not None:
current_num += 1
total_scale_consistency += metrics['scale_consistency']
total_pitch_entropy += metrics['pitch_entropy']
total_pitch_class_entropy += metrics['pitch_class_entropy']
total_empty_beat_rate += metrics['empty_beat_rate']
total_groove_consistency += metrics['groove_consistency']
average_metrics['scale_consistency'] = total_scale_consistency / current_num
average_metrics['pitch_entropy'] = total_pitch_entropy / current_num
average_metrics['pitch_class_entropy'] = total_pitch_class_entropy / current_num
average_metrics['empty_beat_rate'] = total_empty_beat_rate / current_num
average_metrics['groove_consistency'] = total_groove_consistency / current_num
print("current_metrics:", metrics)
all_metrics.append(metrics)
if not all_metrics:
print("所有文件处理失败")
return None
df = pd.DataFrame(all_metrics)
output_csv = os.path.join(directory_path, "midi_metrics_report.csv")
df.to_csv(output_csv, index=False)
avg_metrics = df.mean(numeric_only=True)
return df, avg_metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="计算目录下所有MIDI文件的音乐指标")
parser.add_argument("path", type=str, help="包含MIDI文件的目录路径")
parser.add_argument("--threads", type=int, default=1, help="线程数默认8")
args = parser.parse_args()
if not os.path.isdir(args.path):
print(f"错误: 路径 '{args.path}' 不存在或不是目录")
else:
result, averages = compute_directory_metrics(args.path, num_workers=args.threads)
if result is not None:
print("\n计算完成! 结果已保存到 midi_metrics_report.csv")
print("\n平均指标值:")
print(averages.to_string())

150
SongEval/eval.py Normal file
View File

@ -0,0 +1,150 @@
import glob
import os
import json
import librosa
import numpy as np
import torch
import argparse
from muq import MuQ
from hydra.utils import instantiate
from omegaconf import OmegaConf
from safetensors.torch import load_file
from tqdm import tqdm
class Synthesizer(object):
def __init__(self,
checkpoint_path,
input_path,
output_dir,
use_cpu: bool = False):
self.checkpoint_path = checkpoint_path
self.input_path = input_path
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
self.device = torch.device('cuda') if (torch.cuda.is_available() and (not use_cpu)) else torch.device('cpu')
@torch.no_grad()
def setup(self):
train_config = OmegaConf.load(os.path.join(os.path.dirname(self.checkpoint_path), '../config.yaml'))
model = instantiate(train_config.generator).to(self.device).eval()
state_dict = load_file(self.checkpoint_path, device="cpu")
model.load_state_dict(state_dict, strict=False)
self.model = model
self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
self.muq = self.muq.to(self.device).eval()
self.result_dcit = {}
@torch.no_grad()
def synthesis(self):
if os.path.isfile(self.input_path):
if self.input_path.endswith(('.wav', '.mp3')):
lines = []
lines.append(self.input_path)
else:
with open(self.input_path, "r") as f:
lines = [line for line in f]
input_files = [{
"input_path": line.strip(),
} for line in lines]
print(f"input filelst: {self.input_path}")
elif os.path.isdir(self.input_path):
input_files = [{
"input_path": file,
}for file in glob.glob(os.path.join(self.input_path, '*')) if file.lower().endswith(('.wav', '.mp3'))]
else:
raise ValueError(f"input_path {self.input_path} is not a file or directory")
for input in tqdm(input_files):
try:
self.handle(**input)
except Exception as e:
print(e)
continue
# add average
avg_values = {}
for key in self.result_dcit[list(self.result_dcit.keys())[0]].keys():
avg_values[key] = round(np.mean([self.result_dcit[fid][key] for fid in self.result_dcit]), 4)
self.result_dcit['average'] = avg_values
# save result
with open(os.path.join(self.output_dir, "result.json") , "w")as f:
json.dump(self.result_dcit, f, indent=4, ensure_ascii=False)
@torch.no_grad()
def handle(self, input_path):
fid = os.path.basename(input_path).split('.')[0]
if input_path.endswith('.npy'):
input = np.load(input_path)
# check ssl
if len(input.shape) == 3 and input.shape[0] != 1:
print('ssl_shape error', input_path)
return
if np.isnan(input).any():
print('ssl nan', input_path)
return
input = torch.from_numpy(input).to(self.device)
if len(input.shape) == 2:
input = input.unsqueeze(0)
if input_path.endswith(('.wav', '.mp3')):
wav, sr = librosa.load(input_path, sr=24000)
audio = torch.tensor(wav).unsqueeze(0).to(self.device)
output = self.muq(audio, output_hidden_states=True)
input = output["hidden_states"][6]
values = {}
scores_g = self.model(input).squeeze(0)
values['Coherence'] = round(scores_g[0].item(), 4)
values['Musicality'] = round(scores_g[1].item(), 4)
values['Memorability'] = round(scores_g[2].item(), 4)
values['Clarity'] = round(scores_g[3].item(), 4)
values['Naturalness'] = round(scores_g[4].item(), 4)
self.result_dcit[fid] = values
# delete
del input, output, scores_g, values,audio, wav, sr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input_path",
type=str,
required=True,
help="Input audio: path to a single file, a text file listing audio paths, or a directory of audio files."
)
parser.add_argument(
"-o", "--output_dir",
type=str,
required=True,
help="Output directory for generated results (will be created if it doesn't exist)."
)
parser.add_argument(
"--use_cpu",
type=str,
help="Force CPU mode even if a GPU is available.",
default=False
)
args = parser.parse_args()
ckpt_path = "ckpt/model.safetensors"
synthesizer = Synthesizer(checkpoint_path=ckpt_path,
input_path=args.input_path,
output_dir=args.output_dir,
use_cpu=args.use_cpu)
synthesizer.setup()
synthesizer.synthesis()

View File

@ -0,0 +1,404 @@
import sys
import os
from pathlib import Path
from multiprocessing import Process,set_start_method
import torch
import argparse
from omegaconf import OmegaConf
import json
from collections import defaultdict
from Amadeus.evaluation_utils import (
wandb_style_config_to_omega_config,
prepare_model_and_dataset_from_config,
get_best_ckpt_path_and_config,
Evaluator
)
from transformers import T5Tokenizer, T5EncoderModel
from Amadeus import model_zoo
from Amadeus.symbolic_encoding import data_utils
from Amadeus.model_zoo import AmadeusModel
from Amadeus.symbolic_encoding.data_utils import TuneCompiler
from Amadeus.symbolic_encoding.compile_utils import shift_and_pad
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
from Amadeus.symbolic_encoding import decoding_utils
from Amadeus.train_utils import adjust_prediction_order
from data_representation import vocab_utils
from data_representation.vocab_utils import LangTokenVocab
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-wandb_exp_dir",
required=True,
type=str,
help="wandb experiment directory",
)
parser.add_argument(
"-generation_type",
type=str,
choices=('conditioned', 'unconditioned', 'text-conditioned'),
default='unconditioned',
help="generation type",
)
parser.add_argument(
"-sampling_method",
type=str,
choices=('top_p', 'top_k'),
default='top_p',
help="sampling method",
)
parser.add_argument(
"-threshold",
type=float,
default=0.99,
help="threshold",
)
parser.add_argument(
"-temperature",
type=float,
default=1.15,
help="temperature",
)
parser.add_argument(
"-num_samples",
type=int,
default=30,
help="number of samples to generate",
)
parser.add_argument(
"-num_target_measure",
type=int,
default=4,
help="number of target measures for conditioned generation",
)
parser.add_argument(
"-choose_selected_tunes",
action='store_true',
help="generate samples from selected tunes, only for SOD dataset",
)
parser.add_argument(
"-generate_length",
type=int,
default=1024,
help="length of the generated sequence",
)
parser.add_argument(
"-num_processes",
type=int,
default=2,
help="number of processes to use",
)
parser.add_argument(
"-gpu_ids",
type=str,
default="0,5",
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
)
parser.add_argument(
"-prompt",
type=str,
default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.",
help="prompt for generation, only used for conditioned generation",
)
parser.add_argument(
"-prompt_file",
type=str,
default="dataset/midicaps/train.json",
help="file containing prompts for text-conditioned generation",
)
return parser
def load_resources(wandb_exp_dir, device):
"""Load model and dataset resources for a process"""
wandb_dir = Path('wandb')
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
config = OmegaConf.load(config_path)
config = wandb_style_config_to_omega_config(config)
# Load checkpoint to specified device
ckpt = torch.load(ckpt_path, map_location=device)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
model.load_state_dict(ckpt['model'], strict=False)
model.to(device)
model.eval()
torch.compile(model)
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
# Prepare dataset for prompts
condition_list = [x[1] for x in test_set.data_list]
dataset_for_prompt = []
for i in range(len(condition_list)):
condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
dataset_for_prompt.append((condition, condition_list[i]))
return config, model, dataset_for_prompt, vocab
def conditioned_worker(process_idx, gpu_id, args, data_slice):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
batch_dir = base_path / f"process_{process_idx}_batch_{idx}"
batch_dir.mkdir(parents=True, exist_ok=True)
evaluator.generate_samples_with_prompt(
batch_dir,
args.num_target_measure,
tune_in_idx,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length
)
def generate_samples_unconditioned(config, vocab, model, device,save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
for i in range(num_samples):
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
def generate_samples_with_text_prompt(config, vocab, model, device, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = config.nn_params.encoding_scheme
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-large')
encoder = T5EncoderModel.from_pretrained('google/flan-t5-large').to(device)
print(f"Using T5EncoderModel for text prompt: {prompt}")
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(device)
context = encoder(**context).last_hidden_state
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
# Open the jsonl file and count the number of lines to determine the current index
jsonl_path = save_dir / "name2prompt.jsonl"
if jsonl_path.exists():
with open(jsonl_path, 'r') as f:
current_idx = sum(1 for _ in f)
else:
current_idx = 0
name = f"prompt_{current_idx}"
name2prompt_dict = defaultdict(list)
name2prompt_dict[name].append(prompt)
with open(jsonl_path, 'a') as f:
f.write(json.dumps(name2prompt_dict) + '\n')
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
# Generate assigned number of samples
batch_dir = base_path
generate_samples_unconditioned(
config,
vocab,
model,
batch_dir,
num_samples,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
uid=f"{process_idx}"
)
def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
# Generate assigned number of samples
batch_dir = base_path
for idx, tune_name in enumerate(data_slice):
print(f"Process {process_idx} generating samples for tune: {tune_name}")
generate_samples_with_text_prompt(
config,
vocab,
model,
device,
batch_dir,
prompt=tune_name,
first_pred_feature=config.data_params.first_pred_feature,
sampling_method=args.sampling_method,
threshold=args.threshold,
temperature=args.temperature,
generation_length=args.generate_length,
uid=f"{process_idx}_{idx}"
)
def main():
# use spawn method for multiprocessing
set_start_method('spawn', force=True)
args = get_argument_parser().parse_args()
gpu_ids = list(map(int, args.gpu_ids.split(',')))
# Validate GPU availability
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
if len(gpu_ids) == 0:
raise ValueError("At least one GPU must be specified")
# Validate process count
if args.num_processes < 1:
raise ValueError("Number of processes must be at least 1")
if len(gpu_ids) < args.num_processes:
print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared")
# Prepare data slices for processes
processes = []
try:
if args.generation_type == 'conditioned':
# Prepare selected tunes
wandb_dir = Path('wandb') / args.wandb_exp_dir
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
# Load test set to get selected tunes (dummy load to get dataset info)
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, test_set, _ = prepare_model_and_dataset_from_config(
wandb_dir / "files" / "config.yaml",
wandb_dir / "files" / "metadata.json",
wandb_dir / "files" / "vocab.json"
)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
for i in range(args.num_processes):
start_idx = i * chunk_size
end_idx = min((i+1)*chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
if not data_slice:
continue
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=conditioned_worker,
args=(i, gpu_id, args, data_slice)
)
processes.append(p)
p.start()
elif args.generation_type == 'unconditioned':
samples_per_proc = args.num_samples // args.num_processes
remainder = args.num_samples % args.num_processes
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
samples = samples_per_proc + (1 if i < remainder else 0)
if samples <= 0:
continue
p = Process(
target=unconditioned_worker,
args=(i, gpu_id, args, samples)
)
processes.append(p)
p.start()
elif args.generation_type == 'text-conditioned':
samples_per_proc = args.num_samples // args.num_processes
remainder = args.num_samples % args.num_processes
# Load prompts from file
prompt_name_list = []
with open(args.prompt_file, 'r') as f:
for line in f:
if not line.strip():
continue
prompt_data = json.loads(line.strip())
prompt_text = prompt_data['caption']
if prompt_data['test_set'] is True:
prompt_name_list.append(prompt_text)
print("length of prompt_name_list:", len(prompt_name_list))
if len(prompt_name_list) >= args.num_samples:
print(f"Reached the limit of {args.num_samples} prompts.")
break
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
samples = samples_per_proc + (1 if i < remainder else 0)
if samples <= 0:
continue
# Split prompt names across processes
start_idx = i * (len(prompt_name_list) // args.num_processes)
end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes)
data_slice = prompt_name_list[start_idx:end_idx]
p = Process(
target=text_conditioned_worker,
args=(i, gpu_id, args, samples, data_slice)
)
processes.append(p)
p.start()
# Wait for all processes to complete
for p in processes:
p.join()
except Exception as e:
print(f"Error in main process: {str(e)}")
for p in processes:
p.terminate()
raise
if __name__ == "__main__":
main()

Some files were not shown because too many files have changed in this diff Show More