from calendar import EPOCH, c from multiprocessing import context import time import pickle import os from pathlib import Path from typing import Union from datetime import datetime from omegaconf import OmegaConf import random import itertools import torch import torchaudio from torch.utils.data import DataLoader from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler, Sampler # import accelerate from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate.utils import set_seed #====================================================================== import wandb from collections import defaultdict from tqdm.auto import tqdm from .model_zoo import AmadeusModel from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor from .symbolic_encoding.data_utils import TuneCompiler from .symbolic_encoding.decoding_utils import MidiDecoder4REMI from .evaluation_utils import add_conti_in_valid from .train_utils import NLLLoss4REMI os.environ['WANDB_INIT_TIMEOUT'] = '600' os.environ["WANDB_BASE_URL"] = "https://api.bandw.top" from data_representation.vocab_utils import LangTokenVocab class InfiniteSampler(Sampler): def __init__(self, data_source, shuffle=True): self.data_source = data_source self.shuffle = shuffle self.indices = list(range(len(data_source))) if self.shuffle: random.shuffle(self.indices) self.infinite_iterator = itertools.cycle(self.indices) def __iter__(self): return self.infinite_iterator def __len__(self): return None # 表示无限长度 class LanguageModelTrainer: def __init__( self, model: AmadeusModel, # The language model for music generation optimizer: torch.optim.Optimizer, # Optimizer for updating model weights scheduler: torch.optim.lr_scheduler._LRScheduler, # Learning rate scheduler loss_fn: NLLLoss4REMI, # Loss function to compute the error midi_decoder: MidiDecoder4REMI, # Decoder to convert model output into MIDI format train_set: TuneCompiler, # Training dataset valid_set: TuneCompiler, # Validation dataset save_dir: str, # Directory to save models and logs vocab: LangTokenVocab, # Vocabulary for tokenizing sequences use_ddp: bool, # Whether to use Distributed Data Parallel (DDP) use_fp16: bool, # Whether to use mixed-precision training (FP16) world_size: int, # Total number of devices for distributed training batch_size: int, # Batch size for training infer_target_len: int, # Target length for inference generation gpu_id: int, # GPU device ID for computation sampling_method: str, # Sampling method for sequence generation sampling_threshold: float, # Threshold for sampling decisions sampling_temperature: float, # Temperature for controlling sampling randomness config, # Configuration parameters (contains general, training, and inference settings) model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional) # model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional) ): # Save model, optimizer, and other configurations self.model = model self.optimizer = optimizer self.scheduler = scheduler self.loss_fn = loss_fn self.valid_set = valid_set self.vocab = vocab self.use_ddp = use_ddp self.world_size = world_size self.batch_size = batch_size self.gpu_id = gpu_id self.sampling_method = sampling_method self.sampling_threshold = sampling_threshold self.sampling_temperature = sampling_temperature self.config = config self.last_iter = 0 # Load pre-trained model if provided if model_checkpoint: # parse the model checkpoint iter if isinstance(model_checkpoint, str): if model_checkpoint.endswith('.pt'): self.last_iter = int(model_checkpoint.split('/')[-1].split('_')[0][4:]) checkpoint = torch.load(model_checkpoint, map_location='cpu') # print state dict keys print("Loading model checkpoint from", model_checkpoint) if isinstance(self.model, DDP): self.model.module.load_state_dict(checkpoint['model'], strict=False) else: self.model.load_state_dict(checkpoint['model'], strict=False) # Training hyperparameters from config self.grad_clip = config.train_params.grad_clip self.num_cycles_for_inference = config.train_params.num_cycles_for_inference self.num_cycles_for_model_checkpoint = config.train_params.num_cycles_for_model_checkpoint self.iterations_per_training_cycle = config.train_params.iterations_per_training_cycle self.iterations_per_validation_cycle = config.train_params.iterations_per_validation_cycle self.make_log = config.general.make_log self.num_uncond_generation = config.inference_params.num_uncond_generation self.num_cond_generation = config.inference_params.num_cond_generation self.num_max_seq_len = infer_target_len self.infer_and_log = config.general.infer_and_log self.valid_loader = self.generate_data_loader(self.valid_set, shuffle=False, drop_last=True) # gradient accumulation self.gradient_accumulation_steps = config.train_params.gradient_accumulation_steps # Set up mixed-precision training (FP16) if enabled if use_fp16: self.use_fp16 = True else: self.use_fp16 = False # Set up Distributed Data Parallel (DDP) if required if use_ddp: # prepare using accelerator if self.use_fp16: self.accelerator = Accelerator(mixed_precision='bf16', step_scheduler_with_optimizer=False, gradient_accumulation_steps=self.gradient_accumulation_steps, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) else: self.accelerator = Accelerator( gradient_accumulation_steps=self.gradient_accumulation_steps, step_scheduler_with_optimizer=False, kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]) with self.accelerator.main_process_first(): self.train_set = train_set self.train_loader = self.generate_data_loader(self.train_set, shuffle=False, drop_last=False) self.accelerator.wait_for_everyone() self.accelerator.print(f"Using {self.world_size} GPUs for training") self.model, self.optimizer, self.scheduler, self.train_loader = self.accelerator.prepare( self.model, self.optimizer, self.scheduler, self.train_loader ) self.accelerator.wait_for_everyone() # self.accelerator.init_trackers("nested_music_transformer", config) set_seed(42) self.device = self.accelerator.device self.model.to(self.device) # set up for logging if self.accelerator.is_main_process: save_dir = self.setup_log(config) print("savwe",save_dir) # Create directory for saving models and logs self.save_dir = Path(save_dir) self.save_dir.mkdir(exist_ok=True, parents=True) self.set_save_out() else: self.train_set = train_set # Create data loaders for training and validation sets self.train_loader = self.generate_data_loader(train_set, shuffle=False, drop_last=True) self.valid_loader = self.generate_data_loader(valid_set, shuffle=True, drop_last=True) save_dir = self.setup_log(config) # Create directory for saving models and logs self.save_dir = Path(save_dir) self.save_dir.mkdir(exist_ok=True, parents=True) self.set_save_out() self.device = config.train_params.device self.model.to(self.device) # Initialize tracking metrics self.best_valid_accuracy = 0 self.best_valid_loss = 100 self.training_loss = [] self.validation_loss = [] self.validation_acc = [] self.midi_decoder = midi_decoder def generate_experiment_name(self, config): # add base hyperparameters to the experiment name dataset_name = config.dataset encoding_name = config.nn_params.encoding_scheme num_features = config.nn_params.num_features input_embedder_name = config.nn_params.input_embedder_name sub_decoder_name = config.nn_params.sub_decoder_name batch_size = config.train_params.batch_size num_layers = config.nn_params.main_decoder.num_layer input_length = config.train_params.input_length first_pred_feature = config.data_params.first_pred_feature # Add target hyperparameters to the experiment name # dropout main_dropout = config.nn_params.model_dropout # learning rate lr_decay_rate = config.train_params.decay_step_rate time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # Combine the information into a single string for the experiment name # experiment_name = f"{time}_{dataset_name}_{encoding_name}{num_features}_{input_embedder_name}_{sub_decoder_name}_firstpred:{first_pred_feature}_inputlen{input_length}_nlayer{num_layers}_batch{batch_size}\ # _dropout{main_dropout}_lrdecay{lr_decay_rate}" experiment_name = f"{time}_{dataset_name}_{encoding_name}{num_features}_{sub_decoder_name}_firstpred:{first_pred_feature}_inputlen{input_length}_nlayer{num_layers}_batch{batch_size}" return experiment_name def collate_fn(self, batch): """ Custom collate function to handle variable-length sequences in a batch. It pads sequences to the maximum length in the batch and returns a tuple of padded sequences and their lengths. """ # Unzip the batch into segments, masks, captions, and encoded captions segments, masks, captions, encoded_captions = zip(*batch) # print("collate_fn",len(segments),len(masks),len(captions),len(encoded_captions)) # # Pad the segments and masks to the maximum length in the batch # padded_segments = torch.nn.utils.rnn.pad_sequence(segments, batch_first=True) # padded_masks = torch.nn.utils.rnn.pad_sequence(masks, batch_first=True) # # Return padded segments and masks along with captions and encoded captions segments = torch.stack(segments, dim=0) masks = torch.stack(masks, dim=0) print(captions) print(encoded_captions) # captions = torch.stack(captions, dim=0) # encoded_captions = torch.stack(encoded_captions, dim=0) return segments, masks, captions, encoded_captions # return padded_segments, padded_masks, captions, encoded_captions def setup_log(self, config): if self.accelerator.is_main_process: if config.general.make_log: experiment_name =self.generate_experiment_name(config) wandb.init( project="Acce_Music_Transformer", name=experiment_name, config=OmegaConf.to_container(config) ) # 保存配置到 WANDB 根目录 config_path = Path(wandb.run.dir) / "config.yaml" OmegaConf.save(config, config_path) # 关键代码 save_dir = Path(wandb.run.dir) / "checkpoints" save_dir.mkdir(exist_ok=True, parents=True) else: now = datetime.now() save_dir = Path('wandb/debug/checkpoints') / now.strftime('%y-%m-%d') save_dir.mkdir(exist_ok=True, parents=True) # 保存配置到调试目录 config_path = save_dir / "config.yaml" OmegaConf.save(config, config_path) # 关键代码 return str(save_dir) # Set up the output directories for saving MIDI results during inference def set_save_out(self): if self.accelerator.is_main_process: # copy from latest folder in wandb/debug/checkpoints target_folder = 'wandb/debug/checkpoints' latest_folder = sorted(Path(target_folder).iterdir(), key=os.path.getmtime)[-1] # get files in the latest folder files = [f for f in latest_folder.iterdir() if f.is_file()] # copy files to the save_dir for file in files: # copy the file to the save_dir target_file = self.save_dir / file.name if not target_file.exists(): os.system(f'cp {file} {target_file}') if self.infer_and_log: self.valid_out_dir = self.save_dir / 'valid_out' os.makedirs(self.valid_out_dir, exist_ok=True) # Save the current model and optimizer state def save_model(self, path): if isinstance(self.model, DDP): torch.save({'model': self.model.module.state_dict(), 'optim': self.optimizer.state_dict()}, path) else: torch.save({'model': self.model.state_dict(), 'optim': self.optimizer.state_dict()}, path) # Generate the data loader for either training or validation datasets def generate_data_loader(self, dataset, shuffle=False, drop_last=False) -> DataLoader: return DataLoader(dataset, shuffle=shuffle, batch_size=self.batch_size, drop_last=drop_last,collate_fn=None, pin_memory=True,num_workers=4, persistent_workers=True, prefetch_factor=2, worker_init_fn=None) # Training function based on a given number of iterations def accelerate_train_by_num_iter(self, num_iters): # generator = iter(self.train_loader) pbar = tqdm(total=num_iters, desc='Training', unit='iteration', leave=False) completed_steps = self.last_iter # save init model while completed_steps < num_iters: total_loss = 0 current_loss = 0 for i, batch in enumerate(self.train_loader): # gradient accumulation with self.accelerator.accumulate(self.model): # Start time for the training step,only for main process start_time = time.time() # Tra\in the model on a single batch # loss_value, loss_dict = self._accelerate_train_by_single_batch(batch) loss, _, loss_dict = self._get_loss_pred_from_single_batch(batch) total_loss += loss.detach().float() current_loss = loss.detach().float() # loss.backward() self.accelerator.backward(loss) if self.accelerator.sync_gradients: self.accelerator.unscale_gradients(self.optimizer) if self.accelerator.sync_gradients: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) # self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip) if not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and self.scheduler is not None: self.scheduler.step() self.optimizer.step() self.optimizer.zero_grad() if self.accelerator.sync_gradients: # update progress bar loss_value = loss.item() # log in main process completed_steps += 1 # if self.accelerator.is_main_process: loss_dict['time'] = time.time() - start_time loss_dict['lr'] = self.optimizer.param_groups[0]['lr'] loss_dict = self._rename_dict(loss_dict, 'train') self.training_loss.append(loss_value) if self.accelerator.is_main_process: pbar.update(1) pbar.set_postfix(loss=loss_value, lr=self.optimizer.param_groups[0]['lr']) # save iter1 checkpoint if completed_steps == 1 and self.accelerator.is_main_process: self.save_model(self.save_dir / f'iter{completed_steps}_loss{current_loss:.4f}.pt') # Log training loss at the specified training cycle if (completed_steps + 1) % self.iterations_per_training_cycle == 0 and self.make_log and self.accelerator.is_main_process: wandb.log(loss_dict, step=completed_steps) # Log training accuracy periodically if (completed_steps + 1) % (self.iterations_per_training_cycle * 3) == 0 and self.make_log: validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch, train=True) train_metric_dict = self._get_train_accuracy(num_nonmask_tokens, num_tokens_by_feature, correct_guess_by_feature) train_metric_dict.update(loss_dict) train_metric_dict = self._rename_dict(train_metric_dict, 'train') if self.accelerator.is_main_process: wandb.log(train_metric_dict, step=completed_steps) # delete variables to avoid memory leakages del validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature, train_metric_dict # Perform validation at the specified interval if (completed_steps + 1) % self.iterations_per_validation_cycle == 0: self.model.eval() validation_loss, validation_acc, validation_metric_dict = self.validate() validation_metric_dict['acc'] = validation_acc validation_metric_dict = self._rename_dict(validation_metric_dict, 'valid') if self.make_log and self.accelerator.is_main_process: wandb.log(validation_metric_dict, step=completed_steps) self.validation_loss.append(validation_loss) self.validation_acc.append(validation_acc) self.best_valid_loss = min(validation_loss, self.best_valid_loss) # Perform inference and logging after a certain number of cycles if (completed_steps + 1) % (self.num_cycles_for_inference * self.iterations_per_validation_cycle) == 0 and self.infer_and_log and self.accelerator.is_main_process: self.inference_and_log(i, self.num_uncond_generation, self.num_cond_generation, self.num_max_seq_len) # Save a model checkpoint periodically if (completed_steps + 1) % (self.iterations_per_validation_cycle * self.num_cycles_for_model_checkpoint) == 0 and self.accelerator.is_main_process: self.accelerator.print(f"Saving model checkpoint at iter {completed_steps}") self.save_model(self.save_dir / f'iter{completed_steps}_loss{validation_loss:.4f}.pt') self.model.train() # delete variables to avoid memory leakages del validation_acc, validation_metric_dict # else: # self.accelerator.wait_for_everyone() # Save the model checkpoint at the end of each epoch if self.accelerator.is_main_process: print(f"Saving model checkpoint at iter {completed_steps}") # Save the model state self.save_model(self.save_dir / f"iter{completed_steps}_loss{current_loss:.4f}.pt") # Save the final model after training self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: print("saving last checkpoint") self.save_model(self.save_dir / f'checkpoint_last.pt') # same as above but for accelerate def _accelarate_get_loss_pred_from_single_batch(self, batch): """ Computes the loss and predictions for a single batch of data. Args: batch: A batch of data, typically containing input sequences, targets, and masks. Returns: loss: The computed loss for the batch. logits: The raw model predictions (logits). loss_dict: A dictionary containing the total loss. The method: - Separates the input sequences and target sequences from the batch. - Moves the data to the appropriate device. - Applies mixed precision (FP16) if applicable. - Computes the logits using the model and calculates the loss using the specified loss function. """ segment, mask, caption,encoded_caption = batch input_seq, target = segment[:, :-1], segment[:, 1:] input_seq = input_seq.to(self.device) target = target.to(self.device) mask = mask[:, :-1].to(self.device) if self.use_fp16: with torch.cuda.amp.autocast(): logits = self.model(input_seq, target) loss = self.loss_fn(logits, target, mask) else: logits = self.model(input_seq, None) loss = self.loss_fn(logits, target, mask) loss_dict = {'total': loss.item()} return loss, logits, loss_dict def _train_by_single_batch(self, batch): """ Trains the model on a single batch of data. Args: batch: A batch of data, typically consisting of input sequences and corresponding targets. Returns: loss.item(): The total loss for this batch. loss_dict: A dictionary containing information about the loss and other relevant metrics. The method: - Calls `_get_loss_pred_from_single_batch` to compute the loss and predictions. - Resets the optimizer's gradients. - Depending on whether mixed precision (FP16) is used, it scales the loss and applies gradient clipping before stepping the optimizer. - Updates the learning rate scheduler if applicable. - Records the time taken for the training step and the current learning rate in the `loss_dict`. """ start_time = time.time() loss, _, loss_dict = self._get_loss_pred_from_single_batch(batch) self.optimizer.zero_grad() if self.use_fp16: self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() if not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and self.scheduler is not None: self.scheduler.step() loss_dict['time'] = time.time() - start_time loss_dict['lr'] = self.optimizer.param_groups[0]['lr'] return loss.item(), loss_dict def _get_loss_pred_from_single_batch(self, batch): """ Computes the loss and predictions for a single batch of data. Args: batch: A batch of data, typically containing input sequences, targets, and masks. Returns: loss: The computed loss for the batch. logits: The raw model predictions (logits). loss_dict: A dictionary containing the total loss. The method: - Separates the input sequences and target sequences from the batch. - Moves the data to the appropriate device. - Applies mixed precision (FP16) if applicable. - Computes the logits using the model and calculates the loss using the specified loss function. """ segment, mask, caption,encoded_caption = batch input_seq, target = segment[:, :-1], segment[:, 1:] input_seq = input_seq.to(self.device) target = target.to(self.device) mask = mask[:, :-1].to(self.device) if self.use_fp16: with self.accelerator.autocast(): logits = self.model(input_seq, target) loss = self.loss_fn(logits, target, mask) else: logits = self.model(input_seq, None) loss = self.loss_fn(logits, target, mask) loss_dict = {'total': loss.item()} return loss, logits, loss_dict def _get_valid_loss_and_acc_from_batch(self, batch, train=False): """ Computes validation loss and accuracy from a single batch. Args: batch: A batch of data, typically containing input sequences, targets, and masks. train (bool): Indicator whether the function is being used in training mode. Returns: validation_loss: Total validation loss for the batch. num_tokens: The number of valid tokens in the batch. loss_dict: A dictionary containing the loss and relevant metrics. None: Placeholder for future implementation. num_correct_guess: Number of correctly predicted tokens. The method: - Calls `_get_loss_pred_from_single_batch` to compute the loss and predictions. - Computes token-level accuracy by comparing predicted tokens with the targets. """ segment, mask, caption,encoded_caption = batch input_seq, target = segment[:, :-1], segment[:, 1:] loss, logits, loss_dict = self._get_loss_pred_from_single_batch(batch) prob = torch.softmax(logits, dim=-1) num_tokens = torch.sum(mask) target = target.to(self.device) mask = mask[:, :-1].to(self.device) selected_tokens = torch.argmax(prob, dim=-1) * mask shifted_tgt_with_mask = target * mask num_correct_guess = torch.sum(selected_tokens == shifted_tgt_with_mask) - torch.sum(mask == 0) validation_loss = loss.item() * num_tokens num_correct_guess = num_correct_guess.item() return validation_loss, num_tokens, loss_dict, None, num_correct_guess def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess): """ Computes training accuracy. Args: num_tokens: Total number of tokens processed. num_tokens_by_feature: Number of tokens for each feature (not used here). num_correct_guess: Number of correctly predicted tokens. Returns: Training accuracy, computed as the ratio of correct predictions to the total number of tokens. """ return num_correct_guess / num_tokens def validate(self, external_loader=None): """ Validates the model on a dataset. Args: external_loader (DataLoader): If provided, an external DataLoader can be used for validation. Returns: total_validation_loss: Average validation loss over all batches. total_num_correct_guess: Total number of correct predictions divided by the number of tokens (accuracy). validation_metric_dict: Dictionary of validation metrics averaged over all batches. The method: - Iterates through the validation data loader, calculating the loss and accuracy for each batch. - Aggregates the results over all batches and returns the overall validation metrics. """ if external_loader and isinstance(external_loader, DataLoader): loader = external_loader print('An arbitrary loader is used instead of Validation loader') else: loader = self.valid_loader self.model.eval() total_validation_loss = 0 total_num_correct_guess = 0 total_num_tokens = 0 validation_metric_dict = defaultdict(float) with torch.inference_mode(): for batch in tqdm(loader, leave=False): validation_loss, num_tokens, loss_dict, _, num_correct_guess = self._get_valid_loss_and_acc_from_batch(batch) total_validation_loss += validation_loss total_num_tokens += num_tokens total_num_correct_guess += num_correct_guess for key, value in loss_dict.items(): validation_metric_dict[key] += value * num_tokens for key in validation_metric_dict.keys(): validation_metric_dict[key] /= total_num_tokens return total_validation_loss / total_num_tokens, total_num_correct_guess / total_num_tokens, validation_metric_dict def _make_midi_from_generated_output(self, generated_output, iter, seed, condition=None): """ Generates a MIDI file and logs output from the generated sequence. Args: generated_output: The sequence of notes generated by the model. iter: The current iteration of the training process. seed: The seed used for generating the sequence. condition: Optional condition input for generating conditional output. The method: - Converts the generated output into a MIDI file and logs it. - Optionally logs additional error metrics and figures for analysis. """ if condition is not None: path_addition = "cond_" else: path_addition = "" with open(self.valid_out_dir / f"{path_addition}generated_output_{iter}_seed_{seed}.pkl", 'wb') as f: pickle.dump(generated_output, f) self.midi_decoder(generated_output, self.valid_out_dir / f"{path_addition}midi_decoded_{iter}_seed_{seed}.mid") if self.make_log: log_dict = {} log_dict[f'{path_addition}gen_score'] = wandb.Image(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.png')) log_dict[f'{path_addition}gen_audio'] = wandb.Audio(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.mp3')) wandb.log(log_dict, step=(iter+seed)) print(f"{path_addition}inference is logged: Iter {iter} / seed {seed}") return generated_output @torch.inference_mode() def inference_and_log(self, iter, num_uncond_generation=5, num_cond_generation=5, max_seq_len=10000): """ Generates and logs both unconditional and conditional output sequences. Args: iter: The current iteration. num_uncond_generation: Number of unconditional sequences to generate. num_cond_generation: Number of conditional sequences to generate. max_seq_len: Maximum sequence length to generate. The method: - Generates unconditional and conditional sequences using the model's generation function. - Converts the sequences into MIDI files and logs the generated results. """ self.model.eval() for i in range(num_uncond_generation): try: start_time = time.time() uncond_generated_output = self.model.module.generate(manual_seed=i, max_seq_len=max_seq_len, condition=None, \ sampling_method=self.sampling_method, threshold=self.sampling_threshold, temperature=self.sampling_temperature) if len(uncond_generated_output) == 0: continue print(f"unconditional generation time_{iter}: {time.time() - start_time:.4f}") print(f"unconditional length of generated_output: {uncond_generated_output.shape[1]}") self._make_midi_from_generated_output(uncond_generated_output, iter, i, None) except Exception as e: print(e) condition_list = [x[1] for x in self.valid_set.data_list[:num_cond_generation] ] for i in range(num_cond_generation): condition = self.valid_set.get_segments_with_tune_idx(condition_list[i], 0)[0] try: start_time = time.time() generated_output = self.model.module.generate(manual_seed=i, max_seq_len=max_seq_len, condition=condition, \ sampling_method=self.sampling_method, threshold=self.sampling_threshold, temperature=self.sampling_temperature) if len(generated_output) == 0: continue print(f"conditional generation time_{iter}: {time.time() - start_time:.4f}") print(f"conditional length of generated_output: {generated_output.shape[1]}") self._make_midi_from_generated_output(generated_output, iter+num_uncond_generation, i, condition) except Exception as e: print(e) def _rename_dict(self, adict, prefix='train'): ''' Renames the keys in a dictionary by adding a prefix. ''' keys = list(adict.keys()) for key in keys: adict[f'{prefix}.{key}'] = adict.pop(key) return dict(adict) class LanguageModelTrainer4REMI(LanguageModelTrainer): def __init__(self, model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config): super().__init__(model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config) def _get_loss_pred_from_single_batch(self, batch, valid=False): segment, mask, caption,encoded_caption = batch input_seq, target = segment[:, :-1], segment[:, 1:] input_seq = input_seq.to(self.device) target = target.to(self.device) mask = mask[:, :-1].to(self.device) if self.use_fp16: with self.accelerator.autocast(): logits = self.model(input_seq, target) if not valid: total_loss, loss_dict = self.loss_fn(logits, target, mask, None) return total_loss, logits, {'total':total_loss.item()} else: total_loss, loss_dict = self.loss_fn(logits, target, mask, self.vocab) loss_dict['total'] = total_loss.item() return total_loss, logits, loss_dict else: logits = self.model(input_seq, target) if not valid: total_loss, loss_dict = self.loss_fn(logits, target, mask, None) return total_loss, logits, {'total':total_loss.item()} else: total_loss, loss_dict = self.loss_fn(logits, target, mask, self.vocab) loss_dict['total'] = total_loss.item() return total_loss, logits, loss_dict def _get_valid_loss_and_acc_from_batch(self, batch, train=False): segment, mask, caption,encoded_caption = batch mask = mask[:, :-1] _, target = segment[:, :-1], segment[:, 1:] loss, logits, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True) prob = torch.softmax(logits, dim=-1) num_nonmask_tokens = torch.sum(mask) # [b, t] target = target.to(self.device) # [b, t] mask = mask.to(self.device) prob_with_mask = torch.argmax(prob, dim=-1) * mask # [b, t] shifted_tgt_with_mask = target * mask # [b, t] correct_guess_by_feature = defaultdict(int) num_tokens_by_feature = defaultdict(int) tokens_idx = prob_with_mask.flatten(0,1) # [b*t] answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t] if self.vocab.encoding_scheme == 'remi': eos_idx = 2 for feature in self.vocab.feature_list: feature_mask = self.vocab.total_mask[feature].to(self.device) # [327,] mask_for_target = feature_mask[answers_idx] # [b*t] if feature == 'type': # because Bar token is 0, we need to add 1 to calculate accuracy valid_pred = (tokens_idx+1) * mask_for_target valid_answers = (answers_idx+1) * mask_for_target eos_mask = valid_answers != eos_idx # because EOS is also working as a padding correct_guess_by_feature[feature] += torch.sum(valid_pred[eos_mask] == valid_answers[eos_mask]).item() - torch.sum(mask_for_target[eos_mask] == 0).item() num_tokens_by_feature[feature] += torch.sum(mask_for_target[eos_mask]).item() else: valid_pred = tokens_idx * mask_for_target # [b, t] valid_answers = answers_idx * mask_for_target # [b, t] correct_guess_by_feature[feature] += torch.sum(valid_pred == valid_answers).item() - torch.sum(mask_for_target == 0).item() num_tokens_by_feature[feature] += torch.sum(mask_for_target).item() validation_loss = loss.item() * num_nonmask_tokens.item() return validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess_by_feature): total_num_correct_guess = 0 total_num_tokens = 0 acc_dict = {} for feature, num_correct_guess in num_correct_guess_by_feature.items(): if feature == 'type': continue total_num_correct_guess += num_correct_guess total_num_tokens += num_tokens_by_feature[feature] if num_tokens_by_feature[feature] == 0: continue acc_dict[f"{feature}_acc"] = num_correct_guess / num_tokens_by_feature[feature] total_accuracy = total_num_correct_guess / total_num_tokens acc_dict['total_acc'] = total_accuracy return acc_dict def validate(self, external_loader=None): ''' total_num_tokens: for calculating loss, nonmask tokens total_num_valid_tokens: for calculating accuracy, valid tokens ''' if external_loader and isinstance(external_loader, DataLoader): loader = external_loader print('An arbitrary loader is used instead of Validation loader') else: loader = self.valid_loader self.model.eval() total_validation_loss = 0 total_num_tokens = 0 total_num_valid_tokens = 0 total_num_correct_guess = 0 validation_metric_dict = defaultdict(float) total_num_tokens_by_feature = defaultdict(int) total_num_correct_guess_dict = defaultdict(int) with torch.inference_mode(): for num_iter, batch in enumerate(tqdm(loader, leave=False)): if num_iter == len(self.valid_loader): if loader is not self.valid_loader: # when validate with train_loader break validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, num_correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch) total_validation_loss += validation_loss total_num_tokens += num_nonmask_tokens.item() for key, num_tokens in num_tokens_by_feature.items(): total_num_tokens_by_feature[key] += num_tokens if key == 'type': continue total_num_valid_tokens += num_tokens # num tokens are all the same for each musical type, torch.sum(mask) for key, num_correct_guess in num_correct_guess_by_feature.items(): total_num_correct_guess_dict[key] += num_correct_guess if key == 'type': continue total_num_correct_guess += num_correct_guess for key, value in loss_dict.items(): if key == 'total': validation_metric_dict[key] += value * num_nonmask_tokens else: feature_name = key.split('_')[0] validation_metric_dict[key] += value * num_tokens_by_feature[feature_name] for key in validation_metric_dict.keys(): if key == 'total': validation_metric_dict[key] /= total_num_tokens else: feature_name = key.split('_')[0] if total_num_tokens_by_feature[feature_name] == 0: continue validation_metric_dict[key] /= total_num_tokens_by_feature[feature_name] for key in total_num_tokens_by_feature.keys(): num_tokens = total_num_tokens_by_feature[key] num_correct = total_num_correct_guess_dict[key] if num_tokens == 0: continue validation_metric_dict[f'{key}_acc'] = num_correct / num_tokens return total_validation_loss / total_num_tokens, total_num_correct_guess / total_num_valid_tokens, validation_metric_dict class LanguageModelTrainer4CompoundToken(LanguageModelTrainer): def __init__(self, model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config): super().__init__(model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config) ''' About ignore_token and conti_token: During validation, tokens with this "conti" value are ignored when calculating accuracy or other metrics, ensuring that repeated values don't unfairly skew the results. This is especially relevant for features like beat, chord, tempo, and instrument where repeated tokens may have a specific musical meaning. We used ignore_token and conti_token to fairly compare compound token based encoding with REMI encoding. ''' def _get_num_valid_and_correct_tokens(self, prob, ground_truth, mask, ignore_token=None, conti_token=None): valid_prob = torch.argmax(prob, dim=-1) * mask valid_ground_truth = ground_truth * mask if ignore_token is None and conti_token is None: num_valid_tokens = torch.sum(mask) num_correct_tokens = torch.sum(valid_prob == valid_ground_truth) - torch.sum(mask == 0) elif ignore_token is not None and conti_token is None: ignore_mask = valid_ground_truth != ignore_token # batch x seq_len num_valid_tokens = torch.sum(ignore_mask) num_correct_tokens = torch.sum(valid_prob[ignore_mask] == valid_ground_truth[ignore_mask]) # by using mask, the tensor becomes 1d elif ignore_token is not None and conti_token is not None: ignore_conti_mask = (valid_ground_truth != ignore_token) & (valid_ground_truth != conti_token) num_valid_tokens = torch.sum(ignore_conti_mask) num_correct_tokens = torch.sum(valid_prob[ignore_conti_mask] == valid_ground_truth[ignore_conti_mask]) return num_correct_tokens.item(), num_valid_tokens.item() def _get_loss_pred_from_single_batch(self, batch, valid=False): # print(batch) segment, mask, caption,encoded_caption = batch input_seq, target = segment[:, :-1], segment[:, 1:] input_seq = input_seq.to(self.device) target = target.to(self.device) mask = mask[:, :-1].to(self.device) encoded_caption = encoded_caption.to(self.device) if self.use_fp16: if self.config.use_diff is True: with self.accelerator.autocast(): # breakpoint() (logits_dict, (masked_indices, p_mask)),input_dict = self.model(input_seq, target,context=encoded_caption) if self.config.use_dispLoss == True: total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid, input_dict=input_dict,lambda_weight=self.config.lambda_weight,tau=self.config.tau) else: total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid) else: with self.accelerator.autocast(): logits_dict,_ = self.model(input_seq, target,context=encoded_caption) total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, valid) else: if self.config.use_diff is True: # breakpoint() if self.config.use_dispLoss == True: total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid, input_dict=input_dict,lambda_weight=self.config.lambda_weight) else: total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid) else: logits_dict, input_Dict = self.model(input_seq, target,context=encoded_caption) total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, valid) if valid: loss_dict['total'] = total_loss.item() else: loss_dict = {'total':total_loss.item()} return total_loss, logits_dict, loss_dict def _get_valid_loss_and_acc_from_batch(self, batch, train=False): ''' in this method, valid means handled with both ignore token and mask when valid tokens with only mask, it is called num_nonmask_tokens input_seq, target: batch x seq_len x num_features mask: batch x seq_len, 0 for padding prob: batch x seq_len x total_vocab_size ''' segment, mask, caption,encoded_caption = batch input_seq, target = segment[:, :-1], segment[:, 1:] total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True) probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()} num_nonmask_tokens = torch.sum(mask) input_seq = input_seq.to(self.device) target = add_conti_in_valid(target, self.config.nn_params.encoding_scheme).to(self.device) mask = mask[:, :-1].to(self.device) correct_guess_by_feature = defaultdict(int) num_tokens_by_feature = defaultdict(int) for idx, key in enumerate(self.vocab.feature_list): if key == 'type' or key == 'timesig' : num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None) elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program': num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999) elif key == 'beat': # NB's beat vocab has Ignore and CONTI token # CP's beat vocab has Ignore and BAR token, we exclude BAR token in accuracy calculation for parity with NB num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999) else: num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=None) correct_guess_by_feature[key] = num_correct_tokens num_tokens_by_feature[key] = num_valid_tokens validation_loss = total_loss.item() * num_nonmask_tokens.item() return validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess_by_feature): total_num_correct_guess = 0 total_num_tokens = 0 acc_dict = {} for feature, num_correct_guess in num_correct_guess_by_feature.items(): if feature == 'type': continue total_num_correct_guess += num_correct_guess total_num_tokens += num_tokens_by_feature[feature] acc_dict[f"{feature}_acc"] = num_correct_guess / num_tokens_by_feature[feature] total_accuracy = total_num_correct_guess / total_num_tokens acc_dict['total_acc'] = total_accuracy return acc_dict def validate(self, external_loader=None): if external_loader and isinstance(external_loader, DataLoader): loader = external_loader print('An arbitrary loader is used instead of Validation loader') else: loader = self.valid_loader self.model.eval() total_validation_loss = 0 total_num_correct_guess = 0 total_num_tokens = 0 total_num_valid_tokens = 0 validation_metric_dict = defaultdict(float) total_num_tokens_by_feature = defaultdict(int) total_num_correct_guess_dict = defaultdict(int) with torch.inference_mode(): ''' mask is used to calculate loss, accuracy validation_loss: sum of loss for valid tokens conditioned on mask num_nonmask_tokens: sum of tokens conditioned on mask num_tokens_by_feature: sum of valid tokens(handle ignore) for each musical features num_correct_guess_by_feature: sum of correct tokens(handle ignore) for each musical features ''' for num_iter, batch in tqdm(enumerate(loader), leave=False): if num_iter == len(self.valid_loader): if loader is not self.valid_loader: # when validate with train_loader break validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, num_correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch) total_validation_loss += validation_loss total_num_tokens += num_nonmask_tokens for key, num_tokens in num_tokens_by_feature.items(): total_num_tokens_by_feature[key] += num_tokens if key == 'type': # because cp and nb have different number of type tokens, we don't want to calculate accuracy for type token continue total_num_valid_tokens += num_tokens # num tokens are all the same for each musical type, torch.sum(mask) for key, num_correct_guess in num_correct_guess_by_feature.items(): total_num_correct_guess_dict[key] += num_correct_guess if key == 'type': continue total_num_correct_guess += num_correct_guess for key, value in loss_dict.items(): if key == 'total': validation_metric_dict[key] += value * num_nonmask_tokens else: # if torch.isnan(value): # in case num valid tokens is 0 because of mask # continue feature_name = key.split('_')[0] validation_metric_dict[key] += value * num_tokens_by_feature[feature_name] for key in validation_metric_dict.keys(): if key == 'total': validation_metric_dict[key] /= total_num_tokens else: feature_name = key.split('_')[0] if total_num_tokens_by_feature[feature_name] == 0: continue validation_metric_dict[key] /= total_num_tokens_by_feature[feature_name] for (key_t, num_tokens), (key_c, num_correct) in zip(total_num_tokens_by_feature.items(), total_num_correct_guess_dict.items()): validation_metric_dict[f'{key_c}_acc'] = num_correct / num_tokens return total_validation_loss / (total_num_tokens + 1), total_num_correct_guess / (1+total_num_valid_tokens), validation_metric_dict def _make_midi_from_generated_output(self, generated_output, iter, seed, condition=None): if self.config.data_params.first_pred_feature != 'type' and self.config.nn_params.encoding_scheme == 'nb': generated_output = reverse_shift_and_pad_for_tensor(generated_output, self.config.data_params.first_pred_feature) if condition is not None: path_addition = "cond_" else: path_addition = "" # save generated_output as pickle with open(self.valid_out_dir / f"{path_addition}generated_output_{iter}_seed_{seed}.pkl", 'wb') as f: pickle.dump(generated_output, f) self.midi_decoder(generated_output, self.valid_out_dir / f"{path_addition}midi_decoded_{iter}_seed_{seed}.mid") if self.make_log and self.infer_and_log: log_dict = {} log_dict[f'{path_addition}gen_score'] = wandb.Image(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.png')) log_dict[f'{path_addition}gen_audio'] = wandb.Audio(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.mp3')) wandb.log(log_dict, step=(iter+seed)) print(f"{path_addition}inference is logged: Iter {iter} / seed {seed}")