from collections import defaultdict from typing import Union from math import log from omegaconf import DictConfig from pathlib import Path import pickle import json import torch from tqdm.auto import tqdm from transformers import T5Tokenizer, T5EncoderModel from . import model_zoo from .symbolic_encoding import data_utils from .model_zoo import AmadeusModel from .symbolic_encoding.data_utils import TuneCompiler from .symbolic_encoding.compile_utils import shift_and_pad from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor from .symbolic_encoding import decoding_utils from .train_utils import adjust_prediction_order from data_representation import vocab_utils from data_representation.vocab_utils import LangTokenVocab def wandb_style_config_to_omega_config(wandb_conf): # remove wandb related config for wandb_key in ["wandb_version", "_wandb"]: if wandb_key in wandb_conf: del wandb_conf[wandb_key] # wandb-related config should not be overrided! # print(wandb_conf) # remove nonnecessary fields such as desc and value for key in wandb_conf: # if 'desc' in wandb_conf[key]: # del wandb_conf[key]['desc'] if isinstance(wandb_conf[key], dict) and 'value' in wandb_conf[key]: wandb_conf[key] = wandb_conf[key]['value'] # 处理存在'value'的情况 try: if 'value' in wandb_conf[key]: wandb_conf[key] = wandb_conf[key]['value'] except: pass return wandb_conf def get_dir_from_wandb_by_code(wandb_dir: Path, code:str) -> Path: for dir in wandb_dir.iterdir(): if dir.name.endswith(code): return dir print(f'No such code in wandb_dir: {code}') return None def get_best_ckpt_path_and_config(wandb_dir, code): dir = get_dir_from_wandb_by_code(wandb_dir, code) if dir is None: raise ValueError('No such code in wandb_dir') ckpt_dir = dir / 'files' / 'checkpoints' config_path = dir / 'files' / 'config.yaml' # print all files in ckpt_dir vocab_path = next(ckpt_dir.glob('vocab*')) metadata_path = next(ckpt_dir.glob('*metadata.json')) # if there is pt file ending with 'last', return it if len(list(ckpt_dir.glob('*last.pt'))) > 0: last_ckpt_fn = next(ckpt_dir.glob('*last.pt')) else: pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', ''))) last_ckpt_fn = pt_fns[-1] return last_ckpt_fn, config_path, metadata_path, vocab_path def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str): nn_params = config.nn_params dataset_name = config.dataset vocab_path = Path(vocab_path) if 'Encodec' in dataset_name: encodec_tokens_path = Path(f"dataset/maestro-v3.0.0-encodec_tokens") encodec_dataset = EncodecDataset(config, encodec_tokens_path, None, None) vocab_sizes = encodec_dataset.vocab.get_vocab_size() train_set, valid_set, test_set = encodec_dataset.split_train_valid_test_set() lm_model:model_zoo.LanguageModelTransformer= getattr(model_zoo, nn_params.model_name)(config, vocab_sizes) else: # print(config) encoding_scheme = config.nn_params.encoding_scheme num_features = config.nn_params.num_features # get vocab vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'} selected_vocab_name = vocab_name[encoding_scheme] vocab = getattr(vocab_utils, selected_vocab_name)( in_vocab_file_path=vocab_path, event_data=None, encoding_scheme=encoding_scheme, num_features=num_features) # Initialize symbolic dataset based on dataset name and configuration parameters symbolic_dataset = getattr(data_utils, dataset_name)( vocab=vocab, encoding_scheme=encoding_scheme, num_features=num_features, debug=config.general.debug, aug_type=config.data_params.aug_type, input_length=config.train_params.input_length, first_pred_feature=config.data_params.first_pred_feature, caption_path=config.captions_path if hasattr(config, 'captions_path') else None, for_evaluation=True, ) vocab_sizes = symbolic_dataset.vocab.get_vocab_size() print(f"---{nn_params.main_decoder}--- is used") print(f"---{dataset_name}--- is used") print(f"---{encoding_scheme}--- is used") split_ratio = config.data_params.split_ratio # test_set = [] train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None) # get proper prediction order according to the encoding scheme and target feature in the config prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params) # Create the Transformer model based on configuration parameters AmadeusModel = getattr(model_zoo, nn_params.model_name)( vocab=symbolic_dataset.vocab, input_length=config.train_params.input_length, prediction_order=prediction_order, input_embedder_name=nn_params.input_embedder_name, main_decoder_name=nn_params.main_decoder_name, sub_decoder_name=nn_params.sub_decoder_name, sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0, sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \ if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False, dim=nn_params.main_decoder.dim_model, heads=nn_params.main_decoder.num_head, depth=nn_params.main_decoder.num_layer, dropout=nn_params.model_dropout, ) return AmadeusModel, test_set, symbolic_dataset.vocab def add_conti_in_valid(tensor, encoding_scheme): new_target = tensor.clone() # Assuming tensor shape is [batch, sequence, features] # Create a shifted version of the tensor shifted_tensor = torch.roll(new_target, shifts=1, dims=1) # The first element of each sequence cannot be a duplicate by definition shifted_tensor[:, 0, :] = new_target[:, 0, :] + 1 # Identify where the original and shifted tensors are the same (duplicates) duplicates = new_target == shifted_tensor # TODO: convert hard-coded part # convert values into False except the 1st and 2nd features if encoding_scheme == 'nb': if tensor.shape[2] == 5: # change beat, instrument duplicates[:, :, 0] = False duplicates[:, :, 3] = False duplicates[:, :, 4] = False elif tensor.shape[2] == 4: # change beat duplicates[:, :, 0] = False duplicates[:, :, 2] = False duplicates[:, :, 3] = False elif tensor.shape[2] == 7: # change beat, chord, tempo duplicates[:, :, 0] = False duplicates[:, :, 4] = False duplicates[:, :, 5] = False duplicates[:, :, 6] = False elif encoding_scheme == 'cp': if tensor.shape[2] == 5: # change instrument duplicates[:, :, 0] = False duplicates[:, :, 1] = False duplicates[:, :, 3] = False duplicates[:, :, 4] = False elif tensor.shape[2] == 7: # change chord, tempo duplicates[:, :, 0] = False duplicates[:, :, 1] = False duplicates[:, :, 4] = False duplicates[:, :, 5] = False duplicates[:, :, 6] = False # Replace duplicates with 9999 new_target[duplicates] = 9999 return new_target # TODO: hard coded def add_conti(list_of_lists, encoding_scheme): if encoding_scheme == 'nb': if len(list_of_lists[0]) == 4: # type, beat, pitch, duration for i in range(0, len(list_of_lists)): if list_of_lists[i][0] == 'SSS': list_of_lists[i][1] = 'Conti' elif len(list_of_lists[0]) == 5: # type, beat, instrument, pitch, duration previous_instrument = None for i in range(0, len(list_of_lists)): if list_of_lists[i][0] == 'SSS': list_of_lists[i][1] = 'Conti' if list_of_lists[i][2] == previous_instrument and previous_instrument != 0: list_of_lists[i][2] = 'Conti' else: previous_instrument = list_of_lists[i][2] elif len(list_of_lists[0]) == 7: # type, beat, chord, tempo, pitch, duration, velocity previous_chord = None previous_tempo = None for i in range(0, len(list_of_lists)): if list_of_lists[i][0] == 'SSS': list_of_lists[i][1] = 'Conti' if list_of_lists[i][2] == previous_chord and previous_chord != 0: list_of_lists[i][2] = 'Conti' elif list_of_lists[i][2] != previous_chord and list_of_lists[i][2] != 0: previous_chord = list_of_lists[i][2] if list_of_lists[i][3] == previous_tempo and previous_tempo != 0: list_of_lists[i][3] = 'Conti' elif list_of_lists[i][3] != previous_tempo and list_of_lists[i][3] != 0: previous_tempo = list_of_lists[i][3] elif encoding_scheme == 'cp': if len(list_of_lists[0]) == 7: # type, beat, chord, tempo, pitch, duration, velocity previous_chord = None previous_tempo = None for i in range(0, len(list_of_lists)): current_chord = list_of_lists[i][2] current_tempo = list_of_lists[i][3] if current_chord == previous_chord and current_chord != 0: list_of_lists[i][2] = 'Conti' elif current_chord != previous_chord and current_chord != 0: previous_chord = current_chord if current_tempo == previous_tempo and current_tempo != 0: list_of_lists[i][3] = 'Conti' elif current_tempo != previous_tempo and current_tempo != 0: previous_tempo = current_tempo if len(list_of_lists[0]) == 5: # type, beat, instrument, pitch, duration previous_instrument = None for i in range(0, len(list_of_lists)): current_instrument = list_of_lists[i][2] if current_instrument == previous_instrument and current_instrument != 0: list_of_lists[i][2] = 'Conti' elif current_instrument != previous_instrument and current_instrument != 0: previous_instrument = current_instrument return list_of_lists class Evaluator: def __init__(self, config: DictConfig, model:AmadeusModel, test_set:TuneCompiler, vocab: Union[LangTokenVocab, LangTokenVocab], device:str='cuda', batch_size:int=16): self.config = config self.device = device self.vocab = vocab self.model = model self.model.eval() self.model.to(device) self.test_set = test_set self.input_len = config.train_params.input_length self.loss_by_class = {key:[] for key in self.vocab.feature_list} self.count_by_class = {key:0 for key in self.vocab.feature_list} self.batch_size = batch_size self.is_multiclass = True if config.nn_params.encoding_scheme == 'nb' or config.nn_params.encoding_scheme == 'cp' else False self.first_pred_feature = self.config.data_params.first_pred_feature self.neglect_keywords = ['SSS', 'SSN', 'Conti', 'Metrical', 'Note'] self.valid_item_prob = [] # we don't use focal loss on evaluation self.focal_alpha = 1 self.focal_gamma = 0 def save_results(self, save_fn): # convert loss_by_clas tensor to cpu for key in self.loss_by_class.keys(): self.loss_by_class[key] = torch.tensor(self.loss_by_class[key]).cpu() self.count_by_class[key] = torch.tensor(self.count_by_class[key]).cpu() torch.save({'loss_by_class':self.loss_by_class, 'count_by_class':self.count_by_class}, save_fn) @torch.inference_mode() def get_perplexity(self,less_than=256): for data in tqdm(self.test_set.data_list, desc='Cal over dataset', position=0): data_tensor = torch.LongTensor(data[0]) if self.config.nn_params.encoding_scheme == 'nb': data_tensor = shift_and_pad(data_tensor, self.first_pred_feature) data_tensor = data_tensor[:-1] x_seg = data_tensor[:-1].unsqueeze(0) y_seg = data_tensor[1:].unsqueeze(0) self._cal_initial_seg(x_seg, y_seg) if x_seg.shape[1] > self.input_len: cat_logits = [] cat_y = [] cat_mask_indices = [] batch_x = x_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1) batch_y = y_seg[0, 1:].unfold(dimension=0, size=self.input_len, step=1) if self.is_multiclass: batch_x = batch_x.transpose(1,2) batch_y = batch_y.transpose(1,2) for batch_start_idx in tqdm(range(0, min(batch_x.shape[0], less_than), self.batch_size), desc='In piece iter', position=1, leave=False): x = batch_x[batch_start_idx:batch_start_idx+self.batch_size] y = batch_y[batch_start_idx:batch_start_idx+self.batch_size] logits, y,mask_indices = self._cal_following_seg(x, y) cat_logits.append(logits) cat_y.append(y) cat_mask_indices.append(mask_indices) if self.is_multiclass: cat_dict = {} for key in self.vocab.feature_list: cat_dict[key] = torch.cat([logits_dict[key] for logits_dict in cat_logits], dim=0) cat_logits = cat_dict else: cat_logits = torch.cat(cat_logits, dim=0) cat_y = torch.cat(cat_y, dim=0) mask_indices = torch.cat(cat_mask_indices, dim=0) if self.is_multiclass: self._update_loss_for_multi_class(cat_logits, cat_y,mask_indices) else: cat_prob = torch.nn.functional.softmax(cat_logits, dim=-1) pt = cat_prob[torch.arange(cat_prob.shape[0]), cat_y] # focal_loss = -self.focal_alpha * (1-pt)**self.focal_gamma * torch.log(pt) # [batch_size*seq_len] loss = -torch.log(pt) self._update_loss_for_single_class(loss, cat_y) @torch.inference_mode() def _update_loss_for_single_class(self, neg_log_prob:torch.Tensor, y:torch.Tensor): for key in self.vocab.feature_list: feature_mask = self.vocab.total_mask[key].to(y.device) # [vocab_size,] mask_for_target = feature_mask[y] # [b*t] normal_loss_seq_by_class = neg_log_prob[mask_for_target==1] if mask_for_target.sum().item() != 0: self.loss_by_class[key] += normal_loss_seq_by_class.tolist() self.count_by_class[key] += mask_for_target.sum().item() @torch.inference_mode() def _update_loss_for_multi_class(self, logits_dict:dict, tgt:torch.Tensor, mask_indices:torch.Tensor=None): correct_token_prob = [] for index, key in enumerate(self.vocab.feature_list): feat_tgt = tgt[:,index] logit_values = logits_dict[key] logit_values = logit_values prob_values = torch.nn.functional.softmax(logit_values, dim=-1) # replce the false correct_token_prob.append(prob_values[torch.arange(prob_values.shape[0]), feat_tgt]) correct_token_prob = torch.stack(correct_token_prob, dim=1) # tgt = reverse_shift_and_pad_for_tensor(tgt, self.first_pred_feature) y_decoded = self.vocab.decode(tgt) y_decoded = add_conti(y_decoded, self.config.nn_params.encoding_scheme) # correct_token_prob = reverse_shift_and_pad_for_tensor(correct_token_prob, self.first_pred_feature) num_notes = logits_dict['pitch'].shape[0] cum_prob = 1 max_num = mask_indices.size(0) for idx in range(max_num): if max_num != num_notes: print("not equal",max_num,num_notes) token = y_decoded[idx] vaild_mask = mask_indices[idx,:] token_prob = correct_token_prob[idx].tolist() for j, key in enumerate(self.vocab.feature_list): cur_feature = token[j] whether_predicted = vaild_mask[j] # clamp cur_prob to avoid when cur_prob is 0 cur_prob = max(token_prob[j], 1e-10) if cur_feature == 0: # ignore token continue if whether_predicted is False: # skip provided token continue if cur_feature in self.neglect_keywords: cum_prob *= cur_prob continue if self.config.nn_params.encoding_scheme == 'cp' and 'time_signature' in cur_feature: cum_prob *= cur_prob continue if self.config.nn_params.encoding_scheme == 'cp' and 'Bar' in cur_feature: cum_prob = 1 continue self.valid_item_prob.append([cur_feature, cur_prob, cur_prob*cum_prob]) pt = cur_prob*cum_prob loss = -log(pt) self.loss_by_class[key].append(loss) self.count_by_class[key] += 1 cum_prob = 1 @torch.inference_mode() def _cal_initial_seg(self, x_seg, y_seg): x, y = x_seg[:, :self.input_len].to(self.device), y_seg[:, :self.input_len].to(self.device) mask_indices = torch.ones_like(y).bool().to(self.device).flatten(0,1) if self.config.use_diff is True: logits,(mask_indices,_) = self.model(x, y) else: logits = self.model(x, y) y = y.flatten(0,1) if self.is_multiclass: for key in logits.keys(): feat_tensor = logits[key].flatten(0,1) logits[key] = feat_tensor self._update_loss_for_multi_class(logits, y, mask_indices) else: prob = torch.nn.functional.softmax(logits, dim=-1) prob = prob.flatten(0,1) pt = prob[torch.arange(len(y)), y] loss = -torch.log(pt) self._update_loss_for_single_class(loss, y) @torch.inference_mode() def _cal_following_seg(self, x:torch.Tensor, y:torch.Tensor): x, y = x.to(self.device), y.to(self.device) mask_indices = torch.ones_like(y).bool().to(self.device) if self.config.use_diff is True: logits,(mask_indices,_) = self.model(x, y) else: logits = self.model(x, y) y = y[:, -1:].flatten(0,1).cpu() mask_indices = mask_indices.reshape(x.shape)[:,-1:].flatten(0,1).cpu() if self.is_multiclass: logits_dict = {} for key in self.vocab.feature_list: logits_dict[key] = logits[key][:, -1:].flatten(0,1).cpu() return logits_dict, y,mask_indices else: logits = logits[:, -1:].flatten(0,1).cpu() return logits, y,mask_indices def prepare_prompt_and_ground_truth(self, save_dir, num_target_samples, num_target_measures): encoding_scheme = self.config.nn_params.encoding_scheme in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4} try: in_beat_resolution = in_beat_resolution_dict[self.config.dataset] except KeyError: in_beat_resolution = 4 # Default resolution if dataset is not found midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) for i, (tuneidx, tune_name) in enumerate(self.test_set): ground_truth_sample = tuneidx try: decoder(ground_truth_sample, output_path=str(save_dir / f"{i}_{tune_name}_gt.mid")) except: print(f"Error in generating {i}_{tune_name}.mid") prompt = self.model.decoder._prepare_inference(start_token=self.model.decoder.net.start_token, manual_seed=0, condition=tuneidx, num_target_measures=num_target_measures) try: decoder(prompt, output_path=str(save_dir / f"{i}_{tune_name}_prompt.mid")) except: print(f"Error in generating {i}_{tune_name}_prompt.mid") if i == num_target_samples: break def generate_samples_with_prompt(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072): encoding_scheme = self.config.nn_params.encoding_scheme in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4} try: in_beat_resolution = in_beat_resolution_dict[self.config.dataset] except KeyError: in_beat_resolution = 4 # Default resolution if dataset is not found midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) tuneidx = tuneidx.cuda() generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature) if encoding_scheme == 'nb': generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature) decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid")) prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8) decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid")) def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1): encoding_scheme = self.config.nn_params.encoding_scheme in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4} try: in_beat_resolution = in_beat_resolution_dict[self.config.dataset] except KeyError: in_beat_resolution = 4 # Default resolution if dataset is not found midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) for i in range(num_samples): generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature) if encoding_scheme == 'nb': generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature) decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid")) def generate_samples_with_text_prompt(self, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1): encoding_scheme = self.config.nn_params.encoding_scheme tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base') encoder = T5EncoderModel.from_pretrained('google/flan-t5-base').to(self.device) print(f"Using T5EncoderModel for text prompt: {prompt}") context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(self.device) context = encoder(**context).last_hidden_state in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4} try: in_beat_resolution = in_beat_resolution_dict[self.config.dataset] except KeyError: in_beat_resolution = 4 # Default resolution if dataset is not found midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) generated_sample = self.model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context) if encoding_scheme == 'nb': generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature) # Open the jsonl file and count the number of lines to determine the current index jsonl_path = save_dir / "name2prompt.jsonl" if jsonl_path.exists(): with open(jsonl_path, 'r') as f: current_idx = sum(1 for _ in f) else: current_idx = 0 name = f"prompt_{current_idx}" name2prompt_dict = defaultdict(list) name2prompt_dict[name].append(prompt) with open(jsonl_path, 'a') as f: f.write(json.dumps(name2prompt_dict) + '\n') decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))