540 lines
25 KiB
Python
540 lines
25 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from tqdm.auto import tqdm
|
|
import time
|
|
import json
|
|
|
|
from . import transformer_utils
|
|
from . import sub_decoder_zoo
|
|
from x_transformers.x_transformers import LayerIntermediates, AbsolutePositionalEmbedding
|
|
from data_representation.vocab_utils import LangTokenVocab
|
|
import os
|
|
|
|
class AmadeusModelWrapper(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vocab:LangTokenVocab,
|
|
input_length:int,
|
|
prediction_order:list,
|
|
input_embedder_name:str,
|
|
main_decoder_name:str,
|
|
sub_decoder_name:str,
|
|
sub_decoder_depth:int,
|
|
sub_decoder_enricher_use:bool,
|
|
dim:int,
|
|
heads:int,
|
|
depth:int,
|
|
dropout:float
|
|
):
|
|
'''
|
|
This class wraps the three main components of the AmadeusModel model,
|
|
which are the input embedding layer, the main transformer decoder, and the sub-decoder.
|
|
'''
|
|
|
|
super().__init__()
|
|
self.vocab = vocab
|
|
self.vocab_size = vocab.get_vocab_size()
|
|
self.start_token = vocab.sos_token if hasattr(vocab, 'sos_token') else None
|
|
self.end_token = vocab.eos_token if hasattr(vocab, 'eos_token') else None
|
|
self.input_length = input_length
|
|
self.prediction_order = prediction_order
|
|
self._get_input_embedder(input_embedder_name, vocab, dropout, dim)
|
|
self._get_main_decoder(main_decoder_name, input_length, dim, heads, depth, dropout)
|
|
self._get_sub_decoder(sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout)
|
|
self.bos_token_hidden = None
|
|
|
|
def _get_input_embedder(self, input_embedder_name, vocab, dropout, dim):
|
|
self.emb_dropout = nn.Dropout(dropout)
|
|
self.input_embedder = getattr(transformer_utils, input_embedder_name)(
|
|
vocab=vocab,
|
|
dim_model=dim
|
|
)
|
|
|
|
def _get_main_decoder(self, main_decoder_name, input_length, dim, heads, depth, dropout):
|
|
self.pos_enc = AbsolutePositionalEmbedding(dim, input_length)
|
|
self.main_norm = nn.LayerNorm(dim)
|
|
self.main_decoder = getattr(transformer_utils, main_decoder_name)(
|
|
dim=dim,
|
|
depth=depth,
|
|
heads=heads,
|
|
dropout=dropout
|
|
)
|
|
|
|
def _get_sub_decoder(self, sub_decoder_name, prediction_order, vocab, sub_decoder_depth, sub_decoder_enricher_use, dim, heads, dropout):
|
|
self.sub_decoder = getattr(sub_decoder_zoo, sub_decoder_name)(
|
|
prediction_order=prediction_order,
|
|
vocab=vocab,
|
|
dim=dim,
|
|
sub_decoder_depth=sub_decoder_depth,
|
|
heads=heads,
|
|
dropout=dropout,
|
|
sub_decoder_enricher_use=sub_decoder_enricher_use
|
|
)
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
|
|
embedding = self.input_embedder(input_seq) + self.pos_enc(input_seq)
|
|
embedding = self.emb_dropout(embedding)
|
|
hidden_vec,layer_inter = self.main_decoder(embedding,train=True, context=context) # B x T x d_model
|
|
hidden_vec = self.main_norm(hidden_vec)
|
|
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': target, 'bos_token_hidden': self.bos_token_hidden}
|
|
logits = self.sub_decoder(input_dict)
|
|
# 选择总数中离三分之一最近的层
|
|
num_layers = len(layer_inter.layer_hiddens)
|
|
idx = round(num_layers / 3)
|
|
idx = min(max(idx, 0), num_layers - 1)
|
|
input_dict['hidden_vec'] = layer_inter.layer_hiddens[idx]
|
|
return logits, input_dict
|
|
|
|
class AmadeusModelAutoregressiveWrapper(nn.Module):
|
|
def __init__(self, net:AmadeusModelWrapper):
|
|
'''
|
|
Initializes an autoregressive wrapper around the AmadeusModelWrapper,
|
|
which allows sequential token generation.
|
|
|
|
Arguments:
|
|
- net: The nested music transformer model that performs the token generation.
|
|
'''
|
|
super().__init__()
|
|
self.net = net
|
|
# self.prediction_order = net.prediction_order
|
|
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
|
|
self.attribute2idx_after = {'pitch': 0,
|
|
'duration': 1,
|
|
'velocity': 2,
|
|
'type': 3,
|
|
'beat': 4,
|
|
'chord': 5,
|
|
'tempo': 6,
|
|
'instrument': 7}
|
|
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
|
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
|
|
return self.net(input_seq, target, context=context)
|
|
|
|
def _prepare_inference(self, start_token, manual_seed, condition=None, num_target_measures=4):
|
|
'''
|
|
Prepares the initial tokens for autoregressive inference. If a manual seed is provided,
|
|
it sets the seed for reproducibility. If a condition is given, it selects a subset of
|
|
the tokens based on certain criteria related to the encoding scheme.
|
|
|
|
Arguments:
|
|
- start_token: The token that represents the start of a sequence.
|
|
- manual_seed: A seed value for reproducibility in inference (if greater than 0).
|
|
- condition: An optional tensor used for conditional generation, which helps select a
|
|
portion of the input tokens based on the encoding scheme.
|
|
|
|
Returns:
|
|
- total_out: A tensor containing the initial tokens for inference, padded to ensure compatibility
|
|
with the model.
|
|
'''
|
|
if manual_seed > 0:
|
|
torch.manual_seed(manual_seed)
|
|
|
|
total_out = []
|
|
if condition is None:
|
|
# Use the start token if no condition is given
|
|
total_out.extend(start_token)
|
|
else:
|
|
# Extract the portion of the sequence depending on encoding scheme (remi, cp, or nb)
|
|
if self.net.vocab.encoding_scheme == 'remi':
|
|
type_boundaries = self.net.vocab.remi_vocab_boundaries_by_key['type']
|
|
# vocab idx -> 0:SOS, 1:EOS, 2:Bar_without_time_signature, ... where_type_ends:Bar_time_signature_end, ...
|
|
measure_bool = (2 <= condition) & (condition < type_boundaries[1]) # between Bar_ts_start and Bar_ts_end
|
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
|
elif self.net.vocab.encoding_scheme == 'cp':
|
|
# find the start and end of the measure
|
|
beat_event2idx = self.net.vocab.event2idx['beat']
|
|
for event, idx in beat_event2idx.items():
|
|
if event == 0:
|
|
continue
|
|
if event == 'Bar':
|
|
start_idx = idx
|
|
elif event.startswith('Beat'):
|
|
end_idx = idx
|
|
break
|
|
measure_bool = (condition[:,1] >= start_idx) & (condition[:,1] < end_idx) # measure tokens
|
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
|
# measure_bool = (condition[:,1] == 1) # measure tokens
|
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
|
elif self.net.vocab.encoding_scheme == 'nb':
|
|
measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts
|
|
conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item()
|
|
|
|
if conditional_input_len == 0:
|
|
conditional_input_len = 50
|
|
|
|
selected_tokens = condition[:conditional_input_len].tolist()
|
|
total_out.extend(selected_tokens)
|
|
|
|
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
|
|
return total_out
|
|
|
|
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None):
|
|
'''
|
|
Runs one step of autoregressive generation by taking the input sequence, embedding it,
|
|
passing it through the main decoder, and generating logits and a sampled token.
|
|
|
|
Arguments:
|
|
- input_seq: The input sequence tensor to be embedded and processed.
|
|
- cache: Optional cache for attention mechanisms to avoid recomputation.
|
|
- sampling_method: Sampling strategy used to select the next token.
|
|
- threshold: Optional threshold value for sampling methods that require it.
|
|
- temperature: Controls the randomness of predictions (higher temperature increases randomness).
|
|
|
|
Returns:
|
|
- logits: The predicted logits for the next token.
|
|
- sampled_token: The token sampled from the logits.
|
|
- intermidiates: Intermediate states from the main decoder, useful for caching.
|
|
'''
|
|
embedding = self.net.input_embedder(input_seq) + self.net.pos_enc(input_seq)
|
|
embedding = self.net.emb_dropout(embedding)
|
|
|
|
# Run through the main decoder and normalize
|
|
hidden_vec, intermidiates = self.net.main_decoder(embedding, cache,context_embedding=context) # B x T x d_model
|
|
hidden_vec = self.net.main_norm(hidden_vec)
|
|
hidden_vec = hidden_vec[:, -1:] # Keep only the last time step
|
|
|
|
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
|
|
|
|
# Generate the next token
|
|
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step)
|
|
return logits, sampled_token, intermidiates, hidden_vec
|
|
|
|
def _update_total_out(self, total_out, sampled_token):
|
|
'''
|
|
Updates the output sequence with the newly sampled token. Depending on the encoding scheme,
|
|
it either appends the token directly or processes feature-based sampling.
|
|
|
|
Arguments:
|
|
- total_out: The tensor containing the previously generated tokens.
|
|
- sampled_token: The newly generated token to be appended.
|
|
|
|
Returns:
|
|
- total_out: Updated output tensor with the newly generated token.
|
|
- sampled_token: The processed sampled token.
|
|
'''
|
|
if self.net.vocab.encoding_scheme == 'remi':
|
|
# For remi encoding, directly append the sampled token
|
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
|
|
else:
|
|
# Handle other encoding schemes by concatenating features
|
|
sampled_token_list = []
|
|
for key in self.net.vocab.feature_list:
|
|
sampled_token_list.append(sampled_token[key])
|
|
sampled_token = torch.cat(sampled_token_list, dim=-1)
|
|
# print(total_out.shape)
|
|
if len(sampled_token.shape) == 2:
|
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=1)
|
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
|
|
|
|
return total_out, sampled_token
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None):
|
|
'''
|
|
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
|
|
until the desired maximum sequence length is reached or the end token is encountered.
|
|
|
|
Arguments:
|
|
- manual_seed: A seed value for reproducibility in inference.
|
|
- max_seq_len: The maximum length of the generated sequence.
|
|
- condition: An optional conditioning sequence to start generation from.
|
|
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
|
|
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
|
|
- temperature: Controls the randomness of the token sampling process.
|
|
- batch_size: The number of sequences to generate in parallel.
|
|
|
|
Returns:
|
|
- total_out: The generated sequence of tokens as a tensor.
|
|
'''
|
|
# Prepare the starting sequence for inference
|
|
if attr_list is None:
|
|
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
|
else:
|
|
total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures)
|
|
# for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336
|
|
condition_filtered = condition.clone().unsqueeze(0)
|
|
# print(self.attribute2idx)
|
|
for attr, idx in self.attribute2idx.items():
|
|
if attr not in attr_list:
|
|
condition_filtered[:, :, idx] = 126336
|
|
# rearange condition_filtered to match prediction order
|
|
|
|
cache = LayerIntermediates()
|
|
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
|
|
bos_hidden_vec = None
|
|
hidden_vec_list = []
|
|
token_time_list = []
|
|
while total_out.shape[1] < max_seq_len:
|
|
pbar.update(1)
|
|
input_tensor = total_out[:, -self.net.input_length:]
|
|
# Generate the next token and update the cache
|
|
time_start = time.time()
|
|
# if attr_list is not None, get one token in condition_filtered each time step
|
|
if attr_list is not None:
|
|
condition_filtered = condition_filtered.to(self.net.device)
|
|
# print(condition_filtered[:,:20,:])
|
|
# print(condition_filtered.shape)
|
|
condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :]
|
|
# rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4
|
|
condition_step_rearranged = torch.zeros_like(condition_step)
|
|
for attr, idx in self.attribute2idx.items():
|
|
new_idx = self.attribute2idx_after[attr]
|
|
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
|
|
# print("condition_step shape:", condition_step.shape)
|
|
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
|
|
else:
|
|
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
|
time_end = time.time()
|
|
token_time_list.append(time_end - time_start)
|
|
if bos_hidden_vec is None:
|
|
bos_hidden_vec = hidden_vec
|
|
hidden_vec_list.append(hidden_vec)
|
|
# Update attention cache to handle autoregressive generation
|
|
for inter in cache.attn_intermediates:
|
|
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
|
|
|
|
# Update the generated output with the new token
|
|
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
|
|
|
|
# Stop if the end token is reached
|
|
if sampled_token.tolist() == self.net.end_token[0]:
|
|
break
|
|
# append hidden_vec to pkl
|
|
|
|
# save_path = 'hidden/diffnoaug_hidden_vec.pt'
|
|
# save_time_path = 'hidden/diff_noaug_token_time.json'
|
|
# if os.path.exists(save_path):
|
|
# # Load existing list and append
|
|
# hidden_vec_all = torch.load(save_path, map_location="cpu")
|
|
# hidden_vec_all.extend(hidden_vec_list)
|
|
# torch.save(hidden_vec_all, save_path)
|
|
# else:
|
|
# torch.save(hidden_vec_list, save_path)
|
|
|
|
# if os.path.exists(save_time_path):
|
|
# # Load existing list and append
|
|
# token_time_all = json.load(open(save_time_path, 'r'))
|
|
# token_time_all = token_time_all['token_time_list']
|
|
# token_time_all.extend(token_time_list)
|
|
# average_time = sum(token_time_all) / len(token_time_all)
|
|
# data = {
|
|
# 'average_time': average_time,
|
|
# 'token_time_list': token_time_all
|
|
# }
|
|
# json.dump(data, open(save_time_path, 'w'), indent=4)
|
|
# else:
|
|
# average_time = sum(token_time_list) / len(token_time_list)
|
|
# data = {
|
|
# 'average_time': average_time,
|
|
# 'token_time_list': token_time_list
|
|
# }
|
|
# json.dump(data, open(save_time_path, 'w'), indent=4)
|
|
|
|
return total_out
|
|
|
|
def generate_batch(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1):
|
|
'''
|
|
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
|
|
until the desired maximum sequence length is reached or the end token is encountered.
|
|
|
|
Arguments:
|
|
- manual_seed: A seed value for reproducibility in inference.
|
|
- max_seq_len: The maximum length of the generated sequence.
|
|
- condition: An optional conditioning sequence to start generation from.
|
|
- sampling_method: The method used to sample the next token (e.g., greedy, top-k).
|
|
- threshold: Optional threshold for sampling (used in methods like top-p sampling).
|
|
- temperature: Controls the randomness of the token sampling process.
|
|
- batch_size: The number of sequences to generate in parallel.
|
|
|
|
Returns:
|
|
- total_out: The generated sequence of tokens as a tensor.
|
|
'''
|
|
# Prepare the starting sequence for inference
|
|
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
|
# total_out (1,1,num) -> (bs,1,num)
|
|
total_out = total_out.repeat(batch_size, 1, 1)
|
|
# If a condition is provided, run one initial step
|
|
if condition is not None:
|
|
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
|
else:
|
|
cache = LayerIntermediates()
|
|
|
|
# Continue generating tokens until the maximum sequence length is reached
|
|
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
|
|
while total_out.shape[1] < max_seq_len:
|
|
pbar.update(1)
|
|
input_tensor = total_out[:, -self.net.input_length:]
|
|
|
|
# Generate the next token and update the cache
|
|
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
|
|
|
# Update attention cache to handle autoregressive generation
|
|
for inter in cache.attn_intermediates:
|
|
inter.cached_kv = [t[..., -(self.net.input_length - 1):, :] for t in inter.cached_kv]
|
|
|
|
# Update the generated output with the new token
|
|
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
|
|
|
|
# Stop if the end token is reached
|
|
if sampled_token.tolist() == self.net.end_token[0]:
|
|
break
|
|
|
|
return total_out
|
|
|
|
class AmadeusModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vocab:LangTokenVocab,
|
|
input_length:int,
|
|
prediction_order:list,
|
|
input_embedder_name:str,
|
|
main_decoder_name:str,
|
|
sub_decoder_name:str,
|
|
sub_decoder_depth:int,
|
|
sub_decoder_enricher_use:bool,
|
|
dim:int,
|
|
heads:int,
|
|
depth:int,
|
|
dropout:float
|
|
):
|
|
'''
|
|
This class combines the wrapper classes and initializes the full AmadeusModel model,
|
|
which can perform autoregressive sequence generation for symbolic music.
|
|
|
|
Vocabulary used for tokenization of the symbolic music data.
|
|
Length of the input seqkeuence in tokens.
|
|
Defines the order in which features are predicted in a sequence used for compound shift
|
|
Name of the input embedding model to be used (e.g., one-hot embedding or learned embeddings).
|
|
Name of the main transformer decoder model used for generating the hidden representations for compound tokens.
|
|
Name of the sub-decoder, which processes the hidden states and decodes the sub-tokens inside the compound tokens.
|
|
Depth (number of layers) of the sub-decoder.
|
|
Whether to use an additional enricher module in the sub-decoder to refine representations.
|
|
Dimensionality of the model (hidden size of the transformer layers).
|
|
Number of attention heads in the transformer layers.
|
|
Number of layers in the main decoder.
|
|
Dropout rate for all layers in the model.
|
|
'''
|
|
|
|
super().__init__()
|
|
decoder = AmadeusModelWrapper(
|
|
vocab=vocab,
|
|
input_length=input_length,
|
|
prediction_order=prediction_order,
|
|
input_embedder_name=input_embedder_name,
|
|
main_decoder_name=main_decoder_name,
|
|
sub_decoder_name=sub_decoder_name,
|
|
sub_decoder_depth=sub_decoder_depth,
|
|
sub_decoder_enricher_use=sub_decoder_enricher_use,
|
|
dim=dim,
|
|
heads=heads,
|
|
depth=depth,
|
|
dropout=dropout
|
|
)
|
|
self.decoder = AmadeusModelAutoregressiveWrapper(
|
|
net=decoder
|
|
)
|
|
|
|
def forward(self, input_seq:torch.Tensor, target:torch.Tensor, context=None):
|
|
return self.decoder(input_seq, target, context=context)
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None):
|
|
if batch_size == 1:
|
|
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list)
|
|
else:
|
|
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list)
|
|
|
|
class AmadeusModel4Encodec(AmadeusModel):
|
|
def __init__(
|
|
self,
|
|
vocab:LangTokenVocab,
|
|
input_length:int,
|
|
prediction_order:list,
|
|
input_embedder_name:str,
|
|
main_decoder_name:str,
|
|
sub_decoder_name:str,
|
|
sub_decoder_depth:int,
|
|
sub_decoder_enricher_use:bool,
|
|
dim:int,
|
|
heads:int,
|
|
depth:int,
|
|
dropout:float
|
|
):
|
|
super().__init__(
|
|
vocab=vocab,
|
|
input_length=input_length,
|
|
prediction_order=prediction_order,
|
|
input_embedder_name=input_embedder_name,
|
|
main_decoder_name=main_decoder_name,
|
|
sub_decoder_name=sub_decoder_name,
|
|
sub_decoder_depth=sub_decoder_depth,
|
|
sub_decoder_enricher_use=sub_decoder_enricher_use,
|
|
dim=dim,
|
|
heads=heads,
|
|
depth=depth,
|
|
dropout=dropout
|
|
)
|
|
|
|
def _prepare_inference(self, start_token, manual_seed, condition=None):
|
|
if manual_seed > 0:
|
|
torch.manual_seed(manual_seed)
|
|
total_out = []
|
|
if condition is None:
|
|
total_out.extend(start_token)
|
|
else:
|
|
if self.decoder.net.vocab.encoding_scheme == 'remi':
|
|
selected_tokens = condition[:1500].tolist()
|
|
else:
|
|
selected_tokens = condition[:500].tolist()
|
|
total_out.extend(selected_tokens)
|
|
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.decoder.net.device)
|
|
return total_out
|
|
|
|
def _update_total_out(self, total_out, sampled_token):
|
|
if self.decoder.net.vocab.encoding_scheme == 'remi':
|
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0)], dim=-1)
|
|
else:
|
|
sampled_token_list = []
|
|
for key in self.decoder.net.vocab.feature_list:
|
|
sampled_token_list.append(sampled_token[key])
|
|
sampled_token = torch.cat(sampled_token_list, dim=-1) # B(1) x num_features
|
|
total_out = torch.cat([total_out, sampled_token.unsqueeze(0).unsqueeze(0)], dim=1)
|
|
return total_out, sampled_token
|
|
|
|
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1):
|
|
embedding = self.decoder.net.input_embedder(input_seq) + self.decoder.net.pos_enc(input_seq)
|
|
embedding = self.decoder.net.emb_dropout(embedding)
|
|
hidden_vec, intermidiates = self.decoder.net.main_decoder(embedding, cache) # B x T x d_model
|
|
hidden_vec = self.decoder.net.main_norm(hidden_vec)
|
|
hidden_vec = hidden_vec[:, -1:] # B x 1 x d_model
|
|
input_dict = {'hidden_vec':hidden_vec, 'input_seq': input_seq, 'target': None}
|
|
if self.decoder.net.vocab.encoding_scheme == 'remi':
|
|
feature_class_idx = (input_seq.shape[1] - 1) % 4
|
|
feature_type = self.decoder.net.vocab.feature_list[feature_class_idx]
|
|
logits, sampled_token = self.decoder.net.sub_decoder.run_one_step(input_dict, sampling_method, threshold, temperature, feature_type)
|
|
else:
|
|
logits, sampled_token = self.decoder.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
|
|
return logits, sampled_token, intermidiates
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, manual_seed, max_seq_len, condition=None, sampling_method=None, threshold=None, temperature=1):
|
|
total_out = self._prepare_inference(self.decoder.net.start_token, manual_seed, condition)
|
|
if condition is not None:
|
|
_, _, cache = self._run_one_step(total_out[:, -self.decoder.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
|
else:
|
|
cache = LayerIntermediates()
|
|
while total_out.shape[1] < max_seq_len:
|
|
input_tensor = total_out[:, -self.decoder.net.input_length:]
|
|
_, sampled_token, cache = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
|
for inter in cache.attn_intermediates:
|
|
inter.cached_kv = [t[..., -(self.decoder.net.input_length - 1):, :] for t in inter.cached_kv] # B x num_heads x T x d_head
|
|
total_out, sampled_token = self._update_total_out(total_out, sampled_token)
|
|
if sampled_token.tolist() == self.decoder.net.end_token[0]:
|
|
break
|
|
return total_out |