first commit
This commit is contained in:
533
Amadeus/evaluation_utils.py
Normal file
533
Amadeus/evaluation_utils.py
Normal file
@ -0,0 +1,533 @@
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
from math import log
|
||||
from omegaconf import DictConfig
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import json
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import T5Tokenizer, T5EncoderModel
|
||||
|
||||
from . import model_zoo
|
||||
from .symbolic_encoding import data_utils
|
||||
from .model_zoo import AmadeusModel
|
||||
from .symbolic_encoding.data_utils import TuneCompiler
|
||||
from .symbolic_encoding.compile_utils import shift_and_pad
|
||||
from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
|
||||
from .symbolic_encoding import decoding_utils
|
||||
from .train_utils import adjust_prediction_order
|
||||
from data_representation import vocab_utils
|
||||
from data_representation.vocab_utils import LangTokenVocab
|
||||
|
||||
def wandb_style_config_to_omega_config(wandb_conf):
|
||||
# remove wandb related config
|
||||
for wandb_key in ["wandb_version", "_wandb"]:
|
||||
if wandb_key in wandb_conf:
|
||||
del wandb_conf[wandb_key] # wandb-related config should not be overrided!
|
||||
# print(wandb_conf)
|
||||
# remove nonnecessary fields such as desc and value
|
||||
for key in wandb_conf:
|
||||
# if 'desc' in wandb_conf[key]:
|
||||
# del wandb_conf[key]['desc']
|
||||
if isinstance(wandb_conf[key], dict) and 'value' in wandb_conf[key]:
|
||||
wandb_conf[key] = wandb_conf[key]['value']
|
||||
# 处理存在'value'的情况
|
||||
try:
|
||||
if 'value' in wandb_conf[key]:
|
||||
wandb_conf[key] = wandb_conf[key]['value']
|
||||
except:
|
||||
pass
|
||||
return wandb_conf
|
||||
|
||||
def get_dir_from_wandb_by_code(wandb_dir: Path, code:str) -> Path:
|
||||
for dir in wandb_dir.iterdir():
|
||||
if dir.name.endswith(code):
|
||||
return dir
|
||||
print(f'No such code in wandb_dir: {code}')
|
||||
return None
|
||||
|
||||
def get_best_ckpt_path_and_config(wandb_dir, code):
|
||||
dir = get_dir_from_wandb_by_code(wandb_dir, code)
|
||||
if dir is None:
|
||||
raise ValueError('No such code in wandb_dir')
|
||||
ckpt_dir = dir / 'files' / 'checkpoints'
|
||||
|
||||
config_path = dir / 'files' / 'config.yaml'
|
||||
# print all files in ckpt_dir
|
||||
vocab_path = next(ckpt_dir.glob('vocab*'))
|
||||
metadata_path = next(ckpt_dir.glob('*metadata.json'))
|
||||
|
||||
# if there is pt file ending with 'last', return it
|
||||
if len(list(ckpt_dir.glob('*last.pt'))) > 0:
|
||||
last_ckpt_fn = next(ckpt_dir.glob('*last.pt'))
|
||||
else:
|
||||
pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', '')))
|
||||
last_ckpt_fn = pt_fns[-1]
|
||||
|
||||
return last_ckpt_fn, config_path, metadata_path, vocab_path
|
||||
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
|
||||
nn_params = config.nn_params
|
||||
dataset_name = config.dataset
|
||||
vocab_path = Path(vocab_path)
|
||||
|
||||
if 'Encodec' in dataset_name:
|
||||
encodec_tokens_path = Path(f"dataset/maestro-v3.0.0-encodec_tokens")
|
||||
encodec_dataset = EncodecDataset(config, encodec_tokens_path, None, None)
|
||||
vocab_sizes = encodec_dataset.vocab.get_vocab_size()
|
||||
train_set, valid_set, test_set = encodec_dataset.split_train_valid_test_set()
|
||||
|
||||
lm_model:model_zoo.LanguageModelTransformer= getattr(model_zoo, nn_params.model_name)(config, vocab_sizes)
|
||||
else:
|
||||
# print(config)
|
||||
encoding_scheme = config.nn_params.encoding_scheme
|
||||
num_features = config.nn_params.num_features
|
||||
|
||||
# get vocab
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||
selected_vocab_name = vocab_name[encoding_scheme]
|
||||
|
||||
vocab = getattr(vocab_utils, selected_vocab_name)(
|
||||
in_vocab_file_path=vocab_path,
|
||||
event_data=None,
|
||||
encoding_scheme=encoding_scheme,
|
||||
num_features=num_features)
|
||||
|
||||
# Initialize symbolic dataset based on dataset name and configuration parameters
|
||||
symbolic_dataset = getattr(data_utils, dataset_name)(
|
||||
vocab=vocab,
|
||||
encoding_scheme=encoding_scheme,
|
||||
num_features=num_features,
|
||||
debug=config.general.debug,
|
||||
aug_type=config.data_params.aug_type,
|
||||
input_length=config.train_params.input_length,
|
||||
first_pred_feature=config.data_params.first_pred_feature,
|
||||
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
|
||||
for_evaluation=True,
|
||||
)
|
||||
|
||||
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
|
||||
print(f"---{nn_params.main_decoder}--- is used")
|
||||
print(f"---{dataset_name}--- is used")
|
||||
print(f"---{encoding_scheme}--- is used")
|
||||
split_ratio = config.data_params.split_ratio
|
||||
# test_set = []
|
||||
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
|
||||
|
||||
# get proper prediction order according to the encoding scheme and target feature in the config
|
||||
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
|
||||
|
||||
# Create the Transformer model based on configuration parameters
|
||||
AmadeusModel = getattr(model_zoo, nn_params.model_name)(
|
||||
vocab=symbolic_dataset.vocab,
|
||||
input_length=config.train_params.input_length,
|
||||
prediction_order=prediction_order,
|
||||
input_embedder_name=nn_params.input_embedder_name,
|
||||
main_decoder_name=nn_params.main_decoder_name,
|
||||
sub_decoder_name=nn_params.sub_decoder_name,
|
||||
sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0,
|
||||
sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \
|
||||
if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False,
|
||||
dim=nn_params.main_decoder.dim_model,
|
||||
heads=nn_params.main_decoder.num_head,
|
||||
depth=nn_params.main_decoder.num_layer,
|
||||
dropout=nn_params.model_dropout,
|
||||
)
|
||||
|
||||
return AmadeusModel, test_set, symbolic_dataset.vocab
|
||||
|
||||
def add_conti_in_valid(tensor, encoding_scheme):
|
||||
new_target = tensor.clone()
|
||||
# Assuming tensor shape is [batch, sequence, features]
|
||||
# Create a shifted version of the tensor
|
||||
shifted_tensor = torch.roll(new_target, shifts=1, dims=1)
|
||||
# The first element of each sequence cannot be a duplicate by definition
|
||||
shifted_tensor[:, 0, :] = new_target[:, 0, :] + 1
|
||||
|
||||
# Identify where the original and shifted tensors are the same (duplicates)
|
||||
duplicates = new_target == shifted_tensor
|
||||
# TODO: convert hard-coded part
|
||||
# convert values into False except the 1st and 2nd features
|
||||
if encoding_scheme == 'nb':
|
||||
if tensor.shape[2] == 5:
|
||||
# change beat, instrument
|
||||
duplicates[:, :, 0] = False
|
||||
duplicates[:, :, 3] = False
|
||||
duplicates[:, :, 4] = False
|
||||
elif tensor.shape[2] == 4:
|
||||
# change beat
|
||||
duplicates[:, :, 0] = False
|
||||
duplicates[:, :, 2] = False
|
||||
duplicates[:, :, 3] = False
|
||||
elif tensor.shape[2] == 7:
|
||||
# change beat, chord, tempo
|
||||
duplicates[:, :, 0] = False
|
||||
duplicates[:, :, 4] = False
|
||||
duplicates[:, :, 5] = False
|
||||
duplicates[:, :, 6] = False
|
||||
elif encoding_scheme == 'cp':
|
||||
if tensor.shape[2] == 5:
|
||||
# change instrument
|
||||
duplicates[:, :, 0] = False
|
||||
duplicates[:, :, 1] = False
|
||||
duplicates[:, :, 3] = False
|
||||
duplicates[:, :, 4] = False
|
||||
elif tensor.shape[2] == 7:
|
||||
# change chord, tempo
|
||||
duplicates[:, :, 0] = False
|
||||
duplicates[:, :, 1] = False
|
||||
duplicates[:, :, 4] = False
|
||||
duplicates[:, :, 5] = False
|
||||
duplicates[:, :, 6] = False
|
||||
|
||||
# Replace duplicates with 9999
|
||||
new_target[duplicates] = 9999
|
||||
return new_target
|
||||
|
||||
# TODO: hard coded
|
||||
def add_conti(list_of_lists, encoding_scheme):
|
||||
if encoding_scheme == 'nb':
|
||||
if len(list_of_lists[0]) == 4:
|
||||
# type, beat, pitch, duration
|
||||
for i in range(0, len(list_of_lists)):
|
||||
if list_of_lists[i][0] == 'SSS':
|
||||
list_of_lists[i][1] = 'Conti'
|
||||
elif len(list_of_lists[0]) == 5:
|
||||
# type, beat, instrument, pitch, duration
|
||||
previous_instrument = None
|
||||
for i in range(0, len(list_of_lists)):
|
||||
if list_of_lists[i][0] == 'SSS':
|
||||
list_of_lists[i][1] = 'Conti'
|
||||
if list_of_lists[i][2] == previous_instrument and previous_instrument != 0:
|
||||
list_of_lists[i][2] = 'Conti'
|
||||
else:
|
||||
previous_instrument = list_of_lists[i][2]
|
||||
elif len(list_of_lists[0]) == 7:
|
||||
# type, beat, chord, tempo, pitch, duration, velocity
|
||||
previous_chord = None
|
||||
previous_tempo = None
|
||||
for i in range(0, len(list_of_lists)):
|
||||
if list_of_lists[i][0] == 'SSS':
|
||||
list_of_lists[i][1] = 'Conti'
|
||||
if list_of_lists[i][2] == previous_chord and previous_chord != 0:
|
||||
list_of_lists[i][2] = 'Conti'
|
||||
elif list_of_lists[i][2] != previous_chord and list_of_lists[i][2] != 0:
|
||||
previous_chord = list_of_lists[i][2]
|
||||
if list_of_lists[i][3] == previous_tempo and previous_tempo != 0:
|
||||
list_of_lists[i][3] = 'Conti'
|
||||
elif list_of_lists[i][3] != previous_tempo and list_of_lists[i][3] != 0:
|
||||
previous_tempo = list_of_lists[i][3]
|
||||
elif encoding_scheme == 'cp':
|
||||
if len(list_of_lists[0]) == 7:
|
||||
# type, beat, chord, tempo, pitch, duration, velocity
|
||||
previous_chord = None
|
||||
previous_tempo = None
|
||||
for i in range(0, len(list_of_lists)):
|
||||
current_chord = list_of_lists[i][2]
|
||||
current_tempo = list_of_lists[i][3]
|
||||
if current_chord == previous_chord and current_chord != 0:
|
||||
list_of_lists[i][2] = 'Conti'
|
||||
elif current_chord != previous_chord and current_chord != 0:
|
||||
previous_chord = current_chord
|
||||
if current_tempo == previous_tempo and current_tempo != 0:
|
||||
list_of_lists[i][3] = 'Conti'
|
||||
elif current_tempo != previous_tempo and current_tempo != 0:
|
||||
previous_tempo = current_tempo
|
||||
if len(list_of_lists[0]) == 5:
|
||||
# type, beat, instrument, pitch, duration
|
||||
previous_instrument = None
|
||||
for i in range(0, len(list_of_lists)):
|
||||
current_instrument = list_of_lists[i][2]
|
||||
if current_instrument == previous_instrument and current_instrument != 0:
|
||||
list_of_lists[i][2] = 'Conti'
|
||||
elif current_instrument != previous_instrument and current_instrument != 0:
|
||||
previous_instrument = current_instrument
|
||||
return list_of_lists
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self,
|
||||
config: DictConfig,
|
||||
model:AmadeusModel,
|
||||
test_set:TuneCompiler,
|
||||
vocab: Union[LangTokenVocab, LangTokenVocab],
|
||||
device:str='cuda',
|
||||
batch_size:int=16):
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.vocab = vocab
|
||||
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.model.to(device)
|
||||
self.test_set = test_set
|
||||
|
||||
self.input_len = config.train_params.input_length
|
||||
self.loss_by_class = {key:[] for key in self.vocab.feature_list}
|
||||
self.count_by_class = {key:0 for key in self.vocab.feature_list}
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.is_multiclass = True if config.nn_params.encoding_scheme == 'nb' or config.nn_params.encoding_scheme == 'cp' else False
|
||||
self.first_pred_feature = self.config.data_params.first_pred_feature
|
||||
|
||||
self.neglect_keywords = ['SSS', 'SSN', 'Conti', 'Metrical', 'Note']
|
||||
self.valid_item_prob = []
|
||||
|
||||
# we don't use focal loss on evaluation
|
||||
self.focal_alpha = 1
|
||||
self.focal_gamma = 0
|
||||
|
||||
def save_results(self, save_fn):
|
||||
# convert loss_by_clas tensor to cpu
|
||||
for key in self.loss_by_class.keys():
|
||||
self.loss_by_class[key] = torch.tensor(self.loss_by_class[key]).cpu()
|
||||
self.count_by_class[key] = torch.tensor(self.count_by_class[key]).cpu()
|
||||
torch.save({'loss_by_class':self.loss_by_class, 'count_by_class':self.count_by_class}, save_fn)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_perplexity(self,less_than=256):
|
||||
for data in tqdm(self.test_set.data_list, desc='Cal over dataset', position=0):
|
||||
data_tensor = torch.LongTensor(data[0])
|
||||
if self.config.nn_params.encoding_scheme == 'nb':
|
||||
data_tensor = shift_and_pad(data_tensor, self.first_pred_feature)
|
||||
data_tensor = data_tensor[:-1]
|
||||
|
||||
x_seg = data_tensor[:-1].unsqueeze(0)
|
||||
y_seg = data_tensor[1:].unsqueeze(0)
|
||||
self._cal_initial_seg(x_seg, y_seg)
|
||||
|
||||
if x_seg.shape[1] > self.input_len:
|
||||
cat_logits = []
|
||||
cat_y = []
|
||||
cat_mask_indices = []
|
||||
batch_x = x_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
|
||||
batch_y = y_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1)
|
||||
if self.is_multiclass:
|
||||
batch_x = batch_x.transpose(1,2)
|
||||
batch_y = batch_y.transpose(1,2)
|
||||
for batch_start_idx in tqdm(range(0, min(batch_x.shape[0], less_than), self.batch_size), desc='In piece iter', position=1, leave=False):
|
||||
x = batch_x[batch_start_idx:batch_start_idx+self.batch_size]
|
||||
y = batch_y[batch_start_idx:batch_start_idx+self.batch_size]
|
||||
logits, y,mask_indices = self._cal_following_seg(x, y)
|
||||
cat_logits.append(logits)
|
||||
cat_y.append(y)
|
||||
cat_mask_indices.append(mask_indices)
|
||||
if self.is_multiclass:
|
||||
cat_dict = {}
|
||||
for key in self.vocab.feature_list:
|
||||
cat_dict[key] = torch.cat([logits_dict[key] for logits_dict in cat_logits], dim=0)
|
||||
cat_logits = cat_dict
|
||||
else:
|
||||
cat_logits = torch.cat(cat_logits, dim=0)
|
||||
cat_y = torch.cat(cat_y, dim=0)
|
||||
mask_indices = torch.cat(cat_mask_indices, dim=0)
|
||||
if self.is_multiclass:
|
||||
self._update_loss_for_multi_class(cat_logits, cat_y,mask_indices)
|
||||
else:
|
||||
cat_prob = torch.nn.functional.softmax(cat_logits, dim=-1)
|
||||
pt = cat_prob[torch.arange(cat_prob.shape[0]), cat_y]
|
||||
# focal_loss = -self.focal_alpha * (1-pt)**self.focal_gamma * torch.log(pt) # [batch_size*seq_len]
|
||||
loss = -torch.log(pt)
|
||||
self._update_loss_for_single_class(loss, cat_y)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _update_loss_for_single_class(self, neg_log_prob:torch.Tensor, y:torch.Tensor):
|
||||
for key in self.vocab.feature_list:
|
||||
feature_mask = self.vocab.total_mask[key].to(y.device) # [vocab_size,]
|
||||
mask_for_target = feature_mask[y] # [b*t]
|
||||
normal_loss_seq_by_class = neg_log_prob[mask_for_target==1]
|
||||
if mask_for_target.sum().item() != 0:
|
||||
self.loss_by_class[key] += normal_loss_seq_by_class.tolist()
|
||||
self.count_by_class[key] += mask_for_target.sum().item()
|
||||
|
||||
@torch.inference_mode()
|
||||
def _update_loss_for_multi_class(self, logits_dict:dict, tgt:torch.Tensor, mask_indices:torch.Tensor=None):
|
||||
correct_token_prob = []
|
||||
for index, key in enumerate(self.vocab.feature_list):
|
||||
feat_tgt = tgt[:,index]
|
||||
logit_values = logits_dict[key]
|
||||
logit_values = logit_values
|
||||
prob_values = torch.nn.functional.softmax(logit_values, dim=-1)
|
||||
# replce the false
|
||||
correct_token_prob.append(prob_values[torch.arange(prob_values.shape[0]), feat_tgt])
|
||||
correct_token_prob = torch.stack(correct_token_prob, dim=1)
|
||||
# tgt = reverse_shift_and_pad_for_tensor(tgt, self.first_pred_feature)
|
||||
y_decoded = self.vocab.decode(tgt)
|
||||
y_decoded = add_conti(y_decoded, self.config.nn_params.encoding_scheme)
|
||||
# correct_token_prob = reverse_shift_and_pad_for_tensor(correct_token_prob, self.first_pred_feature)
|
||||
num_notes = logits_dict['pitch'].shape[0]
|
||||
cum_prob = 1
|
||||
max_num = mask_indices.size(0)
|
||||
for idx in range(max_num):
|
||||
if max_num != num_notes:
|
||||
print("not equal",max_num,num_notes)
|
||||
token = y_decoded[idx]
|
||||
vaild_mask = mask_indices[idx,:]
|
||||
token_prob = correct_token_prob[idx].tolist()
|
||||
for j, key in enumerate(self.vocab.feature_list):
|
||||
cur_feature = token[j]
|
||||
whether_predicted = vaild_mask[j]
|
||||
# clamp cur_prob to avoid when cur_prob is 0
|
||||
cur_prob = max(token_prob[j], 1e-10)
|
||||
if cur_feature == 0: # ignore token
|
||||
continue
|
||||
if whether_predicted is False: # skip provided token
|
||||
continue
|
||||
if cur_feature in self.neglect_keywords:
|
||||
cum_prob *= cur_prob
|
||||
continue
|
||||
if self.config.nn_params.encoding_scheme == 'cp' and 'time_signature' in cur_feature:
|
||||
cum_prob *= cur_prob
|
||||
continue
|
||||
if self.config.nn_params.encoding_scheme == 'cp' and 'Bar' in cur_feature:
|
||||
cum_prob = 1
|
||||
continue
|
||||
self.valid_item_prob.append([cur_feature, cur_prob, cur_prob*cum_prob])
|
||||
pt = cur_prob*cum_prob
|
||||
loss = -log(pt)
|
||||
self.loss_by_class[key].append(loss)
|
||||
self.count_by_class[key] += 1
|
||||
cum_prob = 1
|
||||
|
||||
@torch.inference_mode()
|
||||
def _cal_initial_seg(self, x_seg, y_seg):
|
||||
x, y = x_seg[:, :self.input_len].to(self.device), y_seg[:, :self.input_len].to(self.device)
|
||||
mask_indices = torch.ones_like(y).bool().to(self.device).flatten(0,1)
|
||||
if self.config.use_diff is True:
|
||||
logits,(mask_indices,_) = self.model(x, y)
|
||||
else:
|
||||
logits = self.model(x, y)
|
||||
y = y.flatten(0,1)
|
||||
if self.is_multiclass:
|
||||
for key in logits.keys():
|
||||
feat_tensor = logits[key].flatten(0,1)
|
||||
logits[key] = feat_tensor
|
||||
self._update_loss_for_multi_class(logits, y, mask_indices)
|
||||
else:
|
||||
prob = torch.nn.functional.softmax(logits, dim=-1)
|
||||
prob = prob.flatten(0,1)
|
||||
pt = prob[torch.arange(len(y)), y]
|
||||
loss = -torch.log(pt)
|
||||
self._update_loss_for_single_class(loss, y)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _cal_following_seg(self, x:torch.Tensor, y:torch.Tensor):
|
||||
x, y = x.to(self.device), y.to(self.device)
|
||||
mask_indices = torch.ones_like(y).bool().to(self.device)
|
||||
if self.config.use_diff is True:
|
||||
logits,(mask_indices,_) = self.model(x, y)
|
||||
else:
|
||||
logits = self.model(x, y)
|
||||
y = y[:, -1:].flatten(0,1).cpu()
|
||||
mask_indices = mask_indices.reshape(x.shape)[:,-1:].flatten(0,1).cpu()
|
||||
if self.is_multiclass:
|
||||
logits_dict = {}
|
||||
for key in self.vocab.feature_list:
|
||||
logits_dict[key] = logits[key][:, -1:].flatten(0,1).cpu()
|
||||
return logits_dict, y,mask_indices
|
||||
else:
|
||||
logits = logits[:, -1:].flatten(0,1).cpu()
|
||||
return logits, y,mask_indices
|
||||
|
||||
def prepare_prompt_and_ground_truth(self, save_dir, num_target_samples, num_target_measures):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
for i, (tuneidx, tune_name) in enumerate(self.test_set):
|
||||
ground_truth_sample = tuneidx
|
||||
try:
|
||||
decoder(ground_truth_sample, output_path=str(save_dir / f"{i}_{tune_name}_gt.mid"))
|
||||
except:
|
||||
print(f"Error in generating {i}_{tune_name}.mid")
|
||||
|
||||
prompt = self.model.decoder._prepare_inference(start_token=self.model.decoder.net.start_token, manual_seed=0, condition=tuneidx, num_target_measures=num_target_measures)
|
||||
try:
|
||||
decoder(prompt, output_path=str(save_dir / f"{i}_{tune_name}_prompt.mid"))
|
||||
except:
|
||||
print(f"Error in generating {i}_{tune_name}_prompt.mid")
|
||||
|
||||
if i == num_target_samples:
|
||||
break
|
||||
|
||||
def generate_samples_with_prompt(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
tuneidx = tuneidx.cuda()
|
||||
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
|
||||
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
for i in range(num_samples):
|
||||
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
|
||||
|
||||
def generate_samples_with_text_prompt(self, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
|
||||
encoder = T5EncoderModel.from_pretrained('google/flan-t5-base').to(self.device)
|
||||
print(f"Using T5EncoderModel for text prompt: {prompt}")
|
||||
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(self.device)
|
||||
context = encoder(**context).last_hidden_state
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
# Open the jsonl file and count the number of lines to determine the current index
|
||||
jsonl_path = save_dir / "name2prompt.jsonl"
|
||||
if jsonl_path.exists():
|
||||
with open(jsonl_path, 'r') as f:
|
||||
current_idx = sum(1 for _ in f)
|
||||
else:
|
||||
current_idx = 0
|
||||
|
||||
name = f"prompt_{current_idx}"
|
||||
name2prompt_dict = defaultdict(list)
|
||||
name2prompt_dict[name].append(prompt)
|
||||
with open(jsonl_path, 'a') as f:
|
||||
f.write(json.dumps(name2prompt_dict) + '\n')
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))
|
||||
Reference in New Issue
Block a user