first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

BIN
Amadeus/.DS_Store vendored Normal file

Binary file not shown.

0
Amadeus/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

56
Amadeus/catsample.py Normal file
View File

@ -0,0 +1,56 @@
import torch
import torch.nn.functional as F
def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
logits = categorical_probs.clamp(min=1e-9).log()
return F.gumbel_softmax(logits, hard=hard)
def sample_categorical(categorical_probs, method="hard"):
if method == "hard":
gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
return (categorical_probs / gumbel_norm).argmax(dim=-1)
else:
raise ValueError(f"Method {method} for sampling categorical variables is not valid.")
def direct_sampling(logits):
probs = logits.softmax(dim=-1)
index = sample_categorical(probs.to(torch.float32))
return index
def top_p_sampling(logits, p=0.9):
probs = logits.softmax(dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs.masked_fill_(indices_to_remove, 0)
probs /= probs.sum(dim=-1).unsqueeze(-1)
index = sample_categorical(probs.to(torch.float32))
return index
def top_k_sampling(logits, k=400):
top_k_values, top_k_indices = torch.topk(logits, int(k))
top_k_probs = top_k_values.softmax(dim=-1)
index = sample_categorical(top_k_probs.to(torch.float32))
index = top_k_indices[torch.arange(index.size(0)), index]
return index
def sample_with_strategy(update_logits, strategy, para = None):
if strategy == "direct":
return direct_sampling(update_logits)
elif strategy == "top_p":
return top_p_sampling(update_logits, para)
elif strategy == "top_k":
return top_k_sampling(update_logits, para)
else:
raise ValueError(f"Strategy {strategy} is not valid.")

533
Amadeus/evaluation_utils.py Normal file
View File

@ -0,0 +1,533 @@
from collections import defaultdict
from typing import Union
from math import log
from omegaconf import DictConfig
from pathlib import Path
import pickle
import json
import torch
from tqdm.auto import tqdm
from transformers import T5Tokenizer, T5EncoderModel
from . import model_zoo
from .symbolic_encoding import data_utils
from .model_zoo import AmadeusModel
from .symbolic_encoding.data_utils import TuneCompiler
from .symbolic_encoding.compile_utils import shift_and_pad
from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
from .symbolic_encoding import decoding_utils
from .train_utils import adjust_prediction_order
from data_representation import vocab_utils
from data_representation.vocab_utils import LangTokenVocab
def wandb_style_config_to_omega_config(wandb_conf):
# remove wandb related config
for wandb_key in ["wandb_version", "_wandb"]:
if wandb_key in wandb_conf:
del wandb_conf[wandb_key] # wandb-related config should not be overrided!
# print(wandb_conf)
# remove nonnecessary fields such as desc and value
for key in wandb_conf:
# if 'desc' in wandb_conf[key]:
# del wandb_conf[key]['desc']
if isinstance(wandb_conf[key], dict) and 'value' in wandb_conf[key]:
wandb_conf[key] = wandb_conf[key]['value']
# 处理存在'value'的情况
try:
if 'value' in wandb_conf[key]:
wandb_conf[key] = wandb_conf[key]['value']
except:
pass
return wandb_conf
def get_dir_from_wandb_by_code(wandb_dir: Path, code:str) -> Path:
for dir in wandb_dir.iterdir():
if dir.name.endswith(code):
return dir
print(f'No such code in wandb_dir: {code}')
return None
def get_best_ckpt_path_and_config(wandb_dir, code):
dir = get_dir_from_wandb_by_code(wandb_dir, code)
if dir is None:
raise ValueError('No such code in wandb_dir')
ckpt_dir = dir / 'files' / 'checkpoints'
config_path = dir / 'files' / 'config.yaml'
# print all files in ckpt_dir
vocab_path = next(ckpt_dir.glob('vocab*'))
metadata_path = next(ckpt_dir.glob('*metadata.json'))
# if there is pt file ending with 'last', return it
if len(list(ckpt_dir.glob('*last.pt'))) > 0:
last_ckpt_fn = next(ckpt_dir.glob('*last.pt'))
else:
pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', '')))
last_ckpt_fn = pt_fns[-1]
return last_ckpt_fn, config_path, metadata_path, vocab_path
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
nn_params = config.nn_params
dataset_name = config.dataset
vocab_path = Path(vocab_path)
if 'Encodec' in dataset_name:
encodec_tokens_path = Path(f"dataset/maestro-v3.0.0-encodec_tokens")
encodec_dataset = EncodecDataset(config, encodec_tokens_path, None, None)
vocab_sizes = encodec_dataset.vocab.get_vocab_size()
train_set, valid_set, test_set = encodec_dataset.split_train_valid_test_set()
lm_model:model_zoo.LanguageModelTransformer= getattr(model_zoo, nn_params.model_name)(config, vocab_sizes)
else:
# print(config)
encoding_scheme = config.nn_params.encoding_scheme
num_features = config.nn_params.num_features
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
in_vocab_file_path=vocab_path,
event_data=None,
encoding_scheme=encoding_scheme,
num_features=num_features)
# Initialize symbolic dataset based on dataset name and configuration parameters
symbolic_dataset = getattr(data_utils, dataset_name)(
vocab=vocab,
encoding_scheme=encoding_scheme,
num_features=num_features,
debug=config.general.debug,
aug_type=config.data_params.aug_type,
input_length=config.train_params.input_length,
first_pred_feature=config.data_params.first_pred_feature,
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
for_evaluation=True,
)
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
print(f"---{nn_params.main_decoder}--- is used")
print(f"---{dataset_name}--- is used")
print(f"---{encoding_scheme}--- is used")
split_ratio = config.data_params.split_ratio
# test_set = []
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
# get proper prediction order according to the encoding scheme and target feature in the config
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
# Create the Transformer model based on configuration parameters
AmadeusModel = getattr(model_zoo, nn_params.model_name)(
vocab=symbolic_dataset.vocab,
input_length=config.train_params.input_length,
prediction_order=prediction_order,
input_embedder_name=nn_params.input_embedder_name,
main_decoder_name=nn_params.main_decoder_name,
sub_decoder_name=nn_params.sub_decoder_name,
sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0,
sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \
if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False,
dim=nn_params.main_decoder.dim_model,
heads=nn_params.main_decoder.num_head,
depth=nn_params.main_decoder.num_layer,
dropout=nn_params.model_dropout,
)
return AmadeusModel, test_set, symbolic_dataset.vocab
def add_conti_in_valid(tensor, encoding_scheme):
new_target = tensor.clone()
# Assuming tensor shape is [batch, sequence, features]
# Create a shifted version of the tensor
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
# The first element of each sequence cannot be a duplicate by definition
shifted_tensor[:, 0, :] = new_target[:, 0, :] + 1
# Identify where the original and shifted tensors are the same (duplicates)
duplicates = new_target == shifted_tensor
# TODO: convert hard-coded part
# convert values into False except the 1st and 2nd features
if encoding_scheme == 'nb':
if tensor.shape[2] == 5:
# change beat, instrument
duplicates[:, :, 0] = False
duplicates[:, :, 3] = False
duplicates[:, :, 4] = False
elif tensor.shape[2] == 4:
# change beat
duplicates[:, :, 0] = False
duplicates[:, :, 2] = False
duplicates[:, :, 3] = False
elif tensor.shape[2] == 7:
# change beat, chord, tempo
duplicates[:, :, 0] = False
duplicates[:, :, 4] = False
duplicates[:, :, 5] = False
duplicates[:, :, 6] = False
elif encoding_scheme == 'cp':
if tensor.shape[2] == 5:
# change instrument
duplicates[:, :, 0] = False
duplicates[:, :, 1] = False
duplicates[:, :, 3] = False
duplicates[:, :, 4] = False
elif tensor.shape[2] == 7:
# change chord, tempo
duplicates[:, :, 0] = False
duplicates[:, :, 1] = False
duplicates[:, :, 4] = False
duplicates[:, :, 5] = False
duplicates[:, :, 6] = False
# Replace duplicates with 9999
new_target[duplicates] = 9999
return new_target
# TODO: hard coded
def add_conti(list_of_lists, encoding_scheme):
if encoding_scheme == 'nb':
if len(list_of_lists[0]) == 4:
# type, beat, pitch, duration
for i in range(0, len(list_of_lists)):
if list_of_lists[i][0] == 'SSS':
list_of_lists[i][1] = 'Conti'
elif len(list_of_lists[0]) == 5:
# type, beat, instrument, pitch, duration
previous_instrument = None
for i in range(0, len(list_of_lists)):
if list_of_lists[i][0] == 'SSS':
list_of_lists[i][1] = 'Conti'
if list_of_lists[i][2] == previous_instrument and previous_instrument != 0:
list_of_lists[i][2] = 'Conti'
else:
previous_instrument = list_of_lists[i][2]
elif len(list_of_lists[0]) == 7:
# type, beat, chord, tempo, pitch, duration, velocity
previous_chord = None
previous_tempo = None
for i in range(0, len(list_of_lists)):
if list_of_lists[i][0] == 'SSS':
list_of_lists[i][1] = 'Conti'
if list_of_lists[i][2] == previous_chord and previous_chord != 0:
list_of_lists[i][2] = 'Conti'
elif list_of_lists[i][2] != previous_chord and list_of_lists[i][2] != 0:
previous_chord = list_of_lists[i][2]
if list_of_lists[i][3] == previous_tempo and previous_tempo != 0:
list_of_lists[i][3] = 'Conti'
elif list_of_lists[i][3] != previous_tempo and list_of_lists[i][3] != 0:
previous_tempo = list_of_lists[i][3]
elif encoding_scheme == 'cp':
if len(list_of_lists[0]) == 7:
# type, beat, chord, tempo, pitch, duration, velocity
previous_chord = None
previous_tempo = None
for i in range(0, len(list_of_lists)):
current_chord = list_of_lists[i][2]
current_tempo = list_of_lists[i][3]
if current_chord == previous_chord and current_chord != 0:
list_of_lists[i][2] = 'Conti'
elif current_chord != previous_chord and current_chord != 0:
previous_chord = current_chord
if current_tempo == previous_tempo and current_tempo != 0:
list_of_lists[i][3] = 'Conti'
elif current_tempo != previous_tempo and current_tempo != 0:
previous_tempo = current_tempo
if len(list_of_lists[0]) == 5:
# type, beat, instrument, pitch, duration
previous_instrument = None
for i in range(0, len(list_of_lists)):
current_instrument = list_of_lists[i][2]
if current_instrument == previous_instrument and current_instrument != 0:
list_of_lists[i][2] = 'Conti'
elif current_instrument != previous_instrument and current_instrument != 0:
previous_instrument = current_instrument
return list_of_lists
class Evaluator:
def __init__(self,
config: DictConfig,
model:AmadeusModel,
test_set:TuneCompiler,
vocab: Union[LangTokenVocab, LangTokenVocab],
device:str='cuda',
batch_size:int=16):
self.config = config
self.device = device
self.vocab = vocab
self.model = model
self.model.eval()
self.model.to(device)
self.test_set = test_set
self.input_len = config.train_params.input_length
self.loss_by_class = {key:[] for key in self.vocab.feature_list}
self.count_by_class = {key:0 for key in self.vocab.feature_list}
self.batch_size = batch_size
self.is_multiclass = True if config.nn_params.encoding_scheme == 'nb' or config.nn_params.encoding_scheme == 'cp' else False
self.first_pred_feature = self.config.data_params.first_pred_feature
self.neglect_keywords = ['SSS', 'SSN', 'Conti', 'Metrical', 'Note']
self.valid_item_prob = []
# we don't use focal loss on evaluation
self.focal_alpha = 1
self.focal_gamma = 0
def save_results(self, save_fn):
# convert loss_by_clas tensor to cpu
for key in self.loss_by_class.keys():
self.loss_by_class[key] = torch.tensor(self.loss_by_class[key]).cpu()
self.count_by_class[key] = torch.tensor(self.count_by_class[key]).cpu()
torch.save({'loss_by_class':self.loss_by_class, 'count_by_class':self.count_by_class}, save_fn)
@torch.inference_mode()
def get_perplexity(self,less_than=256):
for data in tqdm(self.test_set.data_list, desc='Cal over dataset', position=0):
data_tensor = torch.LongTensor(data[0])
if self.config.nn_params.encoding_scheme == 'nb':
data_tensor = shift_and_pad(data_tensor, self.first_pred_feature)
data_tensor = data_tensor[:-1]
x_seg = data_tensor[:-1].unsqueeze(0)
y_seg = data_tensor[1:].unsqueeze(0)
self._cal_initial_seg(x_seg, y_seg)
if x_seg.shape[1] > self.input_len:
cat_logits = []
cat_y = []
cat_mask_indices = []
batch_x = x_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
batch_y = y_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
if self.is_multiclass:
batch_x = batch_x.transpose(1,2)
batch_y = batch_y.transpose(1,2)
for batch_start_idx in tqdm(range(0, min(batch_x.shape[0], less_than), self.batch_size), desc='In piece iter', position=1, leave=False):
x = batch_x[batch_start_idx:batch_start_idx+self.batch_size]
y = batch_y[batch_start_idx:batch_start_idx+self.batch_size]
logits, y,mask_indices = self._cal_following_seg(x, y)
cat_logits.append(logits)
cat_y.append(y)
cat_mask_indices.append(mask_indices)
if self.is_multiclass:
cat_dict = {}
for key in self.vocab.feature_list:
cat_dict[key] = torch.cat([logits_dict[key] for logits_dict in cat_logits], dim=0)
cat_logits = cat_dict
else:
cat_logits = torch.cat(cat_logits, dim=0)
cat_y = torch.cat(cat_y, dim=0)
mask_indices = torch.cat(cat_mask_indices, dim=0)
if self.is_multiclass:
self._update_loss_for_multi_class(cat_logits, cat_y,mask_indices)
else:
cat_prob = torch.nn.functional.softmax(cat_logits, dim=-1)
pt = cat_prob[torch.arange(cat_prob.shape[0]), cat_y]
# focal_loss = -self.focal_alpha * (1-pt)**self.focal_gamma * torch.log(pt) # [batch_size*seq_len]
loss = -torch.log(pt)
self._update_loss_for_single_class(loss, cat_y)
@torch.inference_mode()
def _update_loss_for_single_class(self, neg_log_prob:torch.Tensor, y:torch.Tensor):
for key in self.vocab.feature_list:
feature_mask = self.vocab.total_mask[key].to(y.device) # [vocab_size,]
mask_for_target = feature_mask[y] # [b*t]
normal_loss_seq_by_class = neg_log_prob[mask_for_target==1]
if mask_for_target.sum().item() != 0:
self.loss_by_class[key] += normal_loss_seq_by_class.tolist()
self.count_by_class[key] += mask_for_target.sum().item()
@torch.inference_mode()
def _update_loss_for_multi_class(self, logits_dict:dict, tgt:torch.Tensor, mask_indices:torch.Tensor=None):
correct_token_prob = []
for index, key in enumerate(self.vocab.feature_list):
feat_tgt = tgt[:,index]
logit_values = logits_dict[key]
logit_values = logit_values
prob_values = torch.nn.functional.softmax(logit_values, dim=-1)
# replce the false
correct_token_prob.append(prob_values[torch.arange(prob_values.shape[0]), feat_tgt])
correct_token_prob = torch.stack(correct_token_prob, dim=1)
# tgt = reverse_shift_and_pad_for_tensor(tgt, self.first_pred_feature)
y_decoded = self.vocab.decode(tgt)
y_decoded = add_conti(y_decoded, self.config.nn_params.encoding_scheme)
# correct_token_prob = reverse_shift_and_pad_for_tensor(correct_token_prob, self.first_pred_feature)
num_notes = logits_dict['pitch'].shape[0]
cum_prob = 1
max_num = mask_indices.size(0)
for idx in range(max_num):
if max_num != num_notes:
print("not equal",max_num,num_notes)
token = y_decoded[idx]
vaild_mask = mask_indices[idx,:]
token_prob = correct_token_prob[idx].tolist()
for j, key in enumerate(self.vocab.feature_list):
cur_feature = token[j]
whether_predicted = vaild_mask[j]
# clamp cur_prob to avoid when cur_prob is 0
cur_prob = max(token_prob[j], 1e-10)
if cur_feature == 0: # ignore token
continue
if whether_predicted is False: # skip provided token
continue
if cur_feature in self.neglect_keywords:
cum_prob *= cur_prob
continue
if self.config.nn_params.encoding_scheme == 'cp' and 'time_signature' in cur_feature:
cum_prob *= cur_prob
continue
if self.config.nn_params.encoding_scheme == 'cp' and 'Bar' in cur_feature:
cum_prob = 1
continue
self.valid_item_prob.append([cur_feature, cur_prob, cur_prob*cum_prob])
pt = cur_prob*cum_prob
loss = -log(pt)
self.loss_by_class[key].append(loss)
self.count_by_class[key] += 1
cum_prob = 1
@torch.inference_mode()
def _cal_initial_seg(self, x_seg, y_seg):
x, y = x_seg[:, :self.input_len].to(self.device), y_seg[:, :self.input_len].to(self.device)
mask_indices = torch.ones_like(y).bool().to(self.device).flatten(0,1)
if self.config.use_diff is True:
logits,(mask_indices,_) = self.model(x, y)
else:
logits = self.model(x, y)
y = y.flatten(0,1)
if self.is_multiclass:
for key in logits.keys():
feat_tensor = logits[key].flatten(0,1)
logits[key] = feat_tensor
self._update_loss_for_multi_class(logits, y, mask_indices)
else:
prob = torch.nn.functional.softmax(logits, dim=-1)
prob = prob.flatten(0,1)
pt = prob[torch.arange(len(y)), y]
loss = -torch.log(pt)
self._update_loss_for_single_class(loss, y)
@torch.inference_mode()
def _cal_following_seg(self, x:torch.Tensor, y:torch.Tensor):
x, y = x.to(self.device), y.to(self.device)
mask_indices = torch.ones_like(y).bool().to(self.device)
if self.config.use_diff is True:
logits,(mask_indices,_) = self.model(x, y)
else:
logits = self.model(x, y)
y = y[:, -1:].flatten(0,1).cpu()
mask_indices = mask_indices.reshape(x.shape)[:,-1:].flatten(0,1).cpu()
if self.is_multiclass:
logits_dict = {}
for key in self.vocab.feature_list:
logits_dict[key] = logits[key][:, -1:].flatten(0,1).cpu()
return logits_dict, y,mask_indices
else:
logits = logits[:, -1:].flatten(0,1).cpu()
return logits, y,mask_indices
def prepare_prompt_and_ground_truth(self, save_dir, num_target_samples, num_target_measures):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
for i, (tuneidx, tune_name) in enumerate(self.test_set):
ground_truth_sample = tuneidx
try:
decoder(ground_truth_sample, output_path=str(save_dir / f"{i}_{tune_name}_gt.mid"))
except:
print(f"Error in generating {i}_{tune_name}.mid")
prompt = self.model.decoder._prepare_inference(start_token=self.model.decoder.net.start_token, manual_seed=0, condition=tuneidx, num_target_measures=num_target_measures)
try:
decoder(prompt, output_path=str(save_dir / f"{i}_{tune_name}_prompt.mid"))
except:
print(f"Error in generating {i}_{tune_name}_prompt.mid")
if i == num_target_samples:
break
def generate_samples_with_prompt(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
tuneidx = tuneidx.cuda()
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
for i in range(num_samples):
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
def generate_samples_with_text_prompt(self, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = self.config.nn_params.encoding_scheme
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
encoder = T5EncoderModel.from_pretrained('google/flan-t5-base').to(self.device)
print(f"Using T5EncoderModel for text prompt: {prompt}")
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(self.device)
context = encoder(**context).last_hidden_state
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
# Open the jsonl file and count the number of lines to determine the current index
jsonl_path = save_dir / "name2prompt.jsonl"
if jsonl_path.exists():
with open(jsonl_path, 'r') as f:
current_idx = sum(1 for _ in f)
else:
current_idx = 0
name = f"prompt_{current_idx}"
name2prompt_dict = defaultdict(list)
name2prompt_dict[name].append(prompt)
with open(jsonl_path, 'a') as f:
f.write(json.dumps(name2prompt_dict) + '\n')
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))

