1015 lines
51 KiB
Python
1015 lines
51 KiB
Python
from calendar import EPOCH, c
|
|
from multiprocessing import context
|
|
import time
|
|
import pickle
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Union
|
|
from datetime import datetime
|
|
from omegaconf import OmegaConf
|
|
import random
|
|
import itertools
|
|
|
|
|
|
import torch
|
|
import torchaudio
|
|
from torch.utils.data import DataLoader
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.utils.data.distributed import DistributedSampler, Sampler
|
|
|
|
# import accelerate
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs
|
|
from accelerate.utils import set_seed
|
|
#======================================================================
|
|
|
|
|
|
import wandb
|
|
from collections import defaultdict
|
|
from tqdm.auto import tqdm
|
|
|
|
from .model_zoo import AmadeusModel
|
|
from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
|
|
from .symbolic_encoding.data_utils import TuneCompiler
|
|
from .symbolic_encoding.decoding_utils import MidiDecoder4REMI
|
|
from .evaluation_utils import add_conti_in_valid
|
|
from .train_utils import NLLLoss4REMI
|
|
|
|
os.environ['WANDB_INIT_TIMEOUT'] = '600'
|
|
os.environ["WANDB_BASE_URL"] = "https://api.bandw.top"
|
|
from data_representation.vocab_utils import LangTokenVocab
|
|
class InfiniteSampler(Sampler):
|
|
def __init__(self, data_source, shuffle=True):
|
|
self.data_source = data_source
|
|
self.shuffle = shuffle
|
|
self.indices = list(range(len(data_source)))
|
|
if self.shuffle:
|
|
random.shuffle(self.indices)
|
|
self.infinite_iterator = itertools.cycle(self.indices)
|
|
|
|
def __iter__(self):
|
|
return self.infinite_iterator
|
|
|
|
def __len__(self):
|
|
return None # 表示无限长度
|
|
|
|
class LanguageModelTrainer:
|
|
def __init__(
|
|
self,
|
|
model: AmadeusModel, # The language model for music generation
|
|
optimizer: torch.optim.Optimizer, # Optimizer for updating model weights
|
|
scheduler: torch.optim.lr_scheduler._LRScheduler, # Learning rate scheduler
|
|
loss_fn: NLLLoss4REMI, # Loss function to compute the error
|
|
midi_decoder: MidiDecoder4REMI, # Decoder to convert model output into MIDI format
|
|
train_set: TuneCompiler, # Training dataset
|
|
valid_set: TuneCompiler, # Validation dataset
|
|
save_dir: str, # Directory to save models and logs
|
|
vocab: LangTokenVocab, # Vocabulary for tokenizing sequences
|
|
use_ddp: bool, # Whether to use Distributed Data Parallel (DDP)
|
|
use_fp16: bool, # Whether to use mixed-precision training (FP16)
|
|
world_size: int, # Total number of devices for distributed training
|
|
batch_size: int, # Batch size for training
|
|
infer_target_len: int, # Target length for inference generation
|
|
gpu_id: int, # GPU device ID for computation
|
|
sampling_method: str, # Sampling method for sequence generation
|
|
sampling_threshold: float, # Threshold for sampling decisions
|
|
sampling_temperature: float, # Temperature for controlling sampling randomness
|
|
config, # Configuration parameters (contains general, training, and inference settings)
|
|
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
|
|
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
|
):
|
|
# Save model, optimizer, and other configurations
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
self.scheduler = scheduler
|
|
self.loss_fn = loss_fn
|
|
|
|
self.valid_set = valid_set
|
|
self.vocab = vocab
|
|
self.use_ddp = use_ddp
|
|
self.world_size = world_size
|
|
self.batch_size = batch_size
|
|
self.gpu_id = gpu_id
|
|
self.sampling_method = sampling_method
|
|
self.sampling_threshold = sampling_threshold
|
|
self.sampling_temperature = sampling_temperature
|
|
self.config = config
|
|
self.last_iter = 0
|
|
|
|
# Load pre-trained model if provided
|
|
if model_checkpoint:
|
|
# parse the model checkpoint iter
|
|
if isinstance(model_checkpoint, str):
|
|
if model_checkpoint.endswith('.pt'):
|
|
self.last_iter = int(model_checkpoint.split('/')[-1].split('_')[0][4:])
|
|
checkpoint = torch.load(model_checkpoint, map_location='cpu')
|
|
# print state dict keys
|
|
print("Loading model checkpoint from", model_checkpoint)
|
|
if isinstance(self.model, DDP):
|
|
self.model.module.load_state_dict(checkpoint['model'], strict=False)
|
|
else:
|
|
|
|
self.model.load_state_dict(checkpoint['model'], strict=False)
|
|
# Training hyperparameters from config
|
|
self.grad_clip = config.train_params.grad_clip
|
|
self.num_cycles_for_inference = config.train_params.num_cycles_for_inference
|
|
self.num_cycles_for_model_checkpoint = config.train_params.num_cycles_for_model_checkpoint
|
|
self.iterations_per_training_cycle = config.train_params.iterations_per_training_cycle
|
|
self.iterations_per_validation_cycle = config.train_params.iterations_per_validation_cycle
|
|
self.make_log = config.general.make_log
|
|
self.num_uncond_generation = config.inference_params.num_uncond_generation
|
|
self.num_cond_generation = config.inference_params.num_cond_generation
|
|
self.num_max_seq_len = infer_target_len
|
|
self.infer_and_log = config.general.infer_and_log
|
|
self.valid_loader = self.generate_data_loader(self.valid_set, shuffle=False, drop_last=True)
|
|
|
|
# gradient accumulation
|
|
self.gradient_accumulation_steps = config.train_params.gradient_accumulation_steps
|
|
# Set up mixed-precision training (FP16) if enabled
|
|
if use_fp16:
|
|
self.use_fp16 = True
|
|
else:
|
|
self.use_fp16 = False
|
|
# Set up Distributed Data Parallel (DDP) if required
|
|
if use_ddp:
|
|
# prepare using accelerator
|
|
if self.use_fp16:
|
|
self.accelerator = Accelerator(mixed_precision='bf16',
|
|
step_scheduler_with_optimizer=False,
|
|
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
|
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
|
else:
|
|
self.accelerator = Accelerator(
|
|
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
|
step_scheduler_with_optimizer=False,
|
|
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
|
|
|
with self.accelerator.main_process_first():
|
|
self.train_set = train_set
|
|
self.train_loader = self.generate_data_loader(self.train_set, shuffle=False, drop_last=False)
|
|
self.accelerator.wait_for_everyone()
|
|
self.accelerator.print(f"Using {self.world_size} GPUs for training")
|
|
|
|
self.model, self.optimizer, self.scheduler, self.train_loader = self.accelerator.prepare(
|
|
self.model, self.optimizer, self.scheduler, self.train_loader
|
|
)
|
|
self.accelerator.wait_for_everyone()
|
|
# self.accelerator.init_trackers("nested_music_transformer", config)
|
|
set_seed(42)
|
|
self.device = self.accelerator.device
|
|
self.model.to(self.device)
|
|
# set up for logging
|
|
if self.accelerator.is_main_process:
|
|
save_dir = self.setup_log(config)
|
|
print("savwe",save_dir)
|
|
# Create directory for saving models and logs
|
|
self.save_dir = Path(save_dir)
|
|
self.save_dir.mkdir(exist_ok=True, parents=True)
|
|
self.set_save_out()
|
|
else:
|
|
self.train_set = train_set
|
|
# Create data loaders for training and validation sets
|
|
self.train_loader = self.generate_data_loader(train_set, shuffle=False, drop_last=True)
|
|
self.valid_loader = self.generate_data_loader(valid_set, shuffle=True, drop_last=True)
|
|
save_dir = self.setup_log(config)
|
|
# Create directory for saving models and logs
|
|
self.save_dir = Path(save_dir)
|
|
self.save_dir.mkdir(exist_ok=True, parents=True)
|
|
self.set_save_out()
|
|
|
|
self.device = config.train_params.device
|
|
self.model.to(self.device)
|
|
|
|
|
|
# Initialize tracking metrics
|
|
self.best_valid_accuracy = 0
|
|
self.best_valid_loss = 100
|
|
self.training_loss = []
|
|
self.validation_loss = []
|
|
self.validation_acc = []
|
|
|
|
self.midi_decoder = midi_decoder
|
|
|
|
|
|
def generate_experiment_name(self, config):
|
|
# add base hyperparameters to the experiment name
|
|
dataset_name = config.dataset
|
|
encoding_name = config.nn_params.encoding_scheme
|
|
num_features = config.nn_params.num_features
|
|
input_embedder_name = config.nn_params.input_embedder_name
|
|
sub_decoder_name = config.nn_params.sub_decoder_name
|
|
batch_size = config.train_params.batch_size
|
|
num_layers = config.nn_params.main_decoder.num_layer
|
|
input_length = config.train_params.input_length
|
|
first_pred_feature = config.data_params.first_pred_feature
|
|
|
|
# Add target hyperparameters to the experiment name
|
|
# dropout
|
|
main_dropout = config.nn_params.model_dropout
|
|
# learning rate
|
|
lr_decay_rate = config.train_params.decay_step_rate
|
|
|
|
time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
# Combine the information into a single string for the experiment name
|
|
# experiment_name = f"{time}_{dataset_name}_{encoding_name}{num_features}_{input_embedder_name}_{sub_decoder_name}_firstpred:{first_pred_feature}_inputlen{input_length}_nlayer{num_layers}_batch{batch_size}\
|
|
# _dropout{main_dropout}_lrdecay{lr_decay_rate}"
|
|
experiment_name = f"{time}_{dataset_name}_{encoding_name}{num_features}_{sub_decoder_name}_firstpred:{first_pred_feature}_inputlen{input_length}_nlayer{num_layers}_batch{batch_size}"
|
|
return experiment_name
|
|
|
|
def collate_fn(self, batch):
|
|
"""
|
|
Custom collate function to handle variable-length sequences in a batch.
|
|
It pads sequences to the maximum length in the batch and returns a tuple of padded sequences and their lengths.
|
|
"""
|
|
# Unzip the batch into segments, masks, captions, and encoded captions
|
|
segments, masks, captions, encoded_captions = zip(*batch)
|
|
# print("collate_fn",len(segments),len(masks),len(captions),len(encoded_captions))
|
|
# # Pad the segments and masks to the maximum length in the batch
|
|
# padded_segments = torch.nn.utils.rnn.pad_sequence(segments, batch_first=True)
|
|
# padded_masks = torch.nn.utils.rnn.pad_sequence(masks, batch_first=True)
|
|
# # Return padded segments and masks along with captions and encoded captions
|
|
segments = torch.stack(segments, dim=0)
|
|
masks = torch.stack(masks, dim=0)
|
|
print(captions)
|
|
print(encoded_captions)
|
|
# captions = torch.stack(captions, dim=0)
|
|
# encoded_captions = torch.stack(encoded_captions, dim=0)
|
|
return segments, masks, captions, encoded_captions
|
|
# return padded_segments, padded_masks, captions, encoded_captions
|
|
def setup_log(self, config):
|
|
if self.accelerator.is_main_process:
|
|
if config.general.make_log:
|
|
experiment_name =self.generate_experiment_name(config)
|
|
wandb.init(
|
|
project="Acce_Music_Transformer",
|
|
name=experiment_name,
|
|
config=OmegaConf.to_container(config)
|
|
)
|
|
# 保存配置到 WANDB 根目录
|
|
config_path = Path(wandb.run.dir) / "config.yaml"
|
|
OmegaConf.save(config, config_path) # 关键代码
|
|
|
|
save_dir = Path(wandb.run.dir) / "checkpoints"
|
|
save_dir.mkdir(exist_ok=True, parents=True)
|
|
else:
|
|
now = datetime.now()
|
|
save_dir = Path('wandb/debug/checkpoints') / now.strftime('%y-%m-%d')
|
|
save_dir.mkdir(exist_ok=True, parents=True)
|
|
# 保存配置到调试目录
|
|
config_path = save_dir / "config.yaml"
|
|
OmegaConf.save(config, config_path) # 关键代码
|
|
|
|
return str(save_dir)
|
|
|
|
# Set up the output directories for saving MIDI results during inference
|
|
def set_save_out(self):
|
|
if self.accelerator.is_main_process:
|
|
# copy from latest folder in wandb/debug/checkpoints
|
|
target_folder = 'wandb/debug/checkpoints'
|
|
latest_folder = sorted(Path(target_folder).iterdir(), key=os.path.getmtime)[-1]
|
|
# get files in the latest folder
|
|
files = [f for f in latest_folder.iterdir() if f.is_file()]
|
|
# copy files to the save_dir
|
|
for file in files:
|
|
# copy the file to the save_dir
|
|
target_file = self.save_dir / file.name
|
|
if not target_file.exists():
|
|
os.system(f'cp {file} {target_file}')
|
|
if self.infer_and_log:
|
|
self.valid_out_dir = self.save_dir / 'valid_out'
|
|
os.makedirs(self.valid_out_dir, exist_ok=True)
|
|
|
|
# Save the current model and optimizer state
|
|
def save_model(self, path):
|
|
if isinstance(self.model, DDP):
|
|
torch.save({'model': self.model.module.state_dict(), 'optim': self.optimizer.state_dict()}, path)
|
|
else:
|
|
torch.save({'model': self.model.state_dict(), 'optim': self.optimizer.state_dict()}, path)
|
|
|
|
# Generate the data loader for either training or validation datasets
|
|
def generate_data_loader(self, dataset, shuffle=False, drop_last=False) -> DataLoader:
|
|
return DataLoader(dataset, shuffle=shuffle, batch_size=self.batch_size, drop_last=drop_last,collate_fn=None, pin_memory=True,num_workers=4, persistent_workers=True, prefetch_factor=2, worker_init_fn=None)
|
|
|
|
# Training function based on a given number of iterations
|
|
def accelerate_train_by_num_iter(self, num_iters):
|
|
# generator = iter(self.train_loader)
|
|
pbar = tqdm(total=num_iters, desc='Training', unit='iteration', leave=False)
|
|
completed_steps = self.last_iter
|
|
# save init model
|
|
while completed_steps < num_iters:
|
|
total_loss = 0
|
|
current_loss = 0
|
|
for i, batch in enumerate(self.train_loader):
|
|
# gradient accumulation
|
|
|
|
with self.accelerator.accumulate(self.model):
|
|
|
|
# Start time for the training step,only for main process
|
|
start_time = time.time()
|
|
|
|
# Tra\in the model on a single batch
|
|
# loss_value, loss_dict = self._accelerate_train_by_single_batch(batch)
|
|
loss, _, loss_dict = self._get_loss_pred_from_single_batch(batch)
|
|
total_loss += loss.detach().float()
|
|
current_loss = loss.detach().float()
|
|
# loss.backward()
|
|
self.accelerator.backward(loss)
|
|
if self.accelerator.sync_gradients:
|
|
self.accelerator.unscale_gradients(self.optimizer)
|
|
if self.accelerator.sync_gradients:
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
|
|
# self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip)
|
|
if not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and self.scheduler is not None:
|
|
self.scheduler.step()
|
|
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
if self.accelerator.sync_gradients:
|
|
# update progress bar
|
|
loss_value = loss.item()
|
|
# log in main process
|
|
completed_steps += 1
|
|
|
|
# if self.accelerator.is_main_process:
|
|
loss_dict['time'] = time.time() - start_time
|
|
loss_dict['lr'] = self.optimizer.param_groups[0]['lr']
|
|
loss_dict = self._rename_dict(loss_dict, 'train')
|
|
self.training_loss.append(loss_value)
|
|
if self.accelerator.is_main_process:
|
|
pbar.update(1)
|
|
pbar.set_postfix(loss=loss_value, lr=self.optimizer.param_groups[0]['lr'])
|
|
# save iter1 checkpoint
|
|
if completed_steps == 1 and self.accelerator.is_main_process:
|
|
self.save_model(self.save_dir / f'iter{completed_steps}_loss{current_loss:.4f}.pt')
|
|
|
|
# Log training loss at the specified training cycle
|
|
if (completed_steps + 1) % self.iterations_per_training_cycle == 0 and self.make_log and self.accelerator.is_main_process:
|
|
wandb.log(loss_dict, step=completed_steps)
|
|
|
|
# Log training accuracy periodically
|
|
if (completed_steps + 1) % (self.iterations_per_training_cycle * 3) == 0 and self.make_log:
|
|
validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch, train=True)
|
|
train_metric_dict = self._get_train_accuracy(num_nonmask_tokens, num_tokens_by_feature, correct_guess_by_feature)
|
|
train_metric_dict.update(loss_dict)
|
|
train_metric_dict = self._rename_dict(train_metric_dict, 'train')
|
|
if self.accelerator.is_main_process:
|
|
wandb.log(train_metric_dict, step=completed_steps)
|
|
# delete variables to avoid memory leakages
|
|
del validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature, train_metric_dict
|
|
|
|
# Perform validation at the specified interval
|
|
if (completed_steps + 1) % self.iterations_per_validation_cycle == 0:
|
|
self.model.eval()
|
|
validation_loss, validation_acc, validation_metric_dict = self.validate()
|
|
validation_metric_dict['acc'] = validation_acc
|
|
validation_metric_dict = self._rename_dict(validation_metric_dict, 'valid')
|
|
if self.make_log and self.accelerator.is_main_process:
|
|
wandb.log(validation_metric_dict, step=completed_steps)
|
|
self.validation_loss.append(validation_loss)
|
|
self.validation_acc.append(validation_acc)
|
|
self.best_valid_loss = min(validation_loss, self.best_valid_loss)
|
|
|
|
# Perform inference and logging after a certain number of cycles
|
|
if (completed_steps + 1) % (self.num_cycles_for_inference * self.iterations_per_validation_cycle) == 0 and self.infer_and_log and self.accelerator.is_main_process:
|
|
self.inference_and_log(i, self.num_uncond_generation, self.num_cond_generation, self.num_max_seq_len)
|
|
|
|
# Save a model checkpoint periodically
|
|
if (completed_steps + 1) % (self.iterations_per_validation_cycle * self.num_cycles_for_model_checkpoint) == 0 and self.accelerator.is_main_process:
|
|
self.accelerator.print(f"Saving model checkpoint at iter {completed_steps}")
|
|
self.save_model(self.save_dir / f'iter{completed_steps}_loss{validation_loss:.4f}.pt')
|
|
self.model.train()
|
|
|
|
# delete variables to avoid memory leakages
|
|
del validation_acc, validation_metric_dict
|
|
# else:
|
|
# self.accelerator.wait_for_everyone()
|
|
# Save the model checkpoint at the end of each epoch
|
|
if self.accelerator.is_main_process:
|
|
print(f"Saving model checkpoint at iter {completed_steps}")
|
|
# Save the model state
|
|
self.save_model(self.save_dir / f"iter{completed_steps}_loss{current_loss:.4f}.pt")
|
|
# Save the final model after training
|
|
self.accelerator.wait_for_everyone()
|
|
if self.accelerator.is_main_process:
|
|
print("saving last checkpoint")
|
|
self.save_model(self.save_dir / f'checkpoint_last.pt')
|
|
|
|
# same as above but for accelerate
|
|
def _accelarate_get_loss_pred_from_single_batch(self, batch):
|
|
"""
|
|
Computes the loss and predictions for a single batch of data.
|
|
|
|
Args:
|
|
batch: A batch of data, typically containing input sequences, targets, and masks.
|
|
|
|
Returns:
|
|
loss: The computed loss for the batch.
|
|
logits: The raw model predictions (logits).
|
|
loss_dict: A dictionary containing the total loss.
|
|
|
|
The method:
|
|
- Separates the input sequences and target sequences from the batch.
|
|
- Moves the data to the appropriate device.
|
|
- Applies mixed precision (FP16) if applicable.
|
|
- Computes the logits using the model and calculates the loss using the specified loss function.
|
|
"""
|
|
segment, mask, caption,encoded_caption = batch
|
|
input_seq, target = segment[:, :-1], segment[:, 1:]
|
|
input_seq = input_seq.to(self.device)
|
|
target = target.to(self.device)
|
|
mask = mask[:, :-1].to(self.device)
|
|
if self.use_fp16:
|
|
with torch.cuda.amp.autocast():
|
|
logits = self.model(input_seq, target)
|
|
loss = self.loss_fn(logits, target, mask)
|
|
else:
|
|
logits = self.model(input_seq, None)
|
|
loss = self.loss_fn(logits, target, mask)
|
|
loss_dict = {'total': loss.item()}
|
|
return loss, logits, loss_dict
|
|
|
|
|
|
|
|
def _train_by_single_batch(self, batch):
|
|
"""
|
|
Trains the model on a single batch of data.
|
|
|
|
Args:
|
|
batch: A batch of data, typically consisting of input sequences and corresponding targets.
|
|
|
|
Returns:
|
|
loss.item(): The total loss for this batch.
|
|
loss_dict: A dictionary containing information about the loss and other relevant metrics.
|
|
|
|
The method:
|
|
- Calls `_get_loss_pred_from_single_batch` to compute the loss and predictions.
|
|
- Resets the optimizer's gradients.
|
|
- Depending on whether mixed precision (FP16) is used, it scales the loss and applies gradient clipping before stepping the optimizer.
|
|
- Updates the learning rate scheduler if applicable.
|
|
- Records the time taken for the training step and the current learning rate in the `loss_dict`.
|
|
"""
|
|
start_time = time.time()
|
|
loss, _, loss_dict = self._get_loss_pred_from_single_batch(batch)
|
|
self.optimizer.zero_grad()
|
|
if self.use_fp16:
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.unscale_(self.optimizer)
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
|
|
self.optimizer.step()
|
|
if not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and self.scheduler is not None:
|
|
self.scheduler.step()
|
|
loss_dict['time'] = time.time() - start_time
|
|
loss_dict['lr'] = self.optimizer.param_groups[0]['lr']
|
|
return loss.item(), loss_dict
|
|
|
|
def _get_loss_pred_from_single_batch(self, batch):
|
|
"""
|
|
Computes the loss and predictions for a single batch of data.
|
|
|
|
Args:
|
|
batch: A batch of data, typically containing input sequences, targets, and masks.
|
|
|
|
Returns:
|
|
loss: The computed loss for the batch.
|
|
logits: The raw model predictions (logits).
|
|
loss_dict: A dictionary containing the total loss.
|
|
|
|
The method:
|
|
- Separates the input sequences and target sequences from the batch.
|
|
- Moves the data to the appropriate device.
|
|
- Applies mixed precision (FP16) if applicable.
|
|
- Computes the logits using the model and calculates the loss using the specified loss function.
|
|
"""
|
|
segment, mask, caption,encoded_caption = batch
|
|
input_seq, target = segment[:, :-1], segment[:, 1:]
|
|
|
|
input_seq = input_seq.to(self.device)
|
|
target = target.to(self.device)
|
|
mask = mask[:, :-1].to(self.device)
|
|
if self.use_fp16:
|
|
with self.accelerator.autocast():
|
|
logits = self.model(input_seq, target)
|
|
loss = self.loss_fn(logits, target, mask)
|
|
else:
|
|
logits = self.model(input_seq, None)
|
|
loss = self.loss_fn(logits, target, mask)
|
|
loss_dict = {'total': loss.item()}
|
|
return loss, logits, loss_dict
|
|
|
|
def _get_valid_loss_and_acc_from_batch(self, batch, train=False):
|
|
"""
|
|
Computes validation loss and accuracy from a single batch.
|
|
|
|
Args:
|
|
batch: A batch of data, typically containing input sequences, targets, and masks.
|
|
train (bool): Indicator whether the function is being used in training mode.
|
|
|
|
Returns:
|
|
validation_loss: Total validation loss for the batch.
|
|
num_tokens: The number of valid tokens in the batch.
|
|
loss_dict: A dictionary containing the loss and relevant metrics.
|
|
None: Placeholder for future implementation.
|
|
num_correct_guess: Number of correctly predicted tokens.
|
|
|
|
The method:
|
|
- Calls `_get_loss_pred_from_single_batch` to compute the loss and predictions.
|
|
- Computes token-level accuracy by comparing predicted tokens with the targets.
|
|
"""
|
|
segment, mask, caption,encoded_caption = batch
|
|
input_seq, target = segment[:, :-1], segment[:, 1:]
|
|
loss, logits, loss_dict = self._get_loss_pred_from_single_batch(batch)
|
|
prob = torch.softmax(logits, dim=-1)
|
|
num_tokens = torch.sum(mask)
|
|
target = target.to(self.device)
|
|
mask = mask[:, :-1].to(self.device)
|
|
|
|
selected_tokens = torch.argmax(prob, dim=-1) * mask
|
|
shifted_tgt_with_mask = target * mask
|
|
num_correct_guess = torch.sum(selected_tokens == shifted_tgt_with_mask) - torch.sum(mask == 0)
|
|
|
|
validation_loss = loss.item() * num_tokens
|
|
num_correct_guess = num_correct_guess.item()
|
|
return validation_loss, num_tokens, loss_dict, None, num_correct_guess
|
|
|
|
def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess):
|
|
"""
|
|
Computes training accuracy.
|
|
|
|
Args:
|
|
num_tokens: Total number of tokens processed.
|
|
num_tokens_by_feature: Number of tokens for each feature (not used here).
|
|
num_correct_guess: Number of correctly predicted tokens.
|
|
|
|
Returns:
|
|
Training accuracy, computed as the ratio of correct predictions to the total number of tokens.
|
|
"""
|
|
return num_correct_guess / num_tokens
|
|
|
|
def validate(self, external_loader=None):
|
|
"""
|
|
Validates the model on a dataset.
|
|
|
|
Args:
|
|
external_loader (DataLoader): If provided, an external DataLoader can be used for validation.
|
|
|
|
Returns:
|
|
total_validation_loss: Average validation loss over all batches.
|
|
total_num_correct_guess: Total number of correct predictions divided by the number of tokens (accuracy).
|
|
validation_metric_dict: Dictionary of validation metrics averaged over all batches.
|
|
|
|
The method:
|
|
- Iterates through the validation data loader, calculating the loss and accuracy for each batch.
|
|
- Aggregates the results over all batches and returns the overall validation metrics.
|
|
"""
|
|
if external_loader and isinstance(external_loader, DataLoader):
|
|
loader = external_loader
|
|
print('An arbitrary loader is used instead of Validation loader')
|
|
else:
|
|
loader = self.valid_loader
|
|
|
|
self.model.eval()
|
|
total_validation_loss = 0
|
|
total_num_correct_guess = 0
|
|
total_num_tokens = 0
|
|
validation_metric_dict = defaultdict(float)
|
|
with torch.inference_mode():
|
|
for batch in tqdm(loader, leave=False):
|
|
validation_loss, num_tokens, loss_dict, _, num_correct_guess = self._get_valid_loss_and_acc_from_batch(batch)
|
|
total_validation_loss += validation_loss
|
|
total_num_tokens += num_tokens
|
|
total_num_correct_guess += num_correct_guess
|
|
for key, value in loss_dict.items():
|
|
validation_metric_dict[key] += value * num_tokens
|
|
for key in validation_metric_dict.keys():
|
|
validation_metric_dict[key] /= total_num_tokens
|
|
|
|
return total_validation_loss / total_num_tokens, total_num_correct_guess / total_num_tokens, validation_metric_dict
|
|
|
|
def _make_midi_from_generated_output(self, generated_output, iter, seed, condition=None):
|
|
"""
|
|
Generates a MIDI file and logs output from the generated sequence.
|
|
|
|
Args:
|
|
generated_output: The sequence of notes generated by the model.
|
|
iter: The current iteration of the training process.
|
|
seed: The seed used for generating the sequence.
|
|
condition: Optional condition input for generating conditional output.
|
|
|
|
The method:
|
|
- Converts the generated output into a MIDI file and logs it.
|
|
- Optionally logs additional error metrics and figures for analysis.
|
|
"""
|
|
if condition is not None:
|
|
path_addition = "cond_"
|
|
else:
|
|
path_addition = ""
|
|
with open(self.valid_out_dir / f"{path_addition}generated_output_{iter}_seed_{seed}.pkl", 'wb') as f:
|
|
pickle.dump(generated_output, f)
|
|
self.midi_decoder(generated_output, self.valid_out_dir / f"{path_addition}midi_decoded_{iter}_seed_{seed}.mid")
|
|
if self.make_log:
|
|
log_dict = {}
|
|
log_dict[f'{path_addition}gen_score'] = wandb.Image(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.png'))
|
|
log_dict[f'{path_addition}gen_audio'] = wandb.Audio(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.mp3'))
|
|
wandb.log(log_dict, step=(iter+seed))
|
|
print(f"{path_addition}inference is logged: Iter {iter} / seed {seed}")
|
|
return generated_output
|
|
|
|
@torch.inference_mode()
|
|
def inference_and_log(self, iter, num_uncond_generation=5, num_cond_generation=5, max_seq_len=10000):
|
|
"""
|
|
Generates and logs both unconditional and conditional output sequences.
|
|
|
|
Args:
|
|
iter: The current iteration.
|
|
num_uncond_generation: Number of unconditional sequences to generate.
|
|
num_cond_generation: Number of conditional sequences to generate.
|
|
max_seq_len: Maximum sequence length to generate.
|
|
|
|
The method:
|
|
- Generates unconditional and conditional sequences using the model's generation function.
|
|
- Converts the sequences into MIDI files and logs the generated results.
|
|
"""
|
|
self.model.eval()
|
|
for i in range(num_uncond_generation):
|
|
try:
|
|
start_time = time.time()
|
|
uncond_generated_output = self.model.module.generate(manual_seed=i, max_seq_len=max_seq_len, condition=None, \
|
|
sampling_method=self.sampling_method, threshold=self.sampling_threshold, temperature=self.sampling_temperature)
|
|
if len(uncond_generated_output) == 0: continue
|
|
print(f"unconditional generation time_{iter}: {time.time() - start_time:.4f}")
|
|
print(f"unconditional length of generated_output: {uncond_generated_output.shape[1]}")
|
|
self._make_midi_from_generated_output(uncond_generated_output, iter, i, None)
|
|
except Exception as e:
|
|
print(e)
|
|
condition_list = [x[1] for x in self.valid_set.data_list[:num_cond_generation] ]
|
|
for i in range(num_cond_generation):
|
|
condition = self.valid_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
|
|
try:
|
|
start_time = time.time()
|
|
generated_output = self.model.module.generate(manual_seed=i, max_seq_len=max_seq_len, condition=condition, \
|
|
sampling_method=self.sampling_method, threshold=self.sampling_threshold, temperature=self.sampling_temperature)
|
|
if len(generated_output) == 0: continue
|
|
print(f"conditional generation time_{iter}: {time.time() - start_time:.4f}")
|
|
print(f"conditional length of generated_output: {generated_output.shape[1]}")
|
|
self._make_midi_from_generated_output(generated_output, iter+num_uncond_generation, i, condition)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
def _rename_dict(self, adict, prefix='train'):
|
|
'''
|
|
Renames the keys in a dictionary by adding a prefix.
|
|
'''
|
|
keys = list(adict.keys())
|
|
for key in keys:
|
|
adict[f'{prefix}.{key}'] = adict.pop(key)
|
|
return dict(adict)
|
|
|
|
class LanguageModelTrainer4REMI(LanguageModelTrainer):
|
|
def __init__(self, model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config):
|
|
super().__init__(model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config)
|
|
|
|
def _get_loss_pred_from_single_batch(self, batch, valid=False):
|
|
segment, mask, caption,encoded_caption = batch
|
|
input_seq, target = segment[:, :-1], segment[:, 1:]
|
|
input_seq = input_seq.to(self.device)
|
|
target = target.to(self.device)
|
|
mask = mask[:, :-1].to(self.device)
|
|
if self.use_fp16:
|
|
with self.accelerator.autocast():
|
|
logits = self.model(input_seq, target)
|
|
if not valid:
|
|
total_loss, loss_dict = self.loss_fn(logits, target, mask, None)
|
|
return total_loss, logits, {'total':total_loss.item()}
|
|
else:
|
|
total_loss, loss_dict = self.loss_fn(logits, target, mask, self.vocab)
|
|
loss_dict['total'] = total_loss.item()
|
|
return total_loss, logits, loss_dict
|
|
else:
|
|
logits = self.model(input_seq, target)
|
|
if not valid:
|
|
total_loss, loss_dict = self.loss_fn(logits, target, mask, None)
|
|
return total_loss, logits, {'total':total_loss.item()}
|
|
else:
|
|
total_loss, loss_dict = self.loss_fn(logits, target, mask, self.vocab)
|
|
loss_dict['total'] = total_loss.item()
|
|
return total_loss, logits, loss_dict
|
|
|
|
def _get_valid_loss_and_acc_from_batch(self, batch, train=False):
|
|
segment, mask, caption,encoded_caption = batch
|
|
mask = mask[:, :-1]
|
|
_, target = segment[:, :-1], segment[:, 1:]
|
|
loss, logits, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
|
|
prob = torch.softmax(logits, dim=-1)
|
|
num_nonmask_tokens = torch.sum(mask) # [b, t]
|
|
target = target.to(self.device) # [b, t]
|
|
mask = mask.to(self.device)
|
|
|
|
prob_with_mask = torch.argmax(prob, dim=-1) * mask # [b, t]
|
|
shifted_tgt_with_mask = target * mask # [b, t]
|
|
|
|
correct_guess_by_feature = defaultdict(int)
|
|
num_tokens_by_feature = defaultdict(int)
|
|
tokens_idx = prob_with_mask.flatten(0,1) # [b*t]
|
|
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
|
|
if self.vocab.encoding_scheme == 'remi':
|
|
eos_idx = 2
|
|
for feature in self.vocab.feature_list:
|
|
feature_mask = self.vocab.total_mask[feature].to(self.device) # [327,]
|
|
mask_for_target = feature_mask[answers_idx] # [b*t]
|
|
if feature == 'type': # because Bar token is 0, we need to add 1 to calculate accuracy
|
|
valid_pred = (tokens_idx+1) * mask_for_target
|
|
valid_answers = (answers_idx+1) * mask_for_target
|
|
eos_mask = valid_answers != eos_idx # because EOS is also working as a padding
|
|
correct_guess_by_feature[feature] += torch.sum(valid_pred[eos_mask] == valid_answers[eos_mask]).item() - torch.sum(mask_for_target[eos_mask] == 0).item()
|
|
num_tokens_by_feature[feature] += torch.sum(mask_for_target[eos_mask]).item()
|
|
else:
|
|
valid_pred = tokens_idx * mask_for_target # [b, t]
|
|
valid_answers = answers_idx * mask_for_target # [b, t]
|
|
correct_guess_by_feature[feature] += torch.sum(valid_pred == valid_answers).item() - torch.sum(mask_for_target == 0).item()
|
|
num_tokens_by_feature[feature] += torch.sum(mask_for_target).item()
|
|
validation_loss = loss.item() * num_nonmask_tokens.item()
|
|
return validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature
|
|
|
|
def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess_by_feature):
|
|
total_num_correct_guess = 0
|
|
total_num_tokens = 0
|
|
acc_dict = {}
|
|
for feature, num_correct_guess in num_correct_guess_by_feature.items():
|
|
if feature == 'type':
|
|
continue
|
|
total_num_correct_guess += num_correct_guess
|
|
total_num_tokens += num_tokens_by_feature[feature]
|
|
if num_tokens_by_feature[feature] == 0:
|
|
continue
|
|
acc_dict[f"{feature}_acc"] = num_correct_guess / num_tokens_by_feature[feature]
|
|
total_accuracy = total_num_correct_guess / total_num_tokens
|
|
acc_dict['total_acc'] = total_accuracy
|
|
return acc_dict
|
|
|
|
def validate(self, external_loader=None):
|
|
'''
|
|
total_num_tokens: for calculating loss, nonmask tokens
|
|
total_num_valid_tokens: for calculating accuracy, valid tokens
|
|
'''
|
|
if external_loader and isinstance(external_loader, DataLoader):
|
|
loader = external_loader
|
|
print('An arbitrary loader is used instead of Validation loader')
|
|
else:
|
|
loader = self.valid_loader
|
|
|
|
self.model.eval()
|
|
total_validation_loss = 0
|
|
total_num_tokens = 0
|
|
total_num_valid_tokens = 0
|
|
total_num_correct_guess = 0
|
|
validation_metric_dict = defaultdict(float)
|
|
total_num_tokens_by_feature = defaultdict(int)
|
|
total_num_correct_guess_dict = defaultdict(int)
|
|
with torch.inference_mode():
|
|
for num_iter, batch in enumerate(tqdm(loader, leave=False)):
|
|
if num_iter == len(self.valid_loader):
|
|
if loader is not self.valid_loader: # when validate with train_loader
|
|
break
|
|
validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, num_correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch)
|
|
total_validation_loss += validation_loss
|
|
total_num_tokens += num_nonmask_tokens.item()
|
|
for key, num_tokens in num_tokens_by_feature.items():
|
|
total_num_tokens_by_feature[key] += num_tokens
|
|
if key == 'type':
|
|
continue
|
|
total_num_valid_tokens += num_tokens # num tokens are all the same for each musical type, torch.sum(mask)
|
|
for key, num_correct_guess in num_correct_guess_by_feature.items():
|
|
total_num_correct_guess_dict[key] += num_correct_guess
|
|
if key == 'type':
|
|
continue
|
|
total_num_correct_guess += num_correct_guess
|
|
for key, value in loss_dict.items():
|
|
if key == 'total':
|
|
validation_metric_dict[key] += value * num_nonmask_tokens
|
|
else:
|
|
feature_name = key.split('_')[0]
|
|
validation_metric_dict[key] += value * num_tokens_by_feature[feature_name]
|
|
|
|
for key in validation_metric_dict.keys():
|
|
if key == 'total':
|
|
validation_metric_dict[key] /= total_num_tokens
|
|
else:
|
|
feature_name = key.split('_')[0]
|
|
if total_num_tokens_by_feature[feature_name] == 0:
|
|
continue
|
|
validation_metric_dict[key] /= total_num_tokens_by_feature[feature_name]
|
|
|
|
for key in total_num_tokens_by_feature.keys():
|
|
num_tokens = total_num_tokens_by_feature[key]
|
|
num_correct = total_num_correct_guess_dict[key]
|
|
if num_tokens == 0:
|
|
continue
|
|
validation_metric_dict[f'{key}_acc'] = num_correct / num_tokens
|
|
return total_validation_loss / total_num_tokens, total_num_correct_guess / total_num_valid_tokens, validation_metric_dict
|
|
|
|
class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
|
|
def __init__(self, model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config):
|
|
super().__init__(model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config)
|
|
|
|
'''
|
|
About ignore_token and conti_token:
|
|
During validation, tokens with this "conti" value are ignored when calculating accuracy or other metrics,
|
|
ensuring that repeated values don't unfairly skew the results.
|
|
This is especially relevant for features like beat, chord, tempo, and instrument where repeated tokens may have a specific musical meaning.
|
|
|
|
We used ignore_token and conti_token to fairly compare compound token based encoding with REMI encoding.
|
|
'''
|
|
|
|
def _get_num_valid_and_correct_tokens(self, prob, ground_truth, mask, ignore_token=None, conti_token=None):
|
|
valid_prob = torch.argmax(prob, dim=-1) * mask
|
|
valid_ground_truth = ground_truth * mask
|
|
|
|
if ignore_token is None and conti_token is None:
|
|
num_valid_tokens = torch.sum(mask)
|
|
num_correct_tokens = torch.sum(valid_prob == valid_ground_truth) - torch.sum(mask == 0)
|
|
elif ignore_token is not None and conti_token is None:
|
|
ignore_mask = valid_ground_truth != ignore_token # batch x seq_len
|
|
num_valid_tokens = torch.sum(ignore_mask)
|
|
num_correct_tokens = torch.sum(valid_prob[ignore_mask] == valid_ground_truth[ignore_mask]) # by using mask, the tensor becomes 1d
|
|
elif ignore_token is not None and conti_token is not None:
|
|
ignore_conti_mask = (valid_ground_truth != ignore_token) & (valid_ground_truth != conti_token)
|
|
num_valid_tokens = torch.sum(ignore_conti_mask)
|
|
num_correct_tokens = torch.sum(valid_prob[ignore_conti_mask] == valid_ground_truth[ignore_conti_mask])
|
|
return num_correct_tokens.item(), num_valid_tokens.item()
|
|
|
|
def _get_loss_pred_from_single_batch(self, batch, valid=False):
|
|
# print(batch)
|
|
segment, mask, caption,encoded_caption = batch
|
|
input_seq, target = segment[:, :-1], segment[:, 1:]
|
|
input_seq = input_seq.to(self.device)
|
|
target = target.to(self.device)
|
|
mask = mask[:, :-1].to(self.device)
|
|
encoded_caption = encoded_caption.to(self.device)
|
|
if self.use_fp16:
|
|
if self.config.use_diff is True:
|
|
with self.accelerator.autocast():
|
|
# breakpoint()
|
|
(logits_dict, (masked_indices, p_mask)),input_dict = self.model(input_seq, target,context=encoded_caption)
|
|
if self.config.use_dispLoss == True:
|
|
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid, input_dict=input_dict,lambda_weight=self.config.lambda_weight,tau=self.config.tau)
|
|
else:
|
|
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid)
|
|
else:
|
|
with self.accelerator.autocast():
|
|
logits_dict,_ = self.model(input_seq, target,context=encoded_caption)
|
|
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, valid)
|
|
else:
|
|
if self.config.use_diff is True:
|
|
# breakpoint()
|
|
if self.config.use_dispLoss == True:
|
|
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid, input_dict=input_dict,lambda_weight=self.config.lambda_weight)
|
|
else:
|
|
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid)
|
|
else:
|
|
logits_dict, input_Dict = self.model(input_seq, target,context=encoded_caption)
|
|
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, valid)
|
|
if valid:
|
|
loss_dict['total'] = total_loss.item()
|
|
else:
|
|
loss_dict = {'total':total_loss.item()}
|
|
return total_loss, logits_dict, loss_dict
|
|
|
|
def _get_valid_loss_and_acc_from_batch(self, batch, train=False):
|
|
'''
|
|
in this method, valid means handled with both ignore token and mask
|
|
when valid tokens with only mask, it is called num_nonmask_tokens
|
|
|
|
input_seq, target: batch x seq_len x num_features
|
|
mask: batch x seq_len, 0 for padding
|
|
prob: batch x seq_len x total_vocab_size
|
|
'''
|
|
segment, mask, caption,encoded_caption = batch
|
|
input_seq, target = segment[:, :-1], segment[:, 1:]
|
|
total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
|
|
probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()}
|
|
num_nonmask_tokens = torch.sum(mask)
|
|
input_seq = input_seq.to(self.device)
|
|
target = add_conti_in_valid(target, self.config.nn_params.encoding_scheme).to(self.device)
|
|
mask = mask[:, :-1].to(self.device)
|
|
|
|
correct_guess_by_feature = defaultdict(int)
|
|
num_tokens_by_feature = defaultdict(int)
|
|
for idx, key in enumerate(self.vocab.feature_list):
|
|
if key == 'type' or key == 'timesig' :
|
|
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None)
|
|
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
|
|
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
|
|
elif key == 'beat':
|
|
# NB's beat vocab has Ignore and CONTI token
|
|
# CP's beat vocab has Ignore and BAR token, we exclude BAR token in accuracy calculation for parity with NB
|
|
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
|
|
else:
|
|
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=None)
|
|
correct_guess_by_feature[key] = num_correct_tokens
|
|
num_tokens_by_feature[key] = num_valid_tokens
|
|
validation_loss = total_loss.item() * num_nonmask_tokens.item()
|
|
return validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature
|
|
|
|
def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess_by_feature):
|
|
total_num_correct_guess = 0
|
|
total_num_tokens = 0
|
|
acc_dict = {}
|
|
for feature, num_correct_guess in num_correct_guess_by_feature.items():
|
|
if feature == 'type':
|
|
continue
|
|
total_num_correct_guess += num_correct_guess
|
|
total_num_tokens += num_tokens_by_feature[feature]
|
|
acc_dict[f"{feature}_acc"] = num_correct_guess / num_tokens_by_feature[feature]
|
|
total_accuracy = total_num_correct_guess / total_num_tokens
|
|
acc_dict['total_acc'] = total_accuracy
|
|
return acc_dict
|
|
|
|
def validate(self, external_loader=None):
|
|
if external_loader and isinstance(external_loader, DataLoader):
|
|
loader = external_loader
|
|
print('An arbitrary loader is used instead of Validation loader')
|
|
else:
|
|
loader = self.valid_loader
|
|
|
|
self.model.eval()
|
|
total_validation_loss = 0
|
|
total_num_correct_guess = 0
|
|
total_num_tokens = 0
|
|
total_num_valid_tokens = 0
|
|
validation_metric_dict = defaultdict(float)
|
|
total_num_tokens_by_feature = defaultdict(int)
|
|
total_num_correct_guess_dict = defaultdict(int)
|
|
|
|
with torch.inference_mode():
|
|
'''
|
|
mask is used to calculate loss, accuracy
|
|
validation_loss: sum of loss for valid tokens conditioned on mask
|
|
num_nonmask_tokens: sum of tokens conditioned on mask
|
|
num_tokens_by_feature: sum of valid tokens(handle ignore) for each musical features
|
|
num_correct_guess_by_feature: sum of correct tokens(handle ignore) for each musical features
|
|
'''
|
|
for num_iter, batch in tqdm(enumerate(loader), leave=False):
|
|
if num_iter == len(self.valid_loader):
|
|
if loader is not self.valid_loader: # when validate with train_loader
|
|
break
|
|
validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, num_correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch)
|
|
total_validation_loss += validation_loss
|
|
total_num_tokens += num_nonmask_tokens
|
|
for key, num_tokens in num_tokens_by_feature.items():
|
|
total_num_tokens_by_feature[key] += num_tokens
|
|
if key == 'type': # because cp and nb have different number of type tokens, we don't want to calculate accuracy for type token
|
|
continue
|
|
total_num_valid_tokens += num_tokens # num tokens are all the same for each musical type, torch.sum(mask)
|
|
for key, num_correct_guess in num_correct_guess_by_feature.items():
|
|
total_num_correct_guess_dict[key] += num_correct_guess
|
|
if key == 'type':
|
|
continue
|
|
total_num_correct_guess += num_correct_guess
|
|
for key, value in loss_dict.items():
|
|
if key == 'total':
|
|
validation_metric_dict[key] += value * num_nonmask_tokens
|
|
else:
|
|
# if torch.isnan(value): # in case num valid tokens is 0 because of mask
|
|
# continue
|
|
feature_name = key.split('_')[0]
|
|
validation_metric_dict[key] += value * num_tokens_by_feature[feature_name]
|
|
|
|
for key in validation_metric_dict.keys():
|
|
if key == 'total':
|
|
validation_metric_dict[key] /= total_num_tokens
|
|
else:
|
|
feature_name = key.split('_')[0]
|
|
if total_num_tokens_by_feature[feature_name] == 0:
|
|
continue
|
|
validation_metric_dict[key] /= total_num_tokens_by_feature[feature_name]
|
|
for (key_t, num_tokens), (key_c, num_correct) in zip(total_num_tokens_by_feature.items(), total_num_correct_guess_dict.items()):
|
|
validation_metric_dict[f'{key_c}_acc'] = num_correct / num_tokens
|
|
|
|
return total_validation_loss / (total_num_tokens + 1), total_num_correct_guess / (1+total_num_valid_tokens), validation_metric_dict
|
|
|
|
def _make_midi_from_generated_output(self, generated_output, iter, seed, condition=None):
|
|
if self.config.data_params.first_pred_feature != 'type' and self.config.nn_params.encoding_scheme == 'nb':
|
|
generated_output = reverse_shift_and_pad_for_tensor(generated_output, self.config.data_params.first_pred_feature)
|
|
if condition is not None:
|
|
path_addition = "cond_"
|
|
else:
|
|
path_addition = ""
|
|
|
|
# save generated_output as pickle
|
|
with open(self.valid_out_dir / f"{path_addition}generated_output_{iter}_seed_{seed}.pkl", 'wb') as f:
|
|
pickle.dump(generated_output, f)
|
|
self.midi_decoder(generated_output, self.valid_out_dir / f"{path_addition}midi_decoded_{iter}_seed_{seed}.mid")
|
|
if self.make_log and self.infer_and_log:
|
|
log_dict = {}
|
|
log_dict[f'{path_addition}gen_score'] = wandb.Image(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.png'))
|
|
log_dict[f'{path_addition}gen_audio'] = wandb.Audio(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.mp3'))
|
|
wandb.log(log_dict, step=(iter+seed))
|
|
print(f"{path_addition}inference is logged: Iter {iter} / seed {seed}")
|
|
|