import torch import torch.nn as nn import torch.nn.functional as F from tqdm.auto import tqdm import time import json from . import transformer_utils from . import sub_decoder_zoo from x_transformers.x_transformers import LayerIntermediates, AbsolutePositionalEmbedding from data_representation.vocab_utils import LangTokenVocab import os class AmadeusModelWrapper(nn.Module): def __init__( self, *, vocab:LangTokenVocab, input_length:int, prediction_order:list, input_embedder_name:str, main_decoder_name:str, sub_decoder_name:str, sub_decoder_depth:int, sub_decoder_enricher_use:bool, dim:int, heads:int, depth:int, dropout:float ): ''' This class wraps the three main components of the AmadeusModel model, which are the input embedding layer, the main transformer decoder, and the sub-decoder. ''' super().__init__() self.vocab = vocab self.vocab_size = vocab.get_vocab_size() self.start_token = vocab.sos_token if hasattr(vocab, 'sos_token') else None self.end_token = vocab.eos_token if hasattr(vocab, 'eos_token') else None self.input_length = input_length self.prediction_order = prediction_order self._get_input_embedder(input_embedder_name, vocab, dropout, dim) self._get_main_decoder(main_decoder_name, input_length, dim, heads, depth, dropout) self._get_sub_decoder(sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout) self.bos_token_hidden = None def _get_input_embedder(self, input_embedder_name, vocab, dropout, dim): self.emb_dropout = nn.Dropout(dropout) self.input_embedder = getattr(transformer_utils, input_embedder_name)( vocab=vocab, dim_model=dim ) def _get_main_decoder(self, main_decoder_name, input_length, dim, heads, depth, dropout): self.pos_enc = AbsolutePositionalEmbedding(dim, input_length) self.main_norm = nn.LayerNorm(dim) self.main_decoder = getattr(transformer_utils, main_decoder_name)( dim=dim, depth=depth, heads=heads, dropout=dropout ) def _get_sub_decoder(self, sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout): self.sub_decoder = getattr(sub_decoder_zoo, sub_decoder_name)( prediction_order=prediction_order, vocab=vocab, dim=dim, sub_decoder_depth=sub_decoder_depth, heads=heads, dropout=dropout, sub_decoder_enricher_use=sub_decoder_enricher_use ) @property def device(self): return next(self.parameters()).device def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None): embedding = self.input_embedder(input_seq) + self.pos_enc(input_seq) embedding = self.emb_dropout(embedding) hidden_vec,layer_inter = self.main_decoder(embedding,train=True, context=context) # B x T x d_model hidden_vec = self.main_norm(hidden_vec) input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': target, 'bos_token_hidden': self.bos_token_hidden} logits = self.sub_decoder(input_dict) # 选择总数中离三分之一最近的层 num_layers = len(layer_inter.layer_hiddens) idx = round(num_layers / 3) idx = min(max(idx, 0), num_layers - 1) input_dict['hidden_vec'] = layer_inter.layer_hiddens[idx] return logits, input_dict class AmadeusModelAutoregressiveWrapper(nn.Module): def __init__(self, net:AmadeusModelWrapper): ''' Initializes an autoregressive wrapper around the AmadeusModelWrapper, which allows sequential token generation. Arguments: - net: The nested music transformer model that performs the token generation. ''' super().__init__() self.net = net def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None): return self.net(input_seq, target, context=context) def _prepare_inference(self, start_token, manual_seed, condition=None, num_target_measures=4): ''' Prepares the initial tokens for autoregressive inference. If a manual seed is provided, it sets the seed for reproducibility. If a condition is given, it selects a subset of the tokens based on certain criteria related to the encoding scheme. Arguments: - start_token: The token that represents the start of a sequence. - manual_seed: A seed value for reproducibility in inference (if greater than 0). - condition: An optional tensor used for conditional generation, which helps select a portion of the input tokens based on the encoding scheme. Returns: - total_out: A tensor containing the initial tokens for inference, padded to ensure compatibility with the model. ''' if manual_seed > 0: torch.manual_seed(manual_seed) total_out = [] if condition is None: # Use the start token if no condition is given total_out.extend(start_token) else: # Extract the portion of the sequence depending on encoding scheme (remi, cp, or nb) if self.net.vocab.encoding_scheme == 'remi': type_boundaries = self.net.vocab.remi_vocab_boundaries_by_key['type'] # vocab idx -> 0:SOS, 1:EOS, 2:Bar_without_time_signature, ... where_type_ends:Bar_time_signature_end, ... measure_bool = (2 <= condition) & (condition < type_boundaries[1]) # between Bar_ts_start and Bar_ts_end conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() elif self.net.vocab.encoding_scheme == 'cp': # find the start and end of the measure beat_event2idx = self.net.vocab.event2idx['beat'] for event, idx in beat_event2idx.items(): if event == 0: continue if event == 'Bar': start_idx = idx elif event.startswith('Beat'): end_idx = idx break measure_bool = (condition[:,1] >= start_idx) & (condition[:,1] < end_idx) # measure tokens conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() # measure_bool = (condition[:,1] == 1) # measure tokens conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() elif self.net.vocab.encoding_scheme == 'nb': measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() if conditional_input_len == 0: conditional_input_len = 50 selected_tokens = condition[:conditional_input_len].tolist() total_out.extend(selected_tokens) total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device) return total_out def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None): ''' Runs one step of autoregressive generation by taking the input sequence, embedding it, passing it through the main decoder, and generating logits and a sampled token. Arguments: - input_seq: The input sequence tensor to be embedded and processed. - cache: Optional cache for attention mechanisms to avoid recomputation. - sampling_method: Sampling strategy used to select the next token. - threshold: Optional threshold value for sampling methods that require it. - temperature: Controls the randomness of predictions (higher temperature increases randomness). Returns: - logits: The predicted logits for the next token. - sampled_token: The token sampled from the logits. - intermidiates: Intermediate states from the main decoder, useful for caching. ''' embedding = self.net.input_embedder(input_seq) + self.net.pos_enc(input_seq) embedding = self.net.emb_dropout(embedding) # Run through the main decoder and normalize hidden_vec, intermidiates = self.net.main_decoder(embedding, cache,context_embedding=context) # B x T x d_model hidden_vec = self.net.main_norm(hidden_vec) hidden_vec = hidden_vec[:, -1:] # Keep only the last time step input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec} # Generate the next token logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature) return logits, sampled_token, intermidiates, hidden_vec def _update_total_out(self, total_out, sampled_token): ''' Updates the output sequence with the newly sampled token. Depending on the encoding scheme, it either appends the token directly or processes feature-based sampling. Arguments: - total_out: The tensor containing the previously generated tokens. - sampled_token: The newly generated token to be appended. Returns: - total_out: Updated output tensor with the newly generated token. - sampled_token: The processed sampled token. ''' if self.net.vocab.encoding_scheme == 'remi': # For remi encoding, directly append the sampled token total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1) else: # Handle other encoding schemes by concatenating features sampled_token_list = [] for key in self.net.vocab.feature_list: sampled_token_list.append(sampled_token[key]) sampled_token = torch.cat(sampled_token_list, dim=-1) # print(total_out.shape) if len(sampled_token.shape) == 2: total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=1) total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1) return total_out, sampled_token @torch.inference_mode() def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None): ''' Autoregressively generates a sequence of tokens by repeatedly sampling the next token until the desired maximum sequence length is reached or the end token is encountered. Arguments: - manual_seed: A seed value for reproducibility in inference. - max_seq_len: The maximum length of the generated sequence. - condition: An optional conditioning sequence to start generation from. - sampling_method: The method used to sample the next token (e.g., greedy, top-k). - threshold: Optional threshold for sampling (used in methods like top-p sampling). - temperature: Controls the randomness of the token sampling process. - batch_size: The number of sequences to generate in parallel. Returns: - total_out: The generated sequence of tokens as a tensor. ''' # Prepare the starting sequence for inference total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures) # If a condition is provided, run one initial step if condition is not None: _, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context) else: cache = LayerIntermediates() # Continue generating tokens until the maximum sequence length is reached pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token") bos_hidden_vec = None hidden_vec_list = [] token_time_list = [] while total_out.shape[1] < max_seq_len: pbar.update(1) input_tensor = total_out[:, -self.net.input_length:] # Generate the next token and update the cache time_start = time.time() _, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context) time_end = time.time() token_time_list.append(time_end - time_start) if bos_hidden_vec is None: bos_hidden_vec = hidden_vec hidden_vec_list.append(hidden_vec) # Update attention cache to handle autoregressive generation for inter in cache.attn_intermediates: inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv] # Update the generated output with the new token total_out, sampled_token = self._update_total_out(total_out, sampled_token) # Stop if the end token is reached if sampled_token.tolist() == self.net.end_token[0]: break # append hidden_vec to pkl # save_path = 'hidden/diffnoaug_hidden_vec.pt' # save_time_path = 'hidden/diff_noaug_token_time.json' # if os.path.exists(save_path): # # Load existing list and append # hidden_vec_all = torch.load(save_path, map_location="cpu") # hidden_vec_all.extend(hidden_vec_list) # torch.save(hidden_vec_all, save_path) # else: # torch.save(hidden_vec_list, save_path) # if os.path.exists(save_time_path): # # Load existing list and append # token_time_all = json.load(open(save_time_path, 'r')) # token_time_all = token_time_all['token_time_list'] # token_time_all.extend(token_time_list) # average_time = sum(token_time_all) / len(token_time_all) # data = { # 'average_time': average_time, # 'token_time_list': token_time_all # } # json.dump(data, open(save_time_path, 'w'), indent=4) # else: # average_time = sum(token_time_list) / len(token_time_list) # data = { # 'average_time': average_time, # 'token_time_list': token_time_list # } # json.dump(data, open(save_time_path, 'w'), indent=4) return total_out def generate_batch(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1): ''' Autoregressively generates a sequence of tokens by repeatedly sampling the next token until the desired maximum sequence length is reached or the end token is encountered. Arguments: - manual_seed: A seed value for reproducibility in inference. - max_seq_len: The maximum length of the generated sequence. - condition: An optional conditioning sequence to start generation from. - sampling_method: The method used to sample the next token (e.g., greedy, top-k). - threshold: Optional threshold for sampling (used in methods like top-p sampling). - temperature: Controls the randomness of the token sampling process. - batch_size: The number of sequences to generate in parallel. Returns: - total_out: The generated sequence of tokens as a tensor. ''' # Prepare the starting sequence for inference total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures) # total_out (1,1,num) -> (bs,1,num) total_out = total_out.repeat(batch_size, 1, 1) # If a condition is provided, run one initial step if condition is not None: _, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature) else: cache = LayerIntermediates() # Continue generating tokens until the maximum sequence length is reached pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token") while total_out.shape[1] < max_seq_len: pbar.update(1) input_tensor = total_out[:, -self.net.input_length:] # Generate the next token and update the cache _, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature) # Update attention cache to handle autoregressive generation for inter in cache.attn_intermediates: inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv] # Update the generated output with the new token total_out, sampled_token = self._update_total_out(total_out, sampled_token) # Stop if the end token is reached if sampled_token.tolist() == self.net.end_token[0]: break return total_out class AmadeusModel(nn.Module): def __init__( self, vocab:LangTokenVocab, input_length:int, prediction_order:list, input_embedder_name:str, main_decoder_name:str, sub_decoder_name:str, sub_decoder_depth:int, sub_decoder_enricher_use:bool, dim:int, heads:int, depth:int, dropout:float ): ''' This class combines the wrapper classes and initializes the full AmadeusModel model, which can perform autoregressive sequence generation for symbolic music. Vocabulary used for tokenization of the symbolic music data. Length of the input seqkeuence in tokens. Defines the order in which features are predicted in a sequence used for compound shift Name of the input embedding model to be used (e.g., one-hot embedding or learned embeddings). Name of the main transformer decoder model used for generating the hidden representations for compound tokens. Name of the sub-decoder, which processes the hidden states and decodes the sub-tokens inside the compound tokens. Depth (number of layers) of the sub-decoder. Whether to use an additional enricher module in the sub-decoder to refine representations. Dimensionality of the model (hidden size of the transformer layers). Number of attention heads in the transformer layers. Number of layers in the main decoder. Dropout rate for all layers in the model. ''' super().__init__() decoder = AmadeusModelWrapper( vocab=vocab, input_length=input_length, prediction_order=prediction_order, input_embedder_name=input_embedder_name, main_decoder_name=main_decoder_name, sub_decoder_name=sub_decoder_name, sub_decoder_depth=sub_decoder_depth, sub_decoder_enricher_use=sub_decoder_enricher_use, dim=dim, heads=heads, depth=depth, dropout=dropout ) self.decoder = AmadeusModelAutoregressiveWrapper( net=decoder ) def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None): return self.decoder(input_seq, target, context=context) @torch.inference_mode() def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None): if batch_size == 1: return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context) else: return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context) class AmadeusModel4Encodec(AmadeusModel): def __init__( self, vocab:LangTokenVocab, input_length:int, prediction_order:list, input_embedder_name:str, main_decoder_name:str, sub_decoder_name:str, sub_decoder_depth:int, sub_decoder_enricher_use:bool, dim:int, heads:int, depth:int, dropout:float ): super().__init__( vocab=vocab, input_length=input_length, prediction_order=prediction_order, input_embedder_name=input_embedder_name, main_decoder_name=main_decoder_name, sub_decoder_name=sub_decoder_name, sub_decoder_depth=sub_decoder_depth, sub_decoder_enricher_use=sub_decoder_enricher_use, dim=dim, heads=heads, depth=depth, dropout=dropout ) def _prepare_inference(self, start_token, manual_seed, condition=None): if manual_seed > 0: torch.manual_seed(manual_seed) total_out = [] if condition is None: total_out.extend(start_token) else: if self.decoder.net.vocab.encoding_scheme == 'remi': selected_tokens = condition[:1500].tolist() else: selected_tokens = condition[:500].tolist() total_out.extend(selected_tokens) total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.decoder.net.device) return total_out def _update_total_out(self, total_out, sampled_token): if self.decoder.net.vocab.encoding_scheme == 'remi': total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1) else: sampled_token_list = [] for key in self.decoder.net.vocab.feature_list: sampled_token_list.append(sampled_token[key]) sampled_token = torch.cat(sampled_token_list, dim=-1) # B(1) x num_features total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1) return total_out, sampled_token def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1): embedding = self.decoder.net.input_embedder(input_seq) + self.decoder.net.pos_enc(input_seq) embedding = self.decoder.net.emb_dropout(embedding) hidden_vec, intermidiates = self.decoder.net.main_decoder(embedding, cache) # B x T x d_model hidden_vec = self.decoder.net.main_norm(hidden_vec) hidden_vec = hidden_vec[:, -1:] # B x 1 x d_model input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': None} if self.decoder.net.vocab.encoding_scheme == 'remi': feature_class_idx = (input_seq.shape[1] - 1) % 4 feature_type = self.decoder.net.vocab.feature_list[feature_class_idx] logits, sampled_token = self.decoder.net.sub_decoder.run_one_step(input_dict, sampling_method, threshold, temperature, feature_type) else: logits, sampled_token = self.decoder.net.sub_decoder(input_dict, sampling_method, threshold, temperature) return logits, sampled_token, intermidiates @torch.inference_mode() def generate(self, manual_seed, max_seq_len, condition=None, sampling_method=None, threshold=None, temperature=1): total_out = self._prepare_inference(self.decoder.net.start_token, manual_seed, condition) if condition is not None: _, _, cache = self._run_one_step(total_out[:, -self.decoder.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature) else: cache = LayerIntermediates() while total_out.shape[1] < max_seq_len: input_tensor = total_out[:, -self.decoder.net.input_length:] _, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature) for inter in cache.attn_intermediates: inter.cached_kv = [t[..., -(self.decoder.net.input_length - 1):, :] for t in inter.cached_kv] # B x num_heads x T x d_head total_out, sampled_token = self._update_total_out(total_out, sampled_token) if sampled_token.tolist() == self.decoder.net.end_token[0]: break return total_out