512
Amadeus/model_zoo.py Normal file
View File

@ -0,0 +1,512 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import time
import json
from . import transformer_utils
from . import sub_decoder_zoo
from x_transformers.x_transformers import LayerIntermediates, AbsolutePositionalEmbedding
from data_representation.vocab_utils import LangTokenVocab
import os
class AmadeusModelWrapper(nn.Module):
def __init__(
self,
*,
vocab:LangTokenVocab,
input_length:int,
prediction_order:list,
input_embedder_name:str,
main_decoder_name:str,
sub_decoder_name:str,
sub_decoder_depth:int,
sub_decoder_enricher_use:bool,
dim:int,
heads:int,
depth:int,
dropout:float
):
'''
This class wraps the three main components of the AmadeusModel model,
which are the input embedding layer, the main transformer decoder, and the sub-decoder.
'''
super().__init__()
self.vocab = vocab
self.vocab_size = vocab.get_vocab_size()
self.start_token = vocab.sos_token if hasattr(vocab, 'sos_token') else None
self.end_token = vocab.eos_token if hasattr(vocab, 'eos_token') else None
self.input_length = input_length
self.prediction_order = prediction_order
self._get_input_embedder(input_embedder_name, vocab, dropout, dim)
self._get_main_decoder(main_decoder_name, input_length, dim, heads, depth, dropout)
self._get_sub_decoder(sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout)
self.bos_token_hidden = None
def _get_input_embedder(self, input_embedder_name, vocab, dropout, dim):
self.emb_dropout = nn.Dropout(dropout)
self.input_embedder = getattr(transformer_utils, input_embedder_name)(
vocab=vocab,
dim_model=dim
)
def _get_main_decoder(self, main_decoder_name, input_length, dim, heads, depth, dropout):
self.pos_enc = AbsolutePositionalEmbedding(dim, input_length)
self.main_norm = nn.LayerNorm(dim)
self.main_decoder = getattr(transformer_utils, main_decoder_name)(
dim=dim,
depth=depth,
heads=heads,
dropout=dropout
)
def _get_sub_decoder(self, sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout):
self.sub_decoder = getattr(sub_decoder_zoo, sub_decoder_name)(
prediction_order=prediction_order,
vocab=vocab,
dim=dim,
sub_decoder_depth=sub_decoder_depth,
heads=heads,
dropout=dropout,
sub_decoder_enricher_use=sub_decoder_enricher_use
)
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
embedding = self.input_embedder(input_seq) + self.pos_enc(input_seq)
embedding = self.emb_dropout(embedding)
hidden_vec,layer_inter = self.main_decoder(embedding,train=True, context=context) # B x T x d_model
hidden_vec = self.main_norm(hidden_vec)
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': target, 'bos_token_hidden': self.bos_token_hidden}
logits = self.sub_decoder(input_dict)
# 选择总数中离三分之一最近的层
num_layers = len(layer_inter.layer_hiddens)
idx = round(num_layers / 3)
idx = min(max(idx, 0), num_layers - 1)
input_dict['hidden_vec'] = layer_inter.layer_hiddens[idx]
return logits, input_dict
class AmadeusModelAutoregressiveWrapper(nn.Module):
def __init__(self, net:AmadeusModelWrapper):
'''
Initializes an autoregressive wrapper around the AmadeusModelWrapper,
which allows sequential token generation.
Arguments:
- net: The nested music transformer model that performs the token generation.
'''
super().__init__()
self.net = net
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context)
def _prepare_inference(self, start_token, manual_seed, condition=None, num_target_measures=4):
'''
Prepares the initial tokens for autoregressive inference. If a manual seed is provided,
it sets the seed for reproducibility. If a condition is given, it selects a subset of
the tokens based on certain criteria related to the encoding scheme.
Arguments:
- start_token: The token that represents the start of a sequence.
- manual_seed: A seed value for reproducibility in inference (if greater than 0).
- condition: An optional tensor used for conditional generation, which helps select a
portion of the input tokens based on the encoding scheme.
Returns:
- total_out: A tensor containing the initial tokens for inference, padded to ensure compatibility
with the model.
'''
if manual_seed > 0:
torch.manual_seed(manual_seed)
total_out = []
if condition is None:
# Use the start token if no condition is given
total_out.extend(start_token)
else:
# Extract the portion of the sequence depending on encoding scheme (remi, cp, or nb)
if self.net.vocab.encoding_scheme == 'remi':
type_boundaries = self.net.vocab.remi_vocab_boundaries_by_key['type']
# vocab idx -> 0:SOS, 1:EOS, 2:Bar_without_time_signature, ... where_type_ends:Bar_time_signature_end, ...
measure_bool = (2 <= condition) & (condition < type_boundaries[1]) # between Bar_ts_start and Bar_ts_end
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
elif self.net.vocab.encoding_scheme == 'cp':
# find the start and end of the measure
beat_event2idx = self.net.vocab.event2idx['beat']
for event, idx in beat_event2idx.items():
if event == 0:
continue
if event == 'Bar':
start_idx = idx
elif event.startswith('Beat'):
end_idx = idx
break
measure_bool = (condition[:,1] >= start_idx) & (condition[:,1] < end_idx) # measure tokens
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
# measure_bool = (condition[:,1] == 1) # measure tokens
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
elif self.net.vocab.encoding_scheme == 'nb':
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
if conditional_input_len == 0:
conditional_input_len = 50
selected_tokens = condition[:conditional_input_len].tolist()
total_out.extend(selected_tokens)
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
return total_out
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
'''
Runs one step of autoregressive generation by taking the input sequence, embedding it,
passing it through the main decoder, and generating logits and a sampled token.
Arguments:
- input_seq: The input sequence tensor to be embedded and processed.
- cache: Optional cache for attention mechanisms to avoid recomputation.
- sampling_method: Sampling strategy used to select the next token.
- threshold: Optional threshold value for sampling methods that require it.
- temperature: Controls the randomness of predictions (higher temperature increases randomness).
Returns:
- logits: The predicted logits for the next token.
- sampled_token: The token sampled from the logits.
- intermidiates: Intermediate states from the main decoder, useful for caching.
'''
embedding = self.net.input_embedder(input_seq) + self.net.pos_enc(input_seq)
embedding = self.net.emb_dropout(embedding)
# Run through the main decoder and normalize
hidden_vec, intermidiates = self.net.main_decoder(embedding, cache,context_embedding=context) # B x T x d_model
hidden_vec = self.net.main_norm(hidden_vec)
hidden_vec = hidden_vec[:, -1:] # Keep only the last time step
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
# Generate the next token
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
return logits, sampled_token, intermidiates, hidden_vec
def _update_total_out(self, total_out, sampled_token):
'''
Updates the output sequence with the newly sampled token. Depending on the encoding scheme,
it either appends the token directly or processes feature-based sampling.
Arguments:
- total_out: The tensor containing the previously generated tokens.
- sampled_token: The newly generated token to be appended.
Returns:
- total_out: Updated output tensor with the newly generated token.
- sampled_token: The processed sampled token.
'''
if self.net.vocab.encoding_scheme == 'remi':
# For remi encoding, directly append the sampled token
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
else:
# Handle other encoding schemes by concatenating features
sampled_token_list = []
for key in self.net.vocab.feature_list:
sampled_token_list.append(sampled_token[key])
sampled_token = torch.cat(sampled_token_list, dim=-1)
# print(total_out.shape)
if len(sampled_token.shape) == 2:
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=1)
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
return total_out, sampled_token
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
'''
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
until the desired maximum sequence length is reached or the end token is encountered.
Arguments:
- manual_seed: A seed value for reproducibility in inference.
- max_seq_len: The maximum length of the generated sequence.
- condition: An optional conditioning sequence to start generation from.
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
- temperature: Controls the randomness of the token sampling process.
- batch_size: The number of sequences to generate in parallel.
Returns:
- total_out: The generated sequence of tokens as a tensor.
'''
# Prepare the starting sequence for inference
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
# If a condition is provided, run one initial step
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
else:
cache = LayerIntermediates()
# Continue generating tokens until the maximum sequence length is reached
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
bos_hidden_vec = None
hidden_vec_list = []
token_time_list = []
while total_out.shape[1] < max_seq_len:
pbar.update(1)
input_tensor = total_out[:, -self.net.input_length:]
# Generate the next token and update the cache
time_start = time.time()
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time()
token_time_list.append(time_end - time_start)
if bos_hidden_vec is None:
bos_hidden_vec = hidden_vec
hidden_vec_list.append(hidden_vec)
# Update attention cache to handle autoregressive generation
for inter in cache.attn_intermediates:
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
# Update the generated output with the new token
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
# Stop if the end token is reached
if sampled_token.tolist() == self.net.end_token[0]:
break
# append hidden_vec to pkl
# save_path = 'hidden/diffnoaug_hidden_vec.pt'
# save_time_path = 'hidden/diff_noaug_token_time.json'
# if os.path.exists(save_path):
# # Load existing list and append
# hidden_vec_all = torch.load(save_path, map_location="cpu")
# hidden_vec_all.extend(hidden_vec_list)
# torch.save(hidden_vec_all, save_path)
# else:
# torch.save(hidden_vec_list, save_path)
# if os.path.exists(save_time_path):
# # Load existing list and append
# token_time_all = json.load(open(save_time_path, 'r'))
# token_time_all = token_time_all['token_time_list']
# token_time_all.extend(token_time_list)
# average_time = sum(token_time_all) / len(token_time_all)
# data = {
# 'average_time': average_time,
# 'token_time_list': token_time_all
# }
# json.dump(data, open(save_time_path, 'w'), indent=4)
# else:
# average_time = sum(token_time_list) / len(token_time_list)
# data = {
# 'average_time': average_time,
# 'token_time_list': token_time_list
# }
# json.dump(data, open(save_time_path, 'w'), indent=4)
return total_out
def generate_batch(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1):
'''
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
until the desired maximum sequence length is reached or the end token is encountered.
Arguments:
- manual_seed: A seed value for reproducibility in inference.
- max_seq_len: The maximum length of the generated sequence.
- condition: An optional conditioning sequence to start generation from.
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
- temperature: Controls the randomness of the token sampling process.
- batch_size: The number of sequences to generate in parallel.
Returns:
- total_out: The generated sequence of tokens as a tensor.
'''
# Prepare the starting sequence for inference
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
# total_out (1,1,num) -> (bs,1,num)
total_out = total_out.repeat(batch_size, 1, 1)
# If a condition is provided, run one initial step
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
else:
cache = LayerIntermediates()
# Continue generating tokens until the maximum sequence length is reached
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
while total_out.shape[1] < max_seq_len:
pbar.update(1)
input_tensor = total_out[:, -self.net.input_length:]
# Generate the next token and update the cache
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
# Update attention cache to handle autoregressive generation
for inter in cache.attn_intermediates:
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
# Update the generated output with the new token
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
# Stop if the end token is reached
if sampled_token.tolist() == self.net.end_token[0]:
break
return total_out
class AmadeusModel(nn.Module):
def __init__(
self,
vocab:LangTokenVocab,
input_length:int,
prediction_order:list,
input_embedder_name:str,
main_decoder_name:str,
sub_decoder_name:str,
sub_decoder_depth:int,
sub_decoder_enricher_use:bool,
dim:int,
heads:int,
depth:int,
dropout:float
):
'''
This class combines the wrapper classes and initializes the full AmadeusModel model,
which can perform autoregressive sequence generation for symbolic music.
Vocabulary used for tokenization of the symbolic music data.
Length of the input seqkeuence in tokens.
Defines the order in which features are predicted in a sequence used for compound shift
Name of the input embedding model to be used (e.g., one-hot embedding or learned embeddings).
Name of the main transformer decoder model used for generating the hidden representations for compound tokens.
Name of the sub-decoder, which processes the hidden states and decodes the sub-tokens inside the compound tokens.
Depth (number of layers) of the sub-decoder.
Whether to use an additional enricher module in the sub-decoder to refine representations.
Dimensionality of the model (hidden size of the transformer layers).
Number of attention heads in the transformer layers.
Number of layers in the main decoder.
Dropout rate for all layers in the model.
'''
super().__init__()
decoder = AmadeusModelWrapper(
vocab=vocab,
input_length=input_length,
prediction_order=prediction_order,
input_embedder_name=input_embedder_name,
main_decoder_name=main_decoder_name,
sub_decoder_name=sub_decoder_name,
sub_decoder_depth=sub_decoder_depth,
sub_decoder_enricher_use=sub_decoder_enricher_use,
dim=dim,
heads=heads,
depth=depth,
dropout=dropout
)
self.decoder = AmadeusModelAutoregressiveWrapper(
net=decoder
)
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
return self.decoder(input_seq, target, context=context)
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
if batch_size == 1:
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
else:
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
class AmadeusModel4Encodec(AmadeusModel):
def __init__(
self,
vocab:LangTokenVocab,
input_length:int,
prediction_order:list,
input_embedder_name:str,
main_decoder_name:str,
sub_decoder_name:str,
sub_decoder_depth:int,
sub_decoder_enricher_use:bool,
dim:int,
heads:int,
depth:int,
dropout:float
):
super().__init__(
vocab=vocab,
input_length=input_length,
prediction_order=prediction_order,
input_embedder_name=input_embedder_name,
main_decoder_name=main_decoder_name,
sub_decoder_name=sub_decoder_name,
sub_decoder_depth=sub_decoder_depth,
sub_decoder_enricher_use=sub_decoder_enricher_use,
dim=dim,
heads=heads,
depth=depth,
dropout=dropout
)
def _prepare_inference(self, start_token, manual_seed, condition=None):
if manual_seed > 0:
torch.manual_seed(manual_seed)
total_out = []
if condition is None:
total_out.extend(start_token)
else:
if self.decoder.net.vocab.encoding_scheme == 'remi':
selected_tokens = condition[:1500].tolist()
else:
selected_tokens = condition[:500].tolist()
total_out.extend(selected_tokens)
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.decoder.net.device)
return total_out
def _update_total_out(self, total_out, sampled_token):
if self.decoder.net.vocab.encoding_scheme == 'remi':
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
else:
sampled_token_list = []
for key in self.decoder.net.vocab.feature_list:
sampled_token_list.append(sampled_token[key])
sampled_token = torch.cat(sampled_token_list, dim=-1) # B(1) x num_features
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
return total_out, sampled_token
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1):
embedding = self.decoder.net.input_embedder(input_seq) + self.decoder.net.pos_enc(input_seq)
embedding = self.decoder.net.emb_dropout(embedding)
hidden_vec, intermidiates = self.decoder.net.main_decoder(embedding, cache) # B x T x d_model
hidden_vec = self.decoder.net.main_norm(hidden_vec)
hidden_vec = hidden_vec[:, -1:] # B x 1 x d_model
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': None}
if self.decoder.net.vocab.encoding_scheme == 'remi':
feature_class_idx = (input_seq.shape[1] - 1) % 4
feature_type = self.decoder.net.vocab.feature_list[feature_class_idx]
logits, sampled_token = self.decoder.net.sub_decoder.run_one_step(input_dict, sampling_method, threshold, temperature, feature_type)
else:
logits, sampled_token = self.decoder.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
return logits, sampled_token, intermidiates
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, sampling_method=None, threshold=None, temperature=1):
total_out = self._prepare_inference(self.decoder.net.start_token, manual_seed, condition)
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.decoder.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
else:
cache = LayerIntermediates()
while total_out.shape[1] < max_seq_len:
input_tensor = total_out[:, -self.decoder.net.input_length:]
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
for inter in cache.attn_intermediates:
inter.cached_kv = [t[..., -(self.decoder.net.input_length - 1):, :] for t in inter.cached_kv] # B x num_heads x T x d_head
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
if sampled_token.tolist() == self.decoder.net.end_token[0]:
break
return total_out

