Files
MIDIFoundationModel/Amadeus/trainer_accelerate.py
2025-11-27 15:44:17 +08:00

1019 lines
51 KiB
Python

from calendar import EPOCH, c
from multiprocessing import context
import time
import pickle
import os
from pathlib import Path
from typing import Union
from datetime import datetime
from omegaconf import OmegaConf
import random
import itertools
import torch
import torchaudio
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler, Sampler
# import accelerate
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import set_seed
#======================================================================
import wandb
from collections import defaultdict
from tqdm.auto import tqdm
from .model_zoo import AmadeusModel
from .symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
from .symbolic_encoding.data_utils import TuneCompiler
from .symbolic_encoding.decoding_utils import MidiDecoder4REMI
from .evaluation_utils import add_conti_in_valid
from .train_utils import NLLLoss4REMI
os.environ['WANDB_INIT_TIMEOUT'] = '600'
os.environ["WANDB_BASE_URL"] = "https://api.bandw.top"
from data_representation.vocab_utils import LangTokenVocab
class InfiniteSampler(Sampler):
def __init__(self, data_source, shuffle=True):
self.data_source = data_source
self.shuffle = shuffle
self.indices = list(range(len(data_source)))
if self.shuffle:
random.shuffle(self.indices)
self.infinite_iterator = itertools.cycle(self.indices)
def __iter__(self):
return self.infinite_iterator
def __len__(self):
return None # 表示无限长度
class LanguageModelTrainer:
def __init__(
self,
model: AmadeusModel, # The language model for music generation
optimizer: torch.optim.Optimizer, # Optimizer for updating model weights
scheduler: torch.optim.lr_scheduler._LRScheduler, # Learning rate scheduler
loss_fn: NLLLoss4REMI, # Loss function to compute the error
midi_decoder: MidiDecoder4REMI, # Decoder to convert model output into MIDI format
train_set: TuneCompiler, # Training dataset
valid_set: TuneCompiler, # Validation dataset
save_dir: str, # Directory to save models and logs
vocab: LangTokenVocab, # Vocabulary for tokenizing sequences
use_ddp: bool, # Whether to use Distributed Data Parallel (DDP)
use_fp16: bool, # Whether to use mixed-precision training (FP16)
world_size: int, # Total number of devices for distributed training
batch_size: int, # Batch size for training
infer_target_len: int, # Target length for inference generation
gpu_id: int, # GPU device ID for computation
sampling_method: str, # Sampling method for sequence generation
sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings)
# model_checkpoint="wandb/run-20251114_151512-k21rnynj/files/checkpoints/iter104999_loss0.2490.pt", # Path to a pre-trained model checkpoint (optional)
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
):
# Save model, optimizer, and other configurations
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_fn = loss_fn
self.valid_set = valid_set
self.vocab = vocab
self.use_ddp = use_ddp
self.world_size = world_size
self.batch_size = batch_size
self.gpu_id = gpu_id
self.sampling_method = sampling_method
self.sampling_threshold = sampling_threshold
self.sampling_temperature = sampling_temperature
self.config = config
self.last_iter = 0
# Load pre-trained model if provided
if model_checkpoint:
# parse the model checkpoint iter
if isinstance(model_checkpoint, str):
if model_checkpoint.endswith('.pt'):
self.last_iter = int(model_checkpoint.split('/')[-1].split('_')[0][4:])
checkpoint = torch.load(model_checkpoint, map_location='cpu')
# print state dict keys
print("Loading model checkpoint from", model_checkpoint)
if isinstance(self.model, DDP):
self.model.module.load_state_dict(checkpoint['model'], strict=False)
else:
self.model.load_state_dict(checkpoint['model'], strict=False)
# Training hyperparameters from config
self.grad_clip = config.train_params.grad_clip
self.num_cycles_for_inference = config.train_params.num_cycles_for_inference
self.num_cycles_for_model_checkpoint = config.train_params.num_cycles_for_model_checkpoint
self.iterations_per_training_cycle = config.train_params.iterations_per_training_cycle
self.iterations_per_validation_cycle = config.train_params.iterations_per_validation_cycle
self.make_log = config.general.make_log
self.num_uncond_generation = config.inference_params.num_uncond_generation
self.num_cond_generation = config.inference_params.num_cond_generation
self.num_max_seq_len = infer_target_len
self.infer_and_log = config.general.infer_and_log
self.valid_loader = self.generate_data_loader(self.valid_set, shuffle=False, drop_last=True)
# gradient accumulation
self.gradient_accumulation_steps = config.train_params.gradient_accumulation_steps
# Set up mixed-precision training (FP16) if enabled
if use_fp16:
self.use_fp16 = True
else:
self.use_fp16 = False
# Set up Distributed Data Parallel (DDP) if required
if use_ddp:
# prepare using accelerator
if self.use_fp16:
self.accelerator = Accelerator(mixed_precision='bf16',
step_scheduler_with_optimizer=False,
gradient_accumulation_steps=self.gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
else:
self.accelerator = Accelerator(
gradient_accumulation_steps=self.gradient_accumulation_steps,
step_scheduler_with_optimizer=False,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
with self.accelerator.main_process_first():
self.train_set = train_set
self.train_loader = self.generate_data_loader(self.train_set, shuffle=False, drop_last=False)
self.accelerator.wait_for_everyone()
self.accelerator.print(f"Using {self.world_size} GPUs for training")
self.model, self.optimizer, self.scheduler, self.train_loader = self.accelerator.prepare(
self.model, self.optimizer, self.scheduler, self.train_loader
)
self.accelerator.wait_for_everyone()
# self.accelerator.init_trackers("nested_music_transformer", config)
set_seed(42)
self.device = self.accelerator.device
self.model.to(self.device)
# set up for logging
if self.accelerator.is_main_process:
save_dir = self.setup_log(config)
print("savwe",save_dir)
# Create directory for saving models and logs
self.save_dir = Path(save_dir)
self.save_dir.mkdir(exist_ok=True, parents=True)
self.set_save_out()
else:
self.train_set = train_set
# Create data loaders for training and validation sets
self.train_loader = self.generate_data_loader(train_set, shuffle=False, drop_last=True)
self.valid_loader = self.generate_data_loader(valid_set, shuffle=True, drop_last=True)
save_dir = self.setup_log(config)
# Create directory for saving models and logs
self.save_dir = Path(save_dir)
self.save_dir.mkdir(exist_ok=True, parents=True)
self.set_save_out()
self.device = config.train_params.device
self.model.to(self.device)
# Initialize tracking metrics
self.best_valid_accuracy = 0
self.best_valid_loss = 100
self.training_loss = []
self.validation_loss = []
self.validation_acc = []
self.midi_decoder = midi_decoder
def generate_experiment_name(self, config):
# add base hyperparameters to the experiment name
dataset_name = config.dataset
encoding_name = config.nn_params.encoding_scheme
num_features = config.nn_params.num_features
input_embedder_name = config.nn_params.input_embedder_name
sub_decoder_name = config.nn_params.sub_decoder_name
batch_size = config.train_params.batch_size
num_layers = config.nn_params.main_decoder.num_layer
input_length = config.train_params.input_length
first_pred_feature = config.data_params.first_pred_feature
# Add target hyperparameters to the experiment name
# dropout
main_dropout = config.nn_params.model_dropout
# learning rate
lr_decay_rate = config.train_params.decay_step_rate
time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Combine the information into a single string for the experiment name
# experiment_name = f"{time}_{dataset_name}_{encoding_name}{num_features}_{input_embedder_name}_{sub_decoder_name}_firstpred:{first_pred_feature}_inputlen{input_length}_nlayer{num_layers}_batch{batch_size}\
# _dropout{main_dropout}_lrdecay{lr_decay_rate}"
experiment_name = f"{time}_{dataset_name}_{encoding_name}{num_features}_{sub_decoder_name}_firstpred:{first_pred_feature}_inputlen{input_length}_nlayer{num_layers}_batch{batch_size}"
return experiment_name
def collate_fn(self, batch):
"""
Custom collate function to handle variable-length sequences in a batch.
It pads sequences to the maximum length in the batch and returns a tuple of padded sequences and their lengths.
"""
# Unzip the batch into segments, masks, captions, and encoded captions
segments, masks, captions, encoded_captions = zip(*batch)
# print("collate_fn",len(segments),len(masks),len(captions),len(encoded_captions))
# # Pad the segments and masks to the maximum length in the batch
# padded_segments = torch.nn.utils.rnn.pad_sequence(segments, batch_first=True)
# padded_masks = torch.nn.utils.rnn.pad_sequence(masks, batch_first=True)
# # Return padded segments and masks along with captions and encoded captions
segments = torch.stack(segments, dim=0)
masks = torch.stack(masks, dim=0)
print(captions)
print(encoded_captions)
# captions = torch.stack(captions, dim=0)
# encoded_captions = torch.stack(encoded_captions, dim=0)
return segments, masks, captions, encoded_captions
# return padded_segments, padded_masks, captions, encoded_captions
def setup_log(self, config):
if self.accelerator.is_main_process:
if config.general.make_log:
experiment_name =self.generate_experiment_name(config)
wandb.init(
project="Acce_Music_Transformer",
name=experiment_name,
config=OmegaConf.to_container(config)
)
# 保存配置到 WANDB 根目录
config_path = Path(wandb.run.dir) / "config.yaml"
OmegaConf.save(config, config_path) # 关键代码
save_dir = Path(wandb.run.dir) / "checkpoints"
save_dir.mkdir(exist_ok=True, parents=True)
else:
now = datetime.now()
save_dir = Path('wandb/debug/checkpoints') / now.strftime('%y-%m-%d')
save_dir.mkdir(exist_ok=True, parents=True)
# 保存配置到调试目录
config_path = save_dir / "config.yaml"
OmegaConf.save(config, config_path) # 关键代码
return str(save_dir)
# Set up the output directories for saving MIDI results during inference
def set_save_out(self):
if self.accelerator.is_main_process:
# copy from latest folder in wandb/debug/checkpoints
target_folder = 'wandb/debug/checkpoints'
latest_folder = sorted(Path(target_folder).iterdir(), key=os.path.getmtime)[-1]
# get files in the latest folder
files = [f for f in latest_folder.iterdir() if f.is_file()]
# copy files to the save_dir
for file in files:
# copy the file to the save_dir
target_file = self.save_dir / file.name
if not target_file.exists():
os.system(f'cp {file} {target_file}')
if self.infer_and_log:
self.valid_out_dir = self.save_dir / 'valid_out'
os.makedirs(self.valid_out_dir, exist_ok=True)
# Save the current model and optimizer state
def save_model(self, path):
if isinstance(self.model, DDP):
torch.save({'model': self.model.module.state_dict(), 'optim': self.optimizer.state_dict()}, path)
else:
torch.save({'model': self.model.state_dict(), 'optim': self.optimizer.state_dict()}, path)
# Generate the data loader for either training or validation datasets
def generate_data_loader(self, dataset, shuffle=False, drop_last=False) -> DataLoader:
return DataLoader(dataset, shuffle=shuffle, batch_size=self.batch_size, drop_last=drop_last,collate_fn=None, pin_memory=True,num_workers=4, persistent_workers=True, prefetch_factor=2, worker_init_fn=None)
# Training function based on a given number of iterations
def accelerate_train_by_num_iter(self, num_iters):
# generator = iter(self.train_loader)
pbar = tqdm(total=num_iters, desc='Training', unit='iteration', leave=False)
completed_steps = self.last_iter
# save init model
while completed_steps < num_iters:
total_loss = 0
current_loss = 0
for i, batch in enumerate(self.train_loader):
# gradient accumulation
with self.accelerator.accumulate(self.model):
# Start time for the training step,only for main process
start_time = time.time()
# Tra\in the model on a single batch
# loss_value, loss_dict = self._accelerate_train_by_single_batch(batch)
loss, _, loss_dict = self._get_loss_pred_from_single_batch(batch)
total_loss += loss.detach().float()
current_loss = loss.detach().float()
# loss.backward()
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.unscale_gradients(self.optimizer)
if self.accelerator.sync_gradients:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
# self.accelerator.clip_grad_norm_(self.model.parameters(), self.grad_clip)
if not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and self.scheduler is not None:
self.scheduler.step()
self.optimizer.step()
self.optimizer.zero_grad()
if self.accelerator.sync_gradients:
# update progress bar
loss_value = loss.item()
# log in main process
completed_steps += 1
# if self.accelerator.is_main_process:
loss_dict['time'] = time.time() - start_time
loss_dict['lr'] = self.optimizer.param_groups[0]['lr']
loss_dict = self._rename_dict(loss_dict, 'train')
self.training_loss.append(loss_value)
if self.accelerator.is_main_process:
pbar.update(1)
pbar.set_postfix(loss=loss_value, lr=self.optimizer.param_groups[0]['lr'])
# save iter1 checkpoint
if completed_steps == 1 and self.accelerator.is_main_process:
self.save_model(self.save_dir / f'iter{completed_steps}_loss{current_loss:.4f}.pt')
# Log training loss at the specified training cycle
if (completed_steps + 1) % self.iterations_per_training_cycle == 0 and self.make_log and self.accelerator.is_main_process:
wandb.log(loss_dict, step=completed_steps)
# Log training accuracy periodically
if (completed_steps + 1) % (self.iterations_per_training_cycle * 3) == 0 and self.make_log:
validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch, train=True)
train_metric_dict = self._get_train_accuracy(num_nonmask_tokens, num_tokens_by_feature, correct_guess_by_feature)
train_metric_dict.update(loss_dict)
train_metric_dict = self._rename_dict(train_metric_dict, 'train')
if self.accelerator.is_main_process:
wandb.log(train_metric_dict, step=completed_steps)
# delete variables to avoid memory leakages
del validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature, train_metric_dict
# Perform validation at the specified interval
if (completed_steps + 1) % self.iterations_per_validation_cycle == 0:
self.model.eval()
validation_loss, validation_acc, validation_metric_dict = self.validate()
validation_metric_dict['acc'] = validation_acc
validation_metric_dict = self._rename_dict(validation_metric_dict, 'valid')
if self.make_log and self.accelerator.is_main_process:
wandb.log(validation_metric_dict, step=completed_steps)
self.validation_loss.append(validation_loss)
self.validation_acc.append(validation_acc)
self.best_valid_loss = min(validation_loss, self.best_valid_loss)
# Perform inference and logging after a certain number of cycles
if (completed_steps + 1) % (self.num_cycles_for_inference * self.iterations_per_validation_cycle) == 0 and self.infer_and_log and self.accelerator.is_main_process:
self.inference_and_log(i, self.num_uncond_generation, self.num_cond_generation, self.num_max_seq_len)
# Save a model checkpoint periodically
if (completed_steps + 1) % (self.iterations_per_validation_cycle * self.num_cycles_for_model_checkpoint) == 0 and self.accelerator.is_main_process:
self.accelerator.print(f"Saving model checkpoint at iter {completed_steps}")
self.save_model(self.save_dir / f'iter{completed_steps}_loss{validation_loss:.4f}.pt')
self.model.train()
# delete variables to avoid memory leakages
del validation_acc, validation_metric_dict
# else:
# self.accelerator.wait_for_everyone()
# Save the model checkpoint at the end of each epoch
if self.accelerator.is_main_process:
print(f"Saving model checkpoint at iter {completed_steps}")
# Save the model state
self.save_model(self.save_dir / f"iter{completed_steps}_loss{current_loss:.4f}.pt")
# Save the final model after training
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
print("saving last checkpoint")
self.save_model(self.save_dir / f'checkpoint_last.pt')
# same as above but for accelerate
def _accelarate_get_loss_pred_from_single_batch(self, batch):
"""
Computes the loss and predictions for a single batch of data.
Args:
batch: A batch of data, typically containing input sequences, targets, and masks.
Returns:
loss: The computed loss for the batch.
logits: The raw model predictions (logits).
loss_dict: A dictionary containing the total loss.
The method:
- Separates the input sequences and target sequences from the batch.
- Moves the data to the appropriate device.
- Applies mixed precision (FP16) if applicable.
- Computes the logits using the model and calculates the loss using the specified loss function.
"""
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
input_seq = input_seq.to(self.device)
target = target.to(self.device)
mask = mask[:, :-1].to(self.device)
if self.use_fp16:
with torch.cuda.amp.autocast():
logits = self.model(input_seq, target)
loss = self.loss_fn(logits, target, mask)
else:
logits = self.model(input_seq, None)
loss = self.loss_fn(logits, target, mask)
loss_dict = {'total': loss.item()}
return loss, logits, loss_dict
def _train_by_single_batch(self, batch):
"""
Trains the model on a single batch of data.
Args:
batch: A batch of data, typically consisting of input sequences and corresponding targets.
Returns:
loss.item(): The total loss for this batch.
loss_dict: A dictionary containing information about the loss and other relevant metrics.
The method:
- Calls `_get_loss_pred_from_single_batch` to compute the loss and predictions.
- Resets the optimizer's gradients.
- Depending on whether mixed precision (FP16) is used, it scales the loss and applies gradient clipping before stepping the optimizer.
- Updates the learning rate scheduler if applicable.
- Records the time taken for the training step and the current learning rate in the `loss_dict`.
"""
start_time = time.time()
loss, _, loss_dict = self._get_loss_pred_from_single_batch(batch)
self.optimizer.zero_grad()
if self.use_fp16:
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.optimizer.step()
if not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and self.scheduler is not None:
self.scheduler.step()
loss_dict['time'] = time.time() - start_time
loss_dict['lr'] = self.optimizer.param_groups[0]['lr']
return loss.item(), loss_dict
def _get_loss_pred_from_single_batch(self, batch):
"""
Computes the loss and predictions for a single batch of data.
Args:
batch: A batch of data, typically containing input sequences, targets, and masks.
Returns:
loss: The computed loss for the batch.
logits: The raw model predictions (logits).
loss_dict: A dictionary containing the total loss.
The method:
- Separates the input sequences and target sequences from the batch.
- Moves the data to the appropriate device.
- Applies mixed precision (FP16) if applicable.
- Computes the logits using the model and calculates the loss using the specified loss function.
"""
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
input_seq = input_seq.to(self.device)
target = target.to(self.device)
mask = mask[:, :-1].to(self.device)
if self.use_fp16:
with self.accelerator.autocast():
logits = self.model(input_seq, target)
loss = self.loss_fn(logits, target, mask)
else:
logits = self.model(input_seq, None)
loss = self.loss_fn(logits, target, mask)
loss_dict = {'total': loss.item()}
return loss, logits, loss_dict
def _get_valid_loss_and_acc_from_batch(self, batch, train=False):
"""
Computes validation loss and accuracy from a single batch.
Args:
batch: A batch of data, typically containing input sequences, targets, and masks.
train (bool): Indicator whether the function is being used in training mode.
Returns:
validation_loss: Total validation loss for the batch.
num_tokens: The number of valid tokens in the batch.
loss_dict: A dictionary containing the loss and relevant metrics.
None: Placeholder for future implementation.
num_correct_guess: Number of correctly predicted tokens.
The method:
- Calls `_get_loss_pred_from_single_batch` to compute the loss and predictions.
- Computes token-level accuracy by comparing predicted tokens with the targets.
"""
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
loss, logits, loss_dict = self._get_loss_pred_from_single_batch(batch)
prob = torch.softmax(logits, dim=-1)
num_tokens = torch.sum(mask)
target = target.to(self.device)
mask = mask[:, :-1].to(self.device)
selected_tokens = torch.argmax(prob, dim=-1) * mask
shifted_tgt_with_mask = target * mask
num_correct_guess = torch.sum(selected_tokens == shifted_tgt_with_mask) - torch.sum(mask == 0)
validation_loss = loss.item() * num_tokens
num_correct_guess = num_correct_guess.item()
return validation_loss, num_tokens, loss_dict, None, num_correct_guess
def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess):
"""
Computes training accuracy.
Args:
num_tokens: Total number of tokens processed.
num_tokens_by_feature: Number of tokens for each feature (not used here).
num_correct_guess: Number of correctly predicted tokens.
Returns:
Training accuracy, computed as the ratio of correct predictions to the total number of tokens.
"""
return num_correct_guess / num_tokens
def validate(self, external_loader=None):
"""
Validates the model on a dataset.
Args:
external_loader (DataLoader): If provided, an external DataLoader can be used for validation.
Returns:
total_validation_loss: Average validation loss over all batches.
total_num_correct_guess: Total number of correct predictions divided by the number of tokens (accuracy).
validation_metric_dict: Dictionary of validation metrics averaged over all batches.
The method:
- Iterates through the validation data loader, calculating the loss and accuracy for each batch.
- Aggregates the results over all batches and returns the overall validation metrics.
"""
if external_loader and isinstance(external_loader, DataLoader):
loader = external_loader
print('An arbitrary loader is used instead of Validation loader')
else:
loader = self.valid_loader
self.model.eval()
total_validation_loss = 0
total_num_correct_guess = 0
total_num_tokens = 0
validation_metric_dict = defaultdict(float)
with torch.inference_mode():
for batch in tqdm(loader, leave=False):
validation_loss, num_tokens, loss_dict, _, num_correct_guess = self._get_valid_loss_and_acc_from_batch(batch)
total_validation_loss += validation_loss
total_num_tokens += num_tokens
total_num_correct_guess += num_correct_guess
for key, value in loss_dict.items():
validation_metric_dict[key] += value * num_tokens
for key in validation_metric_dict.keys():
validation_metric_dict[key] /= total_num_tokens
return total_validation_loss / total_num_tokens, total_num_correct_guess / total_num_tokens, validation_metric_dict
def _make_midi_from_generated_output(self, generated_output, iter, seed, condition=None):
"""
Generates a MIDI file and logs output from the generated sequence.
Args:
generated_output: The sequence of notes generated by the model.
iter: The current iteration of the training process.
seed: The seed used for generating the sequence.
condition: Optional condition input for generating conditional output.
The method:
- Converts the generated output into a MIDI file and logs it.
- Optionally logs additional error metrics and figures for analysis.
"""
if condition is not None:
path_addition = "cond_"
else:
path_addition = ""
with open(self.valid_out_dir / f"{path_addition}generated_output_{iter}_seed_{seed}.pkl", 'wb') as f:
pickle.dump(generated_output, f)
self.midi_decoder(generated_output, self.valid_out_dir / f"{path_addition}midi_decoded_{iter}_seed_{seed}.mid")
if self.make_log:
log_dict = {}
log_dict[f'{path_addition}gen_score'] = wandb.Image(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.png'))
log_dict[f'{path_addition}gen_audio'] = wandb.Audio(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.mp3'))
wandb.log(log_dict, step=(iter+seed))
print(f"{path_addition}inference is logged: Iter {iter} / seed {seed}")
return generated_output
@torch.inference_mode()
def inference_and_log(self, iter, num_uncond_generation=5, num_cond_generation=5, max_seq_len=10000):
"""
Generates and logs both unconditional and conditional output sequences.
Args:
iter: The current iteration.
num_uncond_generation: Number of unconditional sequences to generate.
num_cond_generation: Number of conditional sequences to generate.
max_seq_len: Maximum sequence length to generate.
The method:
- Generates unconditional and conditional sequences using the model's generation function.
- Converts the sequences into MIDI files and logs the generated results.
"""
self.model.eval()
for i in range(num_uncond_generation):
try:
start_time = time.time()
uncond_generated_output = self.model.module.generate(manual_seed=i, max_seq_len=max_seq_len, condition=None, \
sampling_method=self.sampling_method, threshold=self.sampling_threshold, temperature=self.sampling_temperature)
if len(uncond_generated_output) == 0: continue
print(f"unconditional generation time_{iter}: {time.time() - start_time:.4f}")
print(f"unconditional length of generated_output: {uncond_generated_output.shape[1]}")
self._make_midi_from_generated_output(uncond_generated_output, iter, i, None)
except Exception as e:
print(e)
condition_list = [x[1] for x in self.valid_set.data_list[:num_cond_generation] ]
for i in range(num_cond_generation):
condition = self.valid_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
try:
start_time = time.time()
generated_output = self.model.module.generate(manual_seed=i, max_seq_len=max_seq_len, condition=condition, \
sampling_method=self.sampling_method, threshold=self.sampling_threshold, temperature=self.sampling_temperature)
if len(generated_output) == 0: continue
print(f"conditional generation time_{iter}: {time.time() - start_time:.4f}")
print(f"conditional length of generated_output: {generated_output.shape[1]}")
self._make_midi_from_generated_output(generated_output, iter+num_uncond_generation, i, condition)
except Exception as e:
print(e)
def _rename_dict(self, adict, prefix='train'):
'''
Renames the keys in a dictionary by adding a prefix.
'''
keys = list(adict.keys())
for key in keys:
adict[f'{prefix}.{key}'] = adict.pop(key)
return dict(adict)
class LanguageModelTrainer4REMI(LanguageModelTrainer):
def __init__(self, model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config):
super().__init__(model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config)
def _get_loss_pred_from_single_batch(self, batch, valid=False):
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
input_seq = input_seq.to(self.device)
target = target.to(self.device)
mask = mask[:, :-1].to(self.device)
if self.use_fp16:
with self.accelerator.autocast():
logits = self.model(input_seq, target)
if not valid:
total_loss, loss_dict = self.loss_fn(logits, target, mask, None)
return total_loss, logits, {'total':total_loss.item()}
else:
total_loss, loss_dict = self.loss_fn(logits, target, mask, self.vocab)
loss_dict['total'] = total_loss.item()
return total_loss, logits, loss_dict
else:
logits = self.model(input_seq, target)
if not valid:
total_loss, loss_dict = self.loss_fn(logits, target, mask, None)
return total_loss, logits, {'total':total_loss.item()}
else:
total_loss, loss_dict = self.loss_fn(logits, target, mask, self.vocab)
loss_dict['total'] = total_loss.item()
return total_loss, logits, loss_dict
def _get_valid_loss_and_acc_from_batch(self, batch, train=False):
segment, mask, caption,encoded_caption = batch
mask = mask[:, :-1]
_, target = segment[:, :-1], segment[:, 1:]
loss, logits, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
prob = torch.softmax(logits, dim=-1)
num_nonmask_tokens = torch.sum(mask) # [b, t]
target = target.to(self.device) # [b, t]
mask = mask.to(self.device)
prob_with_mask = torch.argmax(prob, dim=-1) * mask # [b, t]
shifted_tgt_with_mask = target * mask # [b, t]
correct_guess_by_feature = defaultdict(int)
num_tokens_by_feature = defaultdict(int)
tokens_idx = prob_with_mask.flatten(0,1) # [b*t]
answers_idx = shifted_tgt_with_mask.flatten(0,1) # [b*t]
if self.vocab.encoding_scheme == 'remi':
eos_idx = 2
for feature in self.vocab.feature_list:
feature_mask = self.vocab.total_mask[feature].to(self.device) # [327,]
mask_for_target = feature_mask[answers_idx] # [b*t]
if feature == 'type': # because Bar token is 0, we need to add 1 to calculate accuracy
valid_pred = (tokens_idx+1) * mask_for_target
valid_answers = (answers_idx+1) * mask_for_target
eos_mask = valid_answers != eos_idx # because EOS is also working as a padding
correct_guess_by_feature[feature] += torch.sum(valid_pred[eos_mask] == valid_answers[eos_mask]).item() - torch.sum(mask_for_target[eos_mask] == 0).item()
num_tokens_by_feature[feature] += torch.sum(mask_for_target[eos_mask]).item()
else:
valid_pred = tokens_idx * mask_for_target # [b, t]
valid_answers = answers_idx * mask_for_target # [b, t]
correct_guess_by_feature[feature] += torch.sum(valid_pred == valid_answers).item() - torch.sum(mask_for_target == 0).item()
num_tokens_by_feature[feature] += torch.sum(mask_for_target).item()
validation_loss = loss.item() * num_nonmask_tokens.item()
return validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature
def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess_by_feature):
total_num_correct_guess = 0
total_num_tokens = 0
acc_dict = {}
for feature, num_correct_guess in num_correct_guess_by_feature.items():
if feature == 'type':
continue
total_num_correct_guess += num_correct_guess
total_num_tokens += num_tokens_by_feature[feature]
if num_tokens_by_feature[feature] == 0:
continue
acc_dict[f"{feature}_acc"] = num_correct_guess / num_tokens_by_feature[feature]
total_accuracy = total_num_correct_guess / total_num_tokens
acc_dict['total_acc'] = total_accuracy
return acc_dict
def validate(self, external_loader=None):
'''
total_num_tokens: for calculating loss, nonmask tokens
total_num_valid_tokens: for calculating accuracy, valid tokens
'''
if external_loader and isinstance(external_loader, DataLoader):
loader = external_loader
print('An arbitrary loader is used instead of Validation loader')
else:
loader = self.valid_loader
self.model.eval()
total_validation_loss = 0
total_num_tokens = 0
total_num_valid_tokens = 0
total_num_correct_guess = 0
validation_metric_dict = defaultdict(float)
total_num_tokens_by_feature = defaultdict(int)
total_num_correct_guess_dict = defaultdict(int)
with torch.inference_mode():
for num_iter, batch in enumerate(tqdm(loader, leave=False)):
if num_iter == len(self.valid_loader):
if loader is not self.valid_loader: # when validate with train_loader
break
validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, num_correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch)
total_validation_loss += validation_loss
total_num_tokens += num_nonmask_tokens.item()
for key, num_tokens in num_tokens_by_feature.items():
total_num_tokens_by_feature[key] += num_tokens
if key == 'type':
continue
total_num_valid_tokens += num_tokens # num tokens are all the same for each musical type, torch.sum(mask)
for key, num_correct_guess in num_correct_guess_by_feature.items():
total_num_correct_guess_dict[key] += num_correct_guess
if key == 'type':
continue
total_num_correct_guess += num_correct_guess
for key, value in loss_dict.items():
if key == 'total':
validation_metric_dict[key] += value * num_nonmask_tokens
else:
feature_name = key.split('_')[0]
validation_metric_dict[key] += value * num_tokens_by_feature[feature_name]
for key in validation_metric_dict.keys():
if key == 'total':
validation_metric_dict[key] /= total_num_tokens
else:
feature_name = key.split('_')[0]
if total_num_tokens_by_feature[feature_name] == 0:
continue
validation_metric_dict[key] /= total_num_tokens_by_feature[feature_name]
for key in total_num_tokens_by_feature.keys():
num_tokens = total_num_tokens_by_feature[key]
num_correct = total_num_correct_guess_dict[key]
if num_tokens == 0:
continue
validation_metric_dict[f'{key}_acc'] = num_correct / num_tokens
return total_validation_loss / total_num_tokens, total_num_correct_guess / total_num_valid_tokens, validation_metric_dict
class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
def __init__(self, model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config):
super().__init__(model, optimizer, scheduler, loss_fn, midi_decoder, train_set, valid_set, save_dir, vocab, use_ddp, use_fp16, world_size, batch_size, infer_target_len, gpu_id, sampling_method, sampling_threshold, sampling_temperature, config)
'''
About ignore_token and conti_token:
During validation, tokens with this "conti" value are ignored when calculating accuracy or other metrics,
ensuring that repeated values don't unfairly skew the results.
This is especially relevant for features like beat, chord, tempo, and instrument where repeated tokens may have a specific musical meaning.
We used ignore_token and conti_token to fairly compare compound token based encoding with REMI encoding.
'''
def _get_num_valid_and_correct_tokens(self, prob, ground_truth, mask, ignore_token=None, conti_token=None):
valid_prob = torch.argmax(prob, dim=-1) * mask
valid_ground_truth = ground_truth * mask
if ignore_token is None and conti_token is None:
num_valid_tokens = torch.sum(mask)
num_correct_tokens = torch.sum(valid_prob == valid_ground_truth) - torch.sum(mask == 0)
elif ignore_token is not None and conti_token is None:
ignore_mask = valid_ground_truth != ignore_token # batch x seq_len
num_valid_tokens = torch.sum(ignore_mask)
num_correct_tokens = torch.sum(valid_prob[ignore_mask] == valid_ground_truth[ignore_mask]) # by using mask, the tensor becomes 1d
elif ignore_token is not None and conti_token is not None:
ignore_conti_mask = (valid_ground_truth != ignore_token) & (valid_ground_truth != conti_token)
num_valid_tokens = torch.sum(ignore_conti_mask)
num_correct_tokens = torch.sum(valid_prob[ignore_conti_mask] == valid_ground_truth[ignore_conti_mask])
return num_correct_tokens.item(), num_valid_tokens.item()
def _get_loss_pred_from_single_batch(self, batch, valid=False):
# print(batch)
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
input_seq = input_seq.to(self.device)
target = target.to(self.device)
mask = mask[:, :-1].to(self.device)
encoded_caption = encoded_caption.to(self.device)
if self.use_fp16:
if self.config.use_diff is True:
with self.accelerator.autocast():
# breakpoint()
(logits_dict, (masked_indices, p_mask)),input_dict = self.model(input_seq, target,context=encoded_caption)
if self.config.use_dispLoss == True:
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid, input_dict=input_dict,lambda_weight=self.config.lambda_weight,tau=self.config.tau)
else:
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid)
else:
with self.accelerator.autocast():
logits_dict,_ = self.model(input_seq, target,context=encoded_caption)
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, valid)
else:
if self.config.use_diff is True:
# breakpoint()
if self.config.use_dispLoss == True:
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid, input_dict=input_dict,lambda_weight=self.config.lambda_weight)
else:
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, masked_indices, p_mask, valid)
else:
logits_dict, input_Dict = self.model(input_seq, target,context=encoded_caption)
total_loss, loss_dict = self.loss_fn(logits_dict, target, mask, valid)
if valid:
loss_dict['total'] = total_loss.item()
else:
loss_dict = {'total':total_loss.item()}
return total_loss, logits_dict, loss_dict
def _get_valid_loss_and_acc_from_batch(self, batch, train=False):
'''
in this method, valid means handled with both ignore token and mask
when valid tokens with only mask, it is called num_nonmask_tokens
input_seq, target: batch x seq_len x num_features
mask: batch x seq_len, 0 for padding
prob: batch x seq_len x total_vocab_size
'''
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
try:
aux_ar_logits, logits_dict = logits_dict
except:
logits_dict = logits_dict
probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()}
num_nonmask_tokens = torch.sum(mask)
input_seq = input_seq.to(self.device)
target = add_conti_in_valid(target, self.config.nn_params.encoding_scheme).to(self.device)
mask = mask[:, :-1].to(self.device)
correct_guess_by_feature = defaultdict(int)
num_tokens_by_feature = defaultdict(int)
for idx, key in enumerate(self.vocab.feature_list):
if key == 'type' or key == 'timesig' :
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'beat':
# NB's beat vocab has Ignore and CONTI token
# CP's beat vocab has Ignore and BAR token, we exclude BAR token in accuracy calculation for parity with NB
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
else:
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=None)
correct_guess_by_feature[key] = num_correct_tokens
num_tokens_by_feature[key] = num_valid_tokens
validation_loss = total_loss.item() * num_nonmask_tokens.item()
return validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, correct_guess_by_feature
def _get_train_accuracy(self, num_tokens, num_tokens_by_feature, num_correct_guess_by_feature):
total_num_correct_guess = 0
total_num_tokens = 0
acc_dict = {}
for feature, num_correct_guess in num_correct_guess_by_feature.items():
if feature == 'type':
continue
total_num_correct_guess += num_correct_guess
total_num_tokens += num_tokens_by_feature[feature]
acc_dict[f"{feature}_acc"] = num_correct_guess / num_tokens_by_feature[feature]
total_accuracy = total_num_correct_guess / total_num_tokens
acc_dict['total_acc'] = total_accuracy
return acc_dict
def validate(self, external_loader=None):
if external_loader and isinstance(external_loader, DataLoader):
loader = external_loader
print('An arbitrary loader is used instead of Validation loader')
else:
loader = self.valid_loader
self.model.eval()
total_validation_loss = 0
total_num_correct_guess = 0
total_num_tokens = 0
total_num_valid_tokens = 0
validation_metric_dict = defaultdict(float)
total_num_tokens_by_feature = defaultdict(int)
total_num_correct_guess_dict = defaultdict(int)
with torch.inference_mode():
'''
mask is used to calculate loss, accuracy
validation_loss: sum of loss for valid tokens conditioned on mask
num_nonmask_tokens: sum of tokens conditioned on mask
num_tokens_by_feature: sum of valid tokens(handle ignore) for each musical features
num_correct_guess_by_feature: sum of correct tokens(handle ignore) for each musical features
'''
for num_iter, batch in tqdm(enumerate(loader), leave=False):
if num_iter == len(self.valid_loader):
if loader is not self.valid_loader: # when validate with train_loader
break
validation_loss, num_nonmask_tokens, loss_dict, num_tokens_by_feature, num_correct_guess_by_feature = self._get_valid_loss_and_acc_from_batch(batch)
total_validation_loss += validation_loss
total_num_tokens += num_nonmask_tokens
for key, num_tokens in num_tokens_by_feature.items():
total_num_tokens_by_feature[key] += num_tokens
if key == 'type': # because cp and nb have different number of type tokens, we don't want to calculate accuracy for type token
continue
total_num_valid_tokens += num_tokens # num tokens are all the same for each musical type, torch.sum(mask)
for key, num_correct_guess in num_correct_guess_by_feature.items():
total_num_correct_guess_dict[key] += num_correct_guess
if key == 'type':
continue
total_num_correct_guess += num_correct_guess
for key, value in loss_dict.items():
if key == 'total':
validation_metric_dict[key] += value * num_nonmask_tokens
else:
# if torch.isnan(value): # in case num valid tokens is 0 because of mask
# continue
feature_name = key.split('_')[0]
validation_metric_dict[key] += value * num_tokens_by_feature[feature_name]
for key in validation_metric_dict.keys():
if key == 'total':
validation_metric_dict[key] /= total_num_tokens
else:
feature_name = key.split('_')[0]
if total_num_tokens_by_feature[feature_name] == 0:
continue
validation_metric_dict[key] /= total_num_tokens_by_feature[feature_name]
for (key_t, num_tokens), (key_c, num_correct) in zip(total_num_tokens_by_feature.items(), total_num_correct_guess_dict.items()):
validation_metric_dict[f'{key_c}_acc'] = num_correct / num_tokens
return total_validation_loss / (total_num_tokens + 1), total_num_correct_guess / (1+total_num_valid_tokens), validation_metric_dict
def _make_midi_from_generated_output(self, generated_output, iter, seed, condition=None):
if self.config.data_params.first_pred_feature != 'type' and self.config.nn_params.encoding_scheme == 'nb':
generated_output = reverse_shift_and_pad_for_tensor(generated_output, self.config.data_params.first_pred_feature)
if condition is not None:
path_addition = "cond_"
else:
path_addition = ""
# save generated_output as pickle
with open(self.valid_out_dir / f"{path_addition}generated_output_{iter}_seed_{seed}.pkl", 'wb') as f:
pickle.dump(generated_output, f)
self.midi_decoder(generated_output, self.valid_out_dir / f"{path_addition}midi_decoded_{iter}_seed_{seed}.mid")
if self.make_log and self.infer_and_log:
log_dict = {}
log_dict[f'{path_addition}gen_score'] = wandb.Image(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.png'))
log_dict[f'{path_addition}gen_audio'] = wandb.Audio(str(self.valid_out_dir / f'{path_addition}midi_decoded_{iter}_seed_{seed}.mp3'))
wandb.log(log_dict, step=(iter+seed))
print(f"{path_addition}inference is logged: Iter {iter} / seed {seed}")