168
Amadeus/sampling_utils.py Normal file
View File

@ -0,0 +1,168 @@
import torch
import torch.nn.functional as F
def top_p_sampling(logits, thres=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Create an empty tensor to hold the new logits
new_logits = logits.clone()
# Use the sorted indices to place the '-inf' in the original places
indices_to_remove = sorted_indices[sorted_indices_to_remove]
new_logits[..., indices_to_remove] = float('-inf')
return new_logits
# refered: https://github.com/cimeister/typical-sampling
def typical_sampling(logits, thres=0.99):
# calculate entropy
normalized = torch.nn.functional.log_softmax(logits, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = logits.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < thres).sum(dim=-1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(-1, last_ind.view(-1, 1, 1))
# if self.min_tokens_to_keep > 1:
# # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
# sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
scores = logits.masked_fill(indices_to_remove, float("-inf"))
return scores
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
'''
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
#
# refered: https://github.com/john-hewitt/truncation-sampling
def eta_sampling(logits, epsilon) -> torch.FloatTensor:
probabilities = logits.softmax(dim=-1)
entropy = torch.distributions.Categorical(probs=probabilities).entropy()
new_epsilon = min(epsilon, torch.sqrt(torch.tensor(epsilon))*torch.exp(-entropy))
indices_to_remove = probabilities < new_epsilon
max_word = torch.argmax(logits, dim=-1)
indices_to_remove[..., max_word.squeeze()] = 0
new_scores = logits.masked_fill(indices_to_remove, float("-inf"))
return new_scores
def sample(logits, sampling_method, threshold, temperature):
"""Sample from the logits with a specific sampling strategy."""
if sampling_method == "top_p":
probs = F.softmax(top_p_sampling(logits, thres=threshold) / temperature, dim=-1)
elif sampling_method == "typical":
probs = F.softmax(typical_sampling(logits, thres=threshold) / temperature, dim=-1)
elif sampling_method == "eta":
probs = F.softmax(eta_sampling(logits, epsilon=threshold) / temperature, dim=-1)
else:
probs = F.softmax(logits / temperature, dim=-1)
return torch.multinomial(probs[-1,-1,:], 1)
def sample_with_prob(logits, sampling_method, threshold, temperature):
"""Sample from the logits with a specific sampling strategy and return the token and its probability."""
# temporarily apply the sampling method to logits
logits = logits / temperature
# logits = add_gumbel_noise(logits, temperature)
if sampling_method == "top_p":
modified_logits = top_p_sampling(logits, thres=threshold)
elif sampling_method == "typical":
modified_logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
modified_logits = eta_sampling(logits, epsilon=threshold)
else:
modified_logits = logits # 其他情况直接使用原始logits
# print(modified_logits.shape)
# 应用温度调整并计算概率
# probs = F.softmax(modified_logits / temperature, dim=-1)
probs = F.softmax(modified_logits, dim=-1)
# 获取最后一个位置的概率分布
# probs_last = probs[-1, -1, :]
# print(probs.shape)
probs_last = probs[-1, -1, :]
# 采样
sampled_token = torch.multinomial(probs_last, num_samples=1)
# 获取对应的概率值
prob_value = probs_last[sampled_token]
return sampled_token, prob_value.squeeze()
def top_p_sampling_fast(logits, thres=0.9):
"""
logits: Tensor of shape [B, L, V]
Returns: logits with low-prob tokens masked as -inf, shape [B, L, V]
"""
# Step 1: sort logits and get indices
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) # [B, L, V]
# Step 2: compute cumulative probs
probs = F.softmax(sorted_logits, dim=-1) # [B, L, V]
cum_probs = torch.cumsum(probs, dim=-1) # [B, L, V]
# Step 3: mask tokens beyond cumulative threshold
sorted_mask = cum_probs > thres
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False # always keep at least one token
# Step 4: scatter back to original order
# Create mask of same shape as logits, default False
mask = torch.zeros_like(logits, dtype=torch.bool) # [B, L, V]
mask = mask.scatter(-1, sorted_indices, sorted_mask)
# Step 5: mask logits
logits = logits.masked_fill(mask, float('-inf')) # final masked logits
return logits
def sample_with_prob_fast(logits, sampling_method="top_p", threshold=0.9, temperature=1.0, mask_indices=None):
"""
logits: [B*T, num_sub_tokens, vocab_size]
mask_indices: mask indicating which tokens to sample, shape = [B*T, num_sub_tokens]
"""
if temperature != 1.0:
logits = logits / temperature
if sampling_method == "top_p":
logits = top_p_sampling_fast(logits, thres=threshold) # should support batch
elif sampling_method == "typical":
logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
logits = eta_sampling(logits, epsilon=threshold)
# else: keep logits as-is
probs = torch.softmax(logits, dim=-1) # [B*T, num_sub_tokens, vocab_size]
B, L, V = probs.shape
probs_flat = probs.view(-1, V) # [(B*T * num_sub_tokens), V]
# 采样multinomial 不能一次性处理 3D展平后采样
sampled = torch.multinomial(probs_flat, num_samples=1) # [(B*T * num_sub_tokens), 1]
sampled = sampled.view(B, L) # [B*T, num_sub_tokens]
sampled_probs = torch.gather(probs, 2, sampled.unsqueeze(-1)).squeeze(-1) # [B*T, num_sub_tokens]
return sampled, sampled_probs

View File

@ -0,0 +1,228 @@
from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, in_size, out_size, hidden_size, dropout):
super().__init__()
self.out_size = out_size
self.layer = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(hidden_size, out_size)
)
def forward(self, x):
return self.layer(x)
class extendedMLP(nn.Module):
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
super().__init__()
self.input_size = in_size
self.layers = nn.ModuleList()
if num_layers == 1:
# Only one layer
self.layers.append(nn.Linear(in_size, out_size))
return
elif num_layers > 1:
# First layer
self.layers.append(nn.Linear(in_size, hidden_size))
self.layers.append(nn.Dropout(dropout))
self.layers.append(nn.ReLU())
# Intermediate layers
if num_layers > 2:
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
self.layers.append(nn.Linear(hidden_size, hidden_size))
self.layers.append(nn.Dropout(dropout))
self.layers.append(nn.ReLU())
# Last layer
self.layers.append(nn.Linear(hidden_size, out_size))
else:
raise ValueError("num_layers should be a positive integer")
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class multiMLP(nn.Module):
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
super().__init__()
self.out_size = out_size
self.layer = nn.ModuleList([MLP(in_size, out_size, hidden_size, dropout) for _ in pred_order])
def forward(self, x, choice):
'''
x: B x T x d_model
choice: token type from self.pred_order (str or list of str)
'''
if isinstance(choice, str):
idx = self.pred_order.index(choice)
return self.layer[idx](x)
elif len(choice) > 1 and not isinstance(choice, str):
raise ValueError("multiMLP doesn't support parallel prediction")
class ResidualLayerNormModule(nn.Module):
def __init__(self, submodule: nn.Module):
super().__init__()
self.submodule = submodule
if submodule.__class__.__name__ == 'MultiheadAttention':
self.layer_norm = nn.LayerNorm(self.submodule.embed_dim)
else:
self.layer_norm = nn.LayerNorm(self.submodule.input_size)
def forward_attention(self, q, k, v, attn_mask, type):
attn_output, _ = self.submodule(q, k, v, attn_mask=attn_mask, need_weights=False, average_attn_weights=False)
return self.layer_norm(attn_output + q)
def forward_mlp(self, x):
return self.layer_norm(self.submodule(x) + x)
class MultiProj_hidden2logit(nn.Module):
def __init__(self, dim, vocab_sizes):
super().__init__()
self.layers = nn.ModuleDict({
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
})
def forward(self, hidden_vec, feature):
logit = self.layers[f"layer_{feature}"](hidden_vec)
return logit
class MultiProj_catvec2hidden(nn.Module):
def __init__(self, config, par_pred_keys, seq_pred_keys):
super().__init__()
'''
This class is used in SQstyleEachEmbStrategy
par_pred_keys: list of independent features(These tokens are predicted in parallel)
seq_pred_keys: list of sequential features(These tokens are predicted sequentially)
'''
net_param = config.nn_params
self.d_model = net_param.model.d_model
independent_emb_size = 0
for key in par_pred_keys:
independent_emb_size += net_param.emb[key]
self.layers = nn.ModuleDict({
'layer_independent': nn.Linear(self.d_model + independent_emb_size, self.d_model),
**{f"layer_{key}": nn.Linear(self.d_model + net_param.emb[key], self.d_model) for key in seq_pred_keys}
})
self.par_pred_keys = par_pred_keys
self.seq_pred_keys = seq_pred_keys
self.dropout = nn.Dropout(0.1)
self.relu = nn.ReLU()
def forward(self, x, choice):
'''
x: B x T x (d_model + emb_size)
choice: key type (str or list of str)
'''
if isinstance(choice, str): # single key
assert choice in self.seq_pred_keys
output = self.layers[f"layer_{choice}"](x)
return self.relu(self.dropout(output))
elif len(choice) > 1 and not isinstance(choice, str): # multiple keys, parallel
assert choice == self.par_pred_keys # the order of choice should be the same as the order of self.par_pred_keys
output = self.layers['layer_independent'](x)
return self.relu(self.dropout(output))
def mask_tensor(tensor, mask_rate=0.15):
# Get the size of the tensor
batch_size, seq_len, dim = tensor.size()
# Calculate the total number of elements and the number to mask
total_elements = batch_size * seq_len
num_to_mask = int(total_elements * mask_rate)
# Create a 1D binary mask where 1 indicates that element will be masked.
# Start by creating a tensor of zeros with length equal to the total number of elements.
mask = torch.zeros(total_elements).to(tensor.device)
# Set `num_to_mask` random indices to 1 (masking)
indices_to_mask = torch.randperm(total_elements)[:num_to_mask]
mask[indices_to_mask] = 1
# Reshape the mask to match the original tensor's shape
mask = mask.reshape(batch_size, seq_len)
mask = mask.unsqueeze(2) # B x T x 1
masked_tensor = tensor * (mask == 0).float() # B x T x d_model
return masked_tensor
def generate_causality_mask_on_window(size, window_size):
mask = torch.zeros((size, size))
for i in range(size):
mask[i, i+window_size:] = 1
return mask.bool()
# generate boolean mask, if the value is 1 or true, it means the value is masked
# considers BOS token and mask margin
def generate_CA_mask(tgt_len, memory_len, mask_margin=0):
mask = torch.triu(torch.ones((tgt_len, memory_len)), diagonal=mask_margin+1)
return mask.bool()
# generate boolean mask, if the value is 1 or true, it means the value is masked
def generate_SA_mask(tgt_len):
mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1)
return mask.bool()
def generate_none_causality_mask(tgt_len, memory_len):
mask = torch.zeros((tgt_len, memory_len))
return mask.bool()
class DecoderLayer(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
self.dropout = nn.Dropout(dropout)
def forward(self, input_dict):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
# cross attention
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
return output_dict
class TransformerLayer(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.self_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
self.dropout = nn.Dropout(dropout)
def forward(self, input_dict):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
# self attention
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
input_dict['input_seq'] = attn_output
# cross attention
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
return output_dict
class FeatureEnricher(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
self.dropout = nn.Dropout(dropout)
def forward(self, input_dict):
'''
input_dict = {'input_seq': input_seq, 'memory': memory}
'''
# cross attention
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], None, type='feature_enrichment')
attn_output = self.residual_FF.forward_mlp(attn_output)
attn_output = self.dropout(attn_output)
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
return output_dict

1280
Amadeus/sub_decoder_zoo.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,46 @@
from sf2utils.sf2parse import Sf2File
def print_sorted_presets(sf2_path):
presets_info = []
with open(sf2_path, 'rb') as f:
sf2 = Sf2File(f)
for preset in sf2.presets:
try:
# 尝试直接读取
name = getattr(preset, 'name', '???').strip('\x00')
bank = getattr(preset, 'bank', None)
program = getattr(preset, 'preset', None)
# 如果获取不到,再尝试从子属性中取
if bank is None or program is None:
for attr in dir(preset):
attr_value = getattr(preset, attr)
if hasattr(attr_value, 'bank') and hasattr(attr_value, 'preset'):
bank = attr_value.bank
program = attr_value.preset
name = getattr(attr_value, 'name', name).strip('\x00')
break
# 收集有效结果
if bank is not None and program is not None:
presets_info.append((program, bank, name))
except Exception as e:
print(f"Error reading preset: {e}")
# 按 program 升序排序(若需要按 bank 再 program改为 sorted(..., key=lambda x: (x[1], x[0]))
presets_info.sort(key=lambda x: x[0])
# 打印结果
print(f"{'Program':<8} {'Bank':<6} {'Preset Name'}")
print("-" * 40)
for program, bank, name in presets_info:
print(f"{program:<8} {bank:<6} {name}")
# DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
# 替换为你的 sf2 文件路径
sf2_path = "/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2"
print_sorted_presets(sf2_path)

View File

@ -0,0 +1,94 @@
import random
from typing import Union
import torch
class Augmentor:
def __init__(
self,
vocab,
aug_type:Union[str, None],
input_length:int
):
self.vocab = vocab
self.aug_type = aug_type
self.input_length = input_length
self.feature_list = vocab.feature_list
self.num_features = len(self.feature_list)
self.encoding_scheme = vocab.encoding_scheme
self.pitch_idx = self.feature_list.index('pitch')
if 'chord' in self.feature_list:
self.chord_idx = self.feature_list.index('chord')
def _get_shift(self, segment):
# the pitch vocab has ignore token in 0 index
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
pitch_mask = segment != 0
pitch_segment = segment[pitch_mask[:,self.pitch_idx], self.pitch_idx]
# check if tensor is empty
if pitch_segment.numel() == 0:
shift = 0
else:
lowest_pitch = max(12, torch.min(pitch_segment))
highest_pitch = min(119, torch.max(pitch_segment))
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) > 11)[0][-1].item()
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) < 120)[0][-1].item()
shift = random.randint(-lower_shift_bound, upper_shift_bound)
else: # remi
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
segemnt_pitch_mask = mask_for_pitch[segment]
segment_pitch = segment * segemnt_pitch_mask
segment_pitch = segment_pitch[segment_pitch != 0]
# check if tensor is empty
if segment_pitch.numel() == 0:
shift = 0
else:
lower_bound = torch.argwhere(mask_for_pitch == 1)[0].item()
upper_bound = torch.argwhere(mask_for_pitch == 1)[-1].item()
lowest_pitch = max(lower_bound, torch.min(segment_pitch))
highest_pitch = min(upper_bound, torch.max(segment_pitch))
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) >= lower_bound)[0][-1].item()
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) <= upper_bound)[0][-1].item()
shift = random.randint(-lower_shift_bound, upper_shift_bound)
return shift
# TODO: arrange hard coded part
def __call__(self, segment):
'''
input_tensor is segments of x, y
for transformer_xl, the shape of x, y is [max_num_segments, input_length, num_features]
so we need to change the shape of x, y to [max_num_segments*input_length, num_features]
'''
if self.aug_type == 'random':
shift = self._get_shift(segment)
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
# pitch augmentation
segment_pitch_mask = segment != 0
new_segment = segment.clone()
new_segment[segment_pitch_mask[:,self.pitch_idx], self.pitch_idx] += shift
if 'chord' in self.feature_list:
# chord augmentation
segment_chord_mask = (segment[:,self.chord_idx] != 0) & (segment[:,self.chord_idx] != 1)
new_segment[segment_chord_mask, self.chord_idx] = (((new_segment[segment_chord_mask, self.chord_idx]-2) % 12) + shift ) % 12 + ((new_segment[segment_chord_mask, self.chord_idx]-2) // 12) * 12 + 2
segment = new_segment
else: # remi
# choose random interger between -5 and 6
# the augmented results from shift -6 and 6 are same, so we choose -5 and 6
# pitch augmentation
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
segment_pitch_mask = mask_for_pitch[segment]
new_segment = segment.clone()
new_segment_valid = (new_segment + shift) * segment_pitch_mask
new_segment = new_segment * (1 - segment_pitch_mask) + new_segment_valid
if 'chord' in self.feature_list:
# chord augmentation
mask_for_chord = self.vocab.total_mask['chord'].clone().to(segment.device)
chord_n_n_idx = torch.argwhere(mask_for_chord == 1)[-1].item()
mask_for_chord[chord_n_n_idx] = 0
start_idx_chord = self.vocab.remi_vocab_boundaries_by_key['chord'][0]
segment_chord_mask = mask_for_chord[segment]
new_segment_valid = ((((new_segment - start_idx_chord) % 12 + shift) % 12) + ((new_segment - start_idx_chord) // 12) * 12 + start_idx_chord) * segment_chord_mask
new_segment = new_segment * (1 - segment_chord_mask) + new_segment_valid
segment = new_segment
return segment

View File

@ -0,0 +1,207 @@
import random
from collections import defaultdict
import torch
import numpy as np
import random
def reverse_shift_and_pad(tune_in_idx, slice_boundary=4):
new_lst = [curr_elems[:slice_boundary] + next_elems[slice_boundary:] for curr_elems, next_elems in zip(tune_in_idx, tune_in_idx[1:])]
return new_lst
def reverse_shift_and_pad_for_tensor(tensor, first_pred_feature):
'''
tensor: [batch_size x seq_len x feature_size]
'''
if first_pred_feature == 'type':
return tensor
if tensor.shape[-1] == 8:
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
elif tensor.shape[-1] == 7:
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'pitch':4, 'duration':5, 'velocity':6}
elif tensor.shape[-1] == 5:
slice_boundary_dict = {'type':0, 'beat':1, 'instrument':2, 'pitch':3, 'duration':4}
elif tensor.shape[-1] == 4:
slice_boundary_dict = {'type':0, 'beat':1, 'pitch':2, 'duration':3}
slice_boundary = slice_boundary_dict[first_pred_feature]
new_tensor = torch.zeros_like(tensor)
new_tensor[..., :, :slice_boundary] = tensor[..., :, :slice_boundary]
new_tensor[..., :-1, slice_boundary:] = tensor[..., 1:, slice_boundary:]
return new_tensor
def shift_and_pad(tune_in_idx, first_pred_feature):
if first_pred_feature == 'type':
return tune_in_idx
if len(tune_in_idx[0]) == 8:
slice_boundary_dict = {'type':0, 'beat':-7, 'chord':-6, 'tempo':-5, 'instrument':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
elif len(tune_in_idx[0]) == 7:
slice_boundary_dict = {'type':0, 'beat':-6, 'chord':-5, 'tempo':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
elif len(tune_in_idx[0]) == 5:
slice_boundary_dict = {'type':0, 'beat':-4, 'instrument':-3, 'pitch':-2, 'duration':-1}
elif len(tune_in_idx[0]) == 4:
slice_boundary_dict = {'type':0, 'beat':-3, 'pitch':-2, 'duration':-1}
slice_boundary = slice_boundary_dict[first_pred_feature]
# Add an empty list padded with zeros at the beginning, and sos and eos tokens are not shifted
padded_tune_in_idx = torch.cat([torch.zeros(1, len(tune_in_idx[0]), dtype=torch.long), tune_in_idx], dim=0)
new_tensor = torch.zeros_like(padded_tune_in_idx)
new_tensor[:, slice_boundary:] = padded_tune_in_idx[:, slice_boundary:]
new_tensor[:-1, :slice_boundary] = padded_tune_in_idx[1:, :slice_boundary]
return new_tensor
class VanillaTransformer_compiler():
def __init__(
self,
data_list,
augmentor,
eos_token,
input_length,
first_pred_feature,
encoding_scheme
):
self.data_list = data_list
self.augmentor = augmentor
self.eos_token = eos_token
self.input_length = input_length
self.first_pred_feature = first_pred_feature
self.encoding_scheme = encoding_scheme
def make_segments(self, data_type):
segments = []
tune_name2segment = defaultdict(list)
segment2tune_name = []
num_segments = 0
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
# shift and pad
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
if data_type == 'train':
if len(tune_in_idx) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
else:
start_point = 0
while start_point + self.input_length+1 < len(tune_in_idx):
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment = tune_in_idx[start_point:start_point + self.input_length+1]
segments.append([segment, mask])
segment2tune_name.append(tune_name)
assert len(segment) == self.input_length+1
# Randomly choose the start point for the next segment, which is in the range of half of the current segment to the end of the current segment
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
# if text controled,we only use the first segment
# add the last segment
if len(tune_in_idx[start_point:]) < self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
else: # for validset
for i in range(0, len(tune_in_idx), self.input_length+1):
segment = tune_in_idx[i:i+self.input_length+1]
if len(segment) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([segment, padding_seq], dim=0)
segment2tune_name.append(tune_name)
segments.append([segment, mask])
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
else:
mask = torch.ones(self.input_length+1, dtype=torch.long)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
segments.append([segment, mask])
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
assert len(segment) == self.input_length+1
return segments, tune_name2segment, segment2tune_name
def make_segments_iters(self, data_type):
tune_name2segment = defaultdict(list)
segment2tune_name = []
num_segments = 0
# shuffle the data_list
if data_type == 'train':
random.shuffle(self.data_list)
print("length of data_list:", len(self.data_list))
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
# shift and pad
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
if data_type == 'train':
if len(tune_in_idx) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
else:
start_point = 0
while start_point + self.input_length+1 < len(tune_in_idx):
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment = tune_in_idx[start_point:start_point + self.input_length+1]
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
assert len(segment) == self.input_length+1
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
# break
if len(tune_in_idx[start_point:]) < self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
else: # for validset
for i in range(0, len(tune_in_idx), self.input_length+1):
segment = tune_in_idx[i:i+self.input_length+1]
if len(segment) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([segment, padding_seq], dim=0)
segment2tune_name.append(tune_name)
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
yield [segment, mask], tune_name2segment, segment2tune_name
else:
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment2tune_name.append(tune_name)
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
yield [segment, mask], tune_name2segment, segment2tune_name
assert len(segment) == self.input_length+1

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,404 @@
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
class MuteWarn:
def __enter__(self):
self._init_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._init_stdout
def save_score_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
with MuteWarn():
convert = converter.parse(midi_fn)
convert.write('musicxml.png', fp=file_name)
def save_pianoroll_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
midi_obj_muspy = muspy.read_midi(midi_fn)
midi_obj_muspy.show_pianoroll(track_label='program', preset='frame')
plt.gcf().set_size_inches(20, 10)
plt.savefig(file_name)
plt.close()
def save_wav_from_midi(midi_fn, file_name, qpm=80):
assert isinstance(midi_fn, str)
with MuteWarn():
music = muspy.read_midi(midi_fn)
music.tempos[0].qpm = qpm
music.write_audio(file_name, rate=44100, gain=3)
def save_wav_from_midi_fluidsynth(midi_fn, file_name, gain=3):
assert isinstance(midi_fn, str)
fs = FluidSynth(gain=gain)
fs.midi_to_audio(midi_fn, file_name)
class MidiDecoder4REMI:
def __init__(
self,
vocab,
in_beat_resolution,
dataset_name
):
self.vocab = vocab
self.in_beat_resolution = in_beat_resolution
self.dataset_name = dataset_name
if dataset_name == 'SymphonyMIDI':
self.gain = 0.7
elif dataset_name == 'SOD' or dataset_name == 'LakhClean':
self.gain = 1.1
elif dataset_name == 'Pop1k7' or dataset_name == 'Pop909':
self.gain = 2.5
else:
self.gain = 1.5
def __call__(self, generated_output, output_path=None):
'''
generated_output: list of tensor, the tensor
'''
idx2event = self.vocab.idx2event
if generated_output.dim() == 2:
generated_output = generated_output.squeeze(0)
events = [idx2event[token.item()] for token in generated_output]
midi_obj = miditoolkit.midi.parser.MidiFile()
if 'tempo' not in idx2event.keys():
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
instr_notes_dict = defaultdict(list)
for i in range(len(events)-2):
cur_event = events[i]
# print(cur_event)
name = cur_event.split('_')[0]
attr = cur_event.split('_')
if name == 'Bar':
bar_pos += cur_bar_resol
if 'time' in cur_event:
cur_num, cur_denom = attr[-1].split('/')
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif name == 'Beat':
beat_pos = int(attr[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks
elif name == 'Chord':
chord_text = attr[1] + '_' + attr[2]
midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
elif name == 'Tempo':
midi_obj.tempo_changes.append(
TempoChange(tempo=int(attr[1]), time=cur_pos))
elif name == 'Instrument':
cur_instr = int(attr[1])
else:
if len(self.vocab.feature_list) == 7 or len(self.vocab.feature_list) == 8:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1] and \
'Note_Velocity' in events[i+2]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = int(events[i+2].split('_')[-1])
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
elif len(self.vocab.feature_list) == 4 or len(self.vocab.feature_list) == 5:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = 90
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
# make subdir
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4CP(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_chord_tempo(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0:
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, tempo, chord,'beat, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if 'time_signature' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['type']] == 'Metrical':
if 'time_signature' in token_with_7infos[feature2idx['beat']]:
cur_num, cur_denom = token_with_7infos[feature2idx['beat']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['beat']] == 'Bar':
bar_pos += cur_bar_resol
elif 'Beat' in str(token_with_7infos[feature2idx['beat']]):
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# chord and tempo
midi_obj = self._update_chord_tempo(midi_obj, cur_pos, token_with_7infos, feature2idx)
elif token_with_7infos[feature2idx['type']] == 'Note':
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1]
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 80
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
output_music_dir = os.path.join(os.path.dirname(output_path), 'music')
if not os.path.exists(output_music_dir):
os.makedirs(output_music_dir)
midi_obj.dump(output_path)
save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, output_music_dir.replace('.mid', '.wav'), gain=self.gain)
return midi_obj
class MidiDecoder4NB(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_additional_info(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0 and token_with_7infos[feature2idx['chord']] != 'Chord_N_N':
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, beat, chord, tempo, intrument, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if token_with_7infos[feature2idx['type']] == 'Empty_Bar' or token_with_7infos[feature2idx['type']] == 'SNN':
bar_pos += cur_bar_resol
elif 'NNN' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
if 'Beat' in str(token_with_7infos[feature2idx['beat']]) or 'CONTI' in str(token_with_7infos[feature2idx['beat']]):
if 'Beat' in str(token_with_7infos[feature2idx['beat']]): # when beat is not CONTI beat is updated
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# update chord and tempo
midi_obj = self._update_additional_info(midi_obj, cur_pos, token_with_7infos, feature2idx)
# note
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1] # duration between 1~192
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 90
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj

View File

@ -0,0 +1,208 @@
import torch
import numpy as np
from collections import Counter
# TODO: refactor hard coded values
def check_syntax_errors_in_inference_for_nb(generated_output, feature_list):
generated_output = generated_output.squeeze(0)
type_idx = feature_list.index('type')
beat_idx = feature_list.index('beat')
type_beat_list = []
for token in generated_output:
type_beat_list.append((token[type_idx].item(), token[beat_idx].item())) # type, beat
last_note = 1
beat_type_unmatched_error_list = []
num_unmatched_errors = 0
beat_backwards_error_list = []
num_backwards_errors = 0
for type_beat in type_beat_list:
if type_beat[0] == 4: # same bar, new beat
if type_beat[1] == 0 or type_beat[1] == 1:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(type_beat)
if type_beat[1] <= last_note:
num_backwards_errors += 1
beat_backwards_error_list.append([last_note, type_beat])
else:
last_note = type_beat[1] # update last note
elif type_beat[0] >= 5: # new bar, new beat
if type_beat[1] == 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(type_beat)
last_note = 1
unmatched_error_rate = num_unmatched_errors / len(type_beat_list)
backwards_error_rate = num_backwards_errors / len(type_beat_list)
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
return type_beat_errors_dict
def check_syntax_errors_in_inference_for_cp(generated_output, feature_list):
generated_output = generated_output.squeeze(0)
type_idx = feature_list.index('type')
beat_idx = feature_list.index('beat')
pitch_idx = feature_list.index('pitch')
duration_idx = feature_list.index('duration')
last_note = 1
beat_type_unmatched_error_list = []
num_unmatched_errors = 0
beat_backwards_error_list = []
num_backwards_errors = 0
for token in generated_output:
if token[type_idx].item() == 2: # Metrical
if token[pitch_idx].item() != 0 or token[duration_idx].item() != 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(token)
if token[beat_idx].item() == 1: # new bar
last_note = 1 # last note will be updated in the next token
elif token[beat_idx].item() != 0 and token[beat_idx].item() <= last_note:
num_backwards_errors += 1
last_note = token[beat_idx].item() # update last note
beat_backwards_error_list.append([last_note, token])
else:
last_note = token[beat_idx].item() # update last note
if token[type_idx].item() == 3: # Note
if token[beat_idx].item() != 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(token)
unmatched_error_rate = num_unmatched_errors / len(generated_output)
backwards_error_rate = num_backwards_errors / len(generated_output)
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
return type_beat_errors_dict
def check_syntax_errors_in_inference_for_remi(generated_output, vocab):
generated_output = generated_output.squeeze(0)
# to check duration errors
beat_mask = vocab.total_mask['beat'].to(generated_output.device)
beat_mask_for_target = beat_mask[generated_output]
beat_target = generated_output * beat_mask_for_target
bar_mask = vocab.total_mask['type'].to(generated_output.device)
bar_mask_for_target = bar_mask[generated_output]
bar_target = (generated_output+1) * bar_mask_for_target # as bar token in 0 in remi vocab, we add 1 to bar token
target = beat_target + bar_target
target = target[target!=0]
# collect beats in between bars(idx=1)
num_backwards_errors = 0
collected_beats = []
total_beats = 0
for token in target:
if token == 1 or 3 <= token <= 26: # Bar_None, or Bar_time_signature
collected_beats_tensor = torch.tensor(collected_beats)
diff = torch.diff(collected_beats_tensor)
num_error_beats = torch.where(diff<=0)[0].shape[0]
num_backwards_errors += num_error_beats
collected_beats = []
else:
collected_beats.append(token.item())
total_beats += 1
if total_beats != 0:
backwards_error_rate = num_backwards_errors / total_beats
else:
backwards_error_rate = 0
# print(f"error rate in beat backwards: {backwards_error_rate}")
return {'beat_backwards_error': backwards_error_rate}
def type_beat_errors_in_validation_nb(beat_prob, answer_type, input_beat, mask):
bool_mask = mask.bool().flatten() # (b*t)
pred_beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
valid_pred_beat_idx = pred_beat_idx[bool_mask] # valid beat_idx
answer_type = answer_type.flatten() # (b*t)
valid_type_input = answer_type[bool_mask] # valid answer_type
type_beat_list = []
for i in range(len(valid_pred_beat_idx)):
type_beat_list.append((valid_type_input[i].item(), valid_pred_beat_idx[i].item())) # type, beat
input_beat = input_beat.flatten()
valid_input_beat = input_beat[bool_mask]
last_note = 1
num_unmatched_errors = 0
num_backwards_errors = 0
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
# update last note
if input_beat_idx.item() >= 1: # beat
last_note = input_beat_idx.item()
if type_beat[0] == 4: # same bar, new beat
if type_beat[1] == 0 or type_beat[1] == 1:
num_unmatched_errors += 1
if type_beat[1] <= last_note:
num_backwards_errors += 1
elif type_beat[0] >= 5: # new bar, new beat
if type_beat[1] == 0:
num_unmatched_errors += 1
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
def type_beat_errors_in_validation_cp(beat_prob, answer_type, input_beat, mask):
bool_mask = mask.bool().flatten() # (b*t)
beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
valid_beat_idx = beat_idx[bool_mask] # valid beat_idx
answer_type = answer_type.flatten() # (b*t)
valid_type_input = answer_type[bool_mask] # valid answer_type
type_beat_list = []
for i in range(len(valid_beat_idx)):
type_beat_list.append((valid_type_input[i].item(), valid_beat_idx[i].item())) # type, beat
input_beat = input_beat.flatten()
valid_input_beat = input_beat[bool_mask]
last_note = 1
num_unmatched_errors = 0
num_backwards_errors = 0
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
# update last note
if input_beat_idx.item() == 1: # bar
last_note = 1
elif input_beat_idx.item() >= 2: # new beat
last_note = input_beat_idx.item()
# check errors
if type_beat[0] == 2: # Metrical
if type_beat[1] == 0: # ignore
num_unmatched_errors += 1
elif type_beat[1] >= 2: # new beat
if type_beat[1] <= last_note:
num_backwards_errors += 1
elif type_beat[0] == 3: # Note
if type_beat[1] != 0:
num_unmatched_errors += 1
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
def get_beat_difference_metric(prob_dict, arranged_prob_dict, mask):
orign_beat_prob = prob_dict['beat'] # b x t x vocab_size
arranged_beat_prob = arranged_prob_dict['beat'] # b x t x vocab_size
# calculate similarity between original beat prob and arranged beat prob
origin_beat_token = torch.argmax(orign_beat_prob, dim=-1) * mask # b x t
arranged_beat_token = torch.argmax(arranged_beat_prob, dim=-1) * mask # b x t
num_same_beat = torch.sum(origin_beat_token == arranged_beat_token) - torch.sum(mask==0)
num_beat = torch.sum(mask==1)
beat_sim = (num_same_beat / num_beat).item() # scalar
# apply mask, shape of mask: b x t
orign_beat_prob = orign_beat_prob * mask.unsqueeze(-1) # b x t x vocab_size
arranged_beat_prob = arranged_beat_prob * mask.unsqueeze(-1)
# calculate cosine similarity between original beat prob and arranged beat prob
orign_beat_prob = orign_beat_prob.flatten(0,1) # (b*t) x vocab_size
arranged_beat_prob = arranged_beat_prob.flatten(0,1) # (b*t) x vocab_size
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
beat_cos_sim = cos(orign_beat_prob, arranged_beat_prob) # (b*t)
# exclude invalid tokens, zero padding tokens
beat_cos_sim = beat_cos_sim[mask.flatten().bool()] # num_valid_tokens
beat_cos_sim = torch.mean(beat_cos_sim).item() # scalar
return {'beat_cos_sim': beat_cos_sim, 'beat_sim': beat_sim}
def get_gini_coefficient(generated_output):
if len(generated_output.shape) == 3:
generated_output = generated_output.squeeze(0).tolist()
gen_list = [tuple(x) for x in generated_output]
else:
gen_list = generated_output.squeeze(0).tolist()
counts = Counter(gen_list).values()
sorted_counts = sorted(counts)
n = len(sorted_counts)
cumulative_counts = np.cumsum(sorted_counts)
cumulative_proportion = cumulative_counts / cumulative_counts[-1]
lorenz_area = sum(cumulative_proportion[:-1]) / n # Exclude the last element
equality_area = 0.5 # The area under line of perfect equality
gini = (equality_area - lorenz_area) / equality_area
return gini

View File

@ -0,0 +1,78 @@
import argparse
import os
import subprocess
from pydub import AudioSegment
'''
This file is a modified version of midi2audio.py from https://github.com/bzamecnik/midi2audio
Author: Bohumír Zámečník (@bzamecnik)
License: MIT, see the LICENSE file
'''
__all__ = ['FluidSynth']
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
DEFAULT_SAMPLE_RATE = 48000
DEFAULT_GAIN = 0.05
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
# DEFAULT_SAMPLE_RATE = 16000
# DEFAULT_GAIN = 0.20
class FluidSynth():
def __init__(self, sound_font=DEFAULT_SOUND_FONT, sample_rate=DEFAULT_SAMPLE_RATE, gain=DEFAULT_GAIN):
self.sample_rate = sample_rate
self.sound_font = os.path.expanduser(sound_font)
self.gain = gain
def midi_to_audio(self, midi_file: str, audio_file: str, verbose=True):
if verbose:
stdout = None
else:
stdout = subprocess.DEVNULL
# Convert MIDI to WAV
subprocess.call(
['fluidsynth', '-ni', '-g', str(self.gain), self.sound_font, midi_file, '-F', audio_file, '-r', str(self.sample_rate)],
stdout=stdout
)
# Convert WAV to MP3
# mp3_path = audio_file.replace('.wav', '.mp3')
# AudioSegment.from_wav(audio_file).export(mp3_path, format="mp3")
# # Delete the temporary WAV file
# os.remove(audio_file)
def play_midi(self, midi_file):
subprocess.call(['fluidsynth', '-i', '-g', str(self.gain), self.sound_font, midi_file, '-r', str(self.sample_rate)])
def parse_args(allow_synth=True):
parser = argparse.ArgumentParser(description='Convert MIDI to audio via FluidSynth')
parser.add_argument('midi_file', metavar='MIDI', type=str)
if allow_synth:
parser.add_argument('audio_file', metavar='AUDIO', type=str, nargs='?')
parser.add_argument('-s', '--sound-font', type=str,
default=DEFAULT_SOUND_FONT,
help='path to a SF2 sound font (default: %s)' % DEFAULT_SOUND_FONT)
parser.add_argument('-r', '--sample-rate', type=int, nargs='?',
default=DEFAULT_SAMPLE_RATE,
help='sample rate in Hz (default: %s)' % DEFAULT_SAMPLE_RATE)
return parser.parse_args()
def main(allow_synth=True):
args = parse_args(allow_synth)
fs = FluidSynth(args.sound_font, args.sample_rate)
if allow_synth and args.audio_file:
fs.midi_to_audio(args.midi_file, args.audio_file)
else:
fs.play_midi(args.midi_file)
def main_play():
"""
A method for the `midiplay` entry point. It omits the audio file from args.
"""
main(allow_synth=False)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,65 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
- nn_params: nb8_embSum_diff_t2m_150M_finetunning
# - nn_params: nb8_embSum_diff_t2m_150M_pretraining
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M
# - nn_params: nb8_embSum_subFeedForward
# - nn_params: nb8_embSum_diff
# nn_params: nb8_SA_diff
# - nn_params: nb8_embSum_diff_main12head16dim512_ave
# - nn_params: nb8_embSum_NMT_main12_head_16_dim512
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
# captions_path: dataset/symphonyNet/syd-caption.json
use_ddp: True # True, False | distributed data parallel
use_fp16: True # True, False | mixed precision training
use_diff: True # True,use diffusion in subdecoder
diff_steps: 8 # number of diffusion steps
use_dispLoss: True
lambda_weight: 0.5
tau: 0.5
train_params:
device: cuda
batch_size: 3
grad_clip: 1.0
num_iter: 300000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
iterations_per_training_cycle: 10 # number of iterations for logging training loss
iterations_per_validation_cycle: 5000 # number of iterations for validation process
input_length: 3072 # input sequence length3072
# you can use focal loss, it it's not used, set focal_gamma to 0
focal_alpha: 1
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.00005
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps
max_lr: 0.00015
gamma: 0.6 # the decay rate for 'cosineannealingwarmuprestarts'
# Distributed Data Parallel
world_size: 5 # 0 means no distributed training
gradient_accumulation_steps: 4 # 1 means no gradient accumulation
inference_params:
num_uncond_generation: 1 # number of unconditional generation
num_cond_generation: 3 # number of conditional generation
data_params:
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
split_ratio: 0.998 # train-validation-test split ratio
aug_type: pitch # random, null | pitch and chord augmentation type
general:
debug: False
make_log: True # True, False | update the log file in wandb online to your designated project and entity
infer_and_log: True # True, False | inference and log the results

View File

@ -0,0 +1,54 @@
defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
# - nn_params: nb8_embSum_diff
- nn_params: nb8_embSum_subFeedForward
# - nn_params: nb8_SA_diff
# - nn_params: nb8_embSum_diff_main12head16dim512_ave
# - nn_params: nb8_embSum_NMT_main12_head_16_dim512
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean
use_ddp: True # True, False | distributed data parallel
use_fp16: True # True, False | mixed precision training
use_diff: True # True,use diffusion in subdecoder
use_dispLoss: True
lambda_weight: 0.5
tau: 0.5
diff_steps: 8 # number of diffusion steps
train_params:
device: cuda
batch_size: 8
grad_clip: 1.0
num_iter: 25000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
num_cycles_for_model_checkpoint: 10 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
iterations_per_training_cycle: 10 # number of iterations for logging training loss
iterations_per_validation_cycle: 500 # number of iterations for validation process
input_length: 3072 # input sequence length3072
# you can use focal loss, it it's not used, set focal_gamma to 0
focal_alpha: 1
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.0001
decay_step_rate: 0.4 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 # number of warmup steps
max_lr: 0.00015
gamma: 0.6 # the decay rate for 'cosineannealingwarmuprestarts'
# Distributed Data Parallel
world_size: 5 # 0 means no distributed training
gradient_accumulation_steps: 1 # 1 means no gradient accumulation
inference_params:
num_uncond_generation: 1 # number of unconditional generation
num_cond_generation: 3 # number of conditional generation
data_params:
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
split_ratio: 0.99 # train-validation-test split ratio
aug_type: null # random, null | pitch and chord augmentation type
general:
debug: False
make_log: True # True, False | update the log file in wandb online to your designated project and entity
infer_and_log: True # True, False | inference and log the results

View File

@ -0,0 +1,20 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
input_length: 1024
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,20 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
input_length: 1024
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: cp
num_features: 5
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
partial_sequential_prediction: True
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,20 @@
encoding_scheme: cp
num_features: 7
vocab_name: MusicTokenVocabCP
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
partial_sequential_prediction: True
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
input_length: 1024
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: Parallel
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: RNN
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 5
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: SelfAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: Parallel
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: RNN
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 7
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: SelfAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SelfAttentionEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: CrossAttention
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 6
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,20 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,20 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: AverageEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 3
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 2
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 6
feature_enricher_use: True

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerCrossAttendDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerFinetuningDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 20
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerPrefixDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 20
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerCrossAttendDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: FeedForward
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,18 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: NestedMusicTransformer
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerDecoder
sub_decoder_name: Parallel
model_dropout: 0.1
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 512
num_layer: 6
num_head: 8
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 5
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 8
num_head: 8

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 7
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 8
num_head: 8

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 8
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 8
num_head: 8

View File

@ -0,0 +1,12 @@
encoding_scheme: remi
num_features: 8
vocab_name: LangTokenVocab
model_name: NestedMusicTransformer
input_embedder_name: SingleEmbedding
main_decoder_name: XtransformerDecoder
sub_decoder_name: SingleProjection
model_dropout: 0.1
main_decoder:
dim_model: 512
num_layer: 12
num_head: 16

View File

@ -0,0 +1,17 @@
program: train.py
method: grid
metric:
name: valid.total
goal: minimize
parameters:
train_params.batch_size:
values: [8]
train_params.focal_gamma:
values: [0, 1]
nn_params.main_decoder.input_length:
values: [8192]
command:
- python3
- ${program}
- ${args_no_hyphens}

428
Amadeus/train_utils.py Normal file
View File

@ -0,0 +1,428 @@
import math
from numpy import mask_indices
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
from collections import defaultdict
import torch.nn.functional as F
def add_conti_for_single_feature(tensor):
new_target = tensor.clone()
# Assuming tensor shape is [batch, sequence, features]
# Create a shifted version of the tensor
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
# The first element of each sequence cannot be a duplicate by definition
shifted_tensor[:, 0] = new_target[:, 0] + 1
# Identify where the original and shifted tensors are the same (duplicates)
duplicates = new_target == shifted_tensor
# Replace duplicates with 9999
new_target[duplicates] = 9999
return new_target
def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_params):
feature_prediction_order_dict = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]
}
if encoding_scheme == 'remi':
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'cp':
if nn_params.get("partial_sequential_prediction", False):
default_prediction_order = feature_prediction_order_dict[num_features]
prediction_order = [default_prediction_order[0], default_prediction_order[1:]]
else:
prediction_order = feature_prediction_order_dict[num_features]
elif encoding_scheme == 'nb':
assert target_feature in feature_prediction_order_dict[num_features], f"Target feature {target_feature} not in the selected sub-token set. Please check target feature in the config and num_features in nn_params."
default_prediction_order = feature_prediction_order_dict[num_features]
# Reorganize the prediction order based on the target_feature
target_index = default_prediction_order.index(target_feature)
prediction_order = default_prediction_order[target_index:] + default_prediction_order[:target_index]
return prediction_order
########################### Loss function ################################
class NLLLoss4REMI():
def __init__(
self,
focal_alpha:float,
focal_gamma:float,
):
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss_seq = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss, loss_seq
def __call__(self, logits, shifted_tgt, mask, vocab):
if vocab is not None:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
loss_by_class_normal = defaultdict(float)
shifted_tgt_with_mask = shifted_tgt * mask # [b, t]
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
for feature in vocab.feature_list:
feature_mask = vocab.total_mask[feature].to(answers_idx.device) # [327,]
mask_for_target = feature_mask[answers_idx] # [b*t]
normal_loss_seq_by_class = loss_seq * mask_for_target
if mask_for_target.sum().item() != 0:
loss_by_class_normal[feature+'_normal'] += (normal_loss_seq_by_class.sum().item() / mask_for_target.sum().item())
return loss, loss_by_class_normal
else:
loss, loss_seq = self.get_nll_loss(logits, shifted_tgt, mask)
return loss, None
class NLLLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
# clamp min value to 1e-7 to avoid log(0)
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss.sum() / mask.sum() # calculating mean loss considering mask
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token):
probs = logits.softmax(dim=-1)
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target] # [batch_size*seq_len]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = -self.alpha * (1-pt)**self.gamma * torch.log(pt) # [batch_size*seq_len]
loss = loss * total_mask # [batch_size*seq_len]
loss = loss.sum() / total_mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits_dict, shifted_tgt, mask, valid):
train_loss_list = []
log_loss_dict_normal = {}
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
if valid:
if key == 'type':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'beat':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'chord' or key == 'tempo' or key == 'instrument':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999)
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None)
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
def dispersive_loss(z, tau=0.5, eps=1e-8):
"""使用余弦距离的Dispersive Loss实现"""
B = z.size(0)
# 计算余弦相似度矩阵 [B, B]
z_norm = torch.nn.functional.normalize(z, p=2, dim=1) # 向量归一化
sim_matrix = torch.matmul(z_norm, z_norm.transpose(0, 1)) # 余弦相似度
# 转换为余弦距离 (1 - 相似度),排除对角线
mask = 1 - torch.eye(B, device=z.device)
cos_dist = (1 - sim_matrix) * mask
# 计算分散性损失与L2版本相同
exp_term = torch.exp(-cos_dist / tau)
mean_exp = exp_term.sum() / (B * (B - 1) + eps)
loss = -torch.log(mean_exp + eps)
return loss
class DiffusionLoss4CompoundToken():
def __init__(self, feature_list, focal_alpha:float, focal_gamma:float):
self.feature_list = feature_list
self.alpha = focal_alpha
self.gamma = focal_gamma
def get_nll_loss(self, logits, target, mask,mask_indices, p_mask):
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
if mask.ndim == 2:
mask = mask.flatten(0, 1)
# datatype of logits, target, mask_indices, p_mask should be the same
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
loss = (token_loss * mask[mask_indices]).sum() / mask[mask_indices].sum()
return loss
def get_nll_loss_for_logging(self, logits, target, mask, ignore_token, conti_token, mask_indices, p_mask):
if ignore_token is not None and conti_token is not None:
target_conti = add_conti_for_single_feature(target) # [batch_size*seq_len]
valid_mask = (target_conti != ignore_token) & (target_conti != conti_token) # [batch_size*seq_len]
elif ignore_token is not None and conti_token is None:
valid_mask = (target != ignore_token)
elif ignore_token is None and conti_token is None:
valid_mask = torch.ones_like(target).bool()
valid_mask = valid_mask.flatten(0, 1)
if logits.ndim == 3:
logits = logits.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
if mask_indices.ndim == 2:
mask_indices = mask_indices.flatten(0, 1)
if p_mask.ndim == 2:
p_mask = p_mask.flatten(0, 1)
token_loss = F.cross_entropy(
logits[mask_indices], # 直接索引 logits
target[mask_indices],
reduction='none'
) / p_mask[mask_indices]
total_mask = mask.flatten(0, 1) & valid_mask # [batch_size*seq_len]
loss = (token_loss * total_mask[mask_indices]).sum() / total_mask[mask_indices].sum()
return loss
def __call__(self, logits_dict, shifted_tgt, mask, mask_indices, p_mask, valid, input_dict=None,lambda_weight=0.5, tau=0.5):
train_loss_list = []
log_loss_dict_normal = {}
mask_indices = mask_indices.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
p_mask = p_mask.reshape(shifted_tgt.shape[0], shifted_tgt.shape[1], -1)
disp_loss = None
if input_dict is not None:
hidden_vec =input_dict['hidden_vec'] #bs,seq_len,dim
feat = hidden_vec.mean(dim=1) #bs,dim
disp_loss = dispersive_loss(feat, tau=tau) # scalar
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx])
train_loss_list.append(training_loss)
if valid:
if key == 'type':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'beat':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
elif key == 'chord' or key == 'tempo' or key == 'instrument':
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
else:
log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx])
k_normal = key + '_normal'
log_loss_dict_normal[k_normal] = log_normal_loss
total_loss = sum(train_loss_list) / len(train_loss_list)
if disp_loss is not None:
total_loss = total_loss + lambda_weight * disp_loss
log_loss_dict_normal['dispersion'] = disp_loss.item()
if valid:
return total_loss, log_loss_dict_normal
else:
return total_loss, None
class EncodecFlattenLoss():
def __init__(self, feature_list):
self.feature_list = feature_list
def get_nll_loss(self, logits, target, mask):
probs = logits.softmax(dim=-1)
if probs.ndim == 3:
probs = probs.flatten(0, 1) # [batch_size*seq_len x vocab_size]
if target.ndim == 2:
target = target.flatten(0, 1) # [batch_size*seq_len]
pt = probs[torch.arange(len(target)), target].clamp(1e-7, 1-1e-7) # [batch_size*seq_len]
loss_seq = -torch.log(pt) # [batch_size*seq_len]
loss_seq = loss_seq * mask.flatten(0, 1) # [batch_size*seq_len]
loss = loss_seq.sum() / mask.sum() # calculating mean loss considering mask
return loss
def __call__(self, logits, shifted_tgt, mask):
loss = self.get_nll_loss(logits, shifted_tgt, mask)
return loss
class EncodecMultiClassLoss(EncodecFlattenLoss):
def __init__(self, feature_list):
super().__init__(feature_list)
def __call__(self, logits_dict, shifted_tgt, mask):
train_loss_list = []
for idx, key in enumerate(self.feature_list):
training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask)
train_loss_list.append(training_loss)
total_loss = sum(train_loss_list) / len(train_loss_list)
return total_loss
########################### Learning rate Scheduler ################################
'''
This scheduler is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/#custom-cosineannealingwarmrestarts-1
It's basically a cosine annealing scheduler with warm restarts including two methods, warm up start and reducing maximum lr.
'''
class CosineAnnealingWarmUpRestarts(_LRScheduler):
def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1, eta_min=0):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
if T_up < 0 or not isinstance(T_up, int):
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
self.T_0 = T_0
self.T_mult = T_mult
self.base_eta_max = eta_max
self.eta_max = eta_max
self.T_up = T_up
self.T_i = T_0
self.gamma = gamma
self.cycle = 0
self.T_cur = last_epoch
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.T_cur == -1:
return self.base_lrs
elif self.T_cur < self.T_up:
return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
else:
return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.cycle += 1
self.T_cur = self.T_cur - self.T_i
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
else:
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
self.cycle = epoch // self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.cycle = n
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CosineLRScheduler(_LRScheduler):
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
class DispersiveLoss(nn.Module):
def __init__(self, loss_type='infonce_l2', tau=0.5, lambda_weight=0.5):
super().__init__()
self.loss_type = loss_type
self.tau = tau
self.lambda_weight = lambda_weight
def forward(self, features, diffusion_loss):
"""
features: 批次特征矩阵,形状为 [batch_size, feature_dim]
diffusion_loss: 原扩散损失
"""
batch_size = features.size(0)
# 计算距离矩阵
if self.loss_type == 'infonce_l2':
# 计算平方L2距离
dist_matrix = torch.cdist(features, features, p=2) ** 2
# 计算分散损失
exp_dist = torch.exp(-dist_matrix / self.tau)
disp_loss = torch.log(exp_dist.mean())
elif self.loss_type == 'hinge':
# Hinge损失假设阈值epsilon=1.0
dist_matrix = torch.cdist(features, features, p=2)
disp_loss = torch.max(torch.zeros_like(dist_matrix), 1.0 - dist_matrix).mean()
elif self.loss_type == 'covariance':
# 协方差损失
normalized_features = (features - features.mean(dim=0)) / features.std(dim=0)
cov_matrix = torch.matmul(normalized_features.T, normalized_features) / batch_size
# 非对角线元素平方和
mask = ~torch.eye(cov_matrix.size(0), dtype=torch.bool)
disp_loss = (cov_matrix[mask] ** 2).mean()
else:
raise ValueError("Unsupported loss type")
# 总损失 = 扩散损失 + lambda * 分散损失
total_loss = diffusion_loss + self.lambda_weight * disp_loss
return total_loss, disp_loss

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,949 @@
import torch
import torch.nn as nn
from x_transformers import Decoder, Encoder, PrefixDecoder, CrossAttender
from transformers import T5EncoderModel
from data_representation.vocab_utils import LangTokenVocab
class PosEncoding(nn.Module):
def __init__(self, emb_size, max_t):
super().__init__()
self.emb_size =emb_size
self.max_t = max_t
self.register_buffer('encoding', self._prepare_emb())
def _prepare_emb(self):
dim_axis = 10000**(torch.arange(self.emb_size//2) * 2 / self.emb_size) # 10000 ** (normalized values between 0~1 num_emb_dim)
timesteps = torch.arange(self.max_t)
pos_enc_in = timesteps.unsqueeze(1) / dim_axis.unsqueeze(0)
pos_enc_sin = torch.sin(pos_enc_in) # x values for sin are between 0 ~ 1 so the values could never be the same
pos_enc_cos = torch.cos(pos_enc_in)
pos_enc = torch.stack([pos_enc_sin, pos_enc_cos], dim=-1).reshape([self.max_t, self.emb_size])
return pos_enc
def forward(self, x):
return self.encoding[x]
class ResidualLayerNormModule(nn.Module):
def __init__(self, submodule):
super().__init__()
self.submodule = submodule
self.layer_norm = nn.LayerNorm(self.submodule.input_size)
def forward(self, x, mask=None, y=None):
if y is not None:
res_x = self.submodule(x, y, mask)
elif mask is not None:
res_x = self.submodule(x, mask)
else:
res_x = self.submodule(x)
x = x + res_x
return self.layer_norm(x)
class SingleEmbedding(nn.Module):
def __init__(
self,
vocab,
dim_model,
):
'''
Embedding layer for REMI
'''
super().__init__()
vocab_size = vocab.get_vocab_size()
self.embedding = nn.Embedding(vocab_size, dim_model)
def forward(self, x):
return self.embedding(x)
class MultiEmbedding(nn.Module):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int,
):
super().__init__()
'''
Embedding layer for compound tokens
'''
self.vocab_size = vocab.get_vocab_size()
self.feature_list = vocab.feature_list
self.dim_model = dim_model
self.layers = []
self._make_emb_layers()
self._init_params()
self._make_emb_boundaries_by_key()
def _init_params(self):
# apply kaiming init
for layer in self.layers:
if isinstance(layer, nn.Embedding):
nn.init.kaiming_normal_(layer.weight)
def _make_emb_layers(self):
vocab_sizes = [self.vocab_size[key] for key in self.feature_list]
self.embedding_sizes = [self.dim_model for _ in self.feature_list]
for vocab_size, embedding_size in zip(vocab_sizes, self.embedding_sizes):
if embedding_size != 0:
self.layers.append(nn.Embedding(vocab_size, embedding_size))
self.layers = nn.ModuleList(self.layers)
def _make_emb_boundaries_by_key(self):
'''
This function returns dict of boundaries for each embedding layer
'''
self.emb_boundary_by_key = {}
start_idx = 0
for key, emb_size in zip(self.feature_list, self.embedding_sizes):
if emb_size != 0:
self.emb_boundary_by_key[key] = (start_idx, start_idx + emb_size)
start_idx += emb_size
def forward(self, x):
emb = torch.cat([module(x[..., i]) for i, module in enumerate(self.layers)], dim=-1)
return emb
def __len__(self):
return len(self.layers)
def get_emb_by_key(self, key, token):
layer_idx = self.feature_list.index(key)
return self.layers[layer_idx](token)
class SummationEmbedder(MultiEmbedding):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__(vocab, dim_model)
def forward(self, seq):
emb_list = [module(seq[..., i]) for i, module in enumerate(self.layers)]
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
output = torch.sum(stacked_emb, dim=2) # B x T x emb_size
return output
class AverageEmbedder(MultiEmbedding):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__(vocab, dim_model)
def forward(self, seq):
emb_list = [module(seq[..., i]) for i, module in enumerate(self.layers)]
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
output = torch.mean(stacked_emb, dim=2) # B x T x emb_size
return output
class SelfAttentionEmbedder(MultiEmbedding):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__(vocab, dim_model)
self.dropout = 0.1
self.transformer_encoder = Encoder(
dim = dim_model,
depth = 1,
heads = 8,
attn_dropout = self.dropout,
ff_dropout = self.dropout,
attn_flash = True)
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, self.dim_model), requires_grad=True)
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff()
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn()
def _add_dropout_after_attn(self):
for layer in self.transformer_encoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(self.dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(self.dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self):
for layer in self.transformer_encoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(self.dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_encoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def _apply_window_on_input_vec(self, embeddings):
window_size = 1
zero_vec = torch.zeros(embeddings.shape[0], window_size-1, embeddings.shape[2], embeddings.shape[3]).to(embeddings.device) # B x (window_size-1) x num_features x emb_size
window_applied_input_vec = torch.cat([zero_vec, embeddings], dim=1) # B x (T+window_size-1) x num_features x emb_size
window_applied_input_vec = window_applied_input_vec.unfold(1, window_size, 1) # B x T x window_size x emb_size x num_features
window_applied_input_vec = window_applied_input_vec.transpose(3, 4) # B x T x window_size x num_features x emb_size
window_applied_input_vec = window_applied_input_vec.reshape(embeddings.shape[0]*embeddings.shape[1], -1, embeddings.shape[3]) # (B*T) x (num_features*window_size) x emb_size
return window_applied_input_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # (num_features*window_size+1)
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x (num_features*window_size+1)
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x (num_features*window_size+1) x emb_size
return tgt_pos
def forward(self, input_tokens):
'''
input_tokens: B x T x num_features
'''
# prepare input vector
emb_list = [module(input_tokens[..., i]) for i, module in enumerate(self.layers)] # B x T x 1 x emb_size
stacked_emb = torch.stack(emb_list, dim=2) # B x T x num_features x emb_size
# apply window
stacked_emb = self._apply_window_on_input_vec(stacked_emb)
# add CLS
cls = self.cls_embedding.repeat(stacked_emb.shape[0], 1, 1) # (B*T) x 1 x emb_size
input_emb = torch.cat([stacked_emb, cls], dim=1) # (B*T) x (num_features*window_size+1) x emb_size
output = self.transformer_encoder(input_emb) # (B*T) x (num_features*window_size+1) x emb_size
# extract CLS
output = output[:, -1, :].reshape((input_tokens.shape[0], input_tokens.shape[1], -1)) # B x T x emb_size
return output
class RVQMultiEmbedding(nn.Module):
def __init__(
self,
vocab:LangTokenVocab,
dim_model:int
):
super().__init__()
self.vocab_size = vocab.get_vocab_size()
self.dim_model = dim_model
self.features = vocab.feature_list
self.layers = []
self._make_emb_layers()
def _make_emb_layers(self):
vocab_sizes = [self.vocab_size[key] for key in self.features]
self.embedding_sizes = [self.dim_model for _ in self.features]
for vocab_size, embedding_size in zip(vocab_sizes, self.embedding_sizes):
if embedding_size != 0:
self.layers.append(nn.Embedding(vocab_size, embedding_size))
self.layers = nn.ModuleList(self.layers)
def forward(self, x):
embeddings = torch.zeros(x.shape[0], x.shape[1], self.dim_model).to(x.device)
emb_list = [module(x[:, (idx+1)%4::4]) for idx, module in enumerate(self.layers)]
for idx, emb in enumerate(emb_list):
embeddings[:, (idx+1)%4::4] = emb
return embeddings
def get_emb_by_key(self, key:str, token:torch.Tensor):
layer_idx = self.features.index(key)
return self.layers[layer_idx](token)
class XtransformerDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class XtransformerLargeCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class NewCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False,
use_rmsnorm=True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class NewCrossAttendwithRoPEDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class XtransformerPrefixDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = PrefixDecoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None):
assert context is not None, 'context should be provided for prefix decoder'
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerPretrainingDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerFinetuningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec
class XtransformerLargeFinetuningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec