1280 lines
66 KiB
Python
1280 lines
66 KiB
Python
from selectors import EpollSelector
|
||
from turtle import st
|
||
from numpy import indices
|
||
from sympy import Trace, false, true
|
||
import torch
|
||
import torch.profiler
|
||
import torch.nn as nn
|
||
|
||
from x_transformers import Decoder
|
||
|
||
from .transformer_utils import MultiEmbedding, RVQMultiEmbedding
|
||
from .sub_decoder_utils import *
|
||
from .sampling_utils import sample, sample_with_prob, sample_with_prob_fast, top_p_sampling, typical_sampling, eta_sampling
|
||
|
||
from data_representation.vocab_utils import LangTokenVocab
|
||
|
||
class SingleProjection(nn.Module):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
'''
|
||
This sub-decoder is used for REMI based models
|
||
'''
|
||
super().__init__()
|
||
vocab_size = vocab.get_vocab_size()
|
||
self.proj = nn.Linear(dim, vocab_size)
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=1):
|
||
hidden_vec = input_dict['hidden_vec']
|
||
target = input_dict['target']
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
logits = self.proj(hidden_vec[:, -1:])
|
||
sampled_token = sample(logits, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
return logits, sampled_token
|
||
# ---- Training ---- #
|
||
logits = self.proj(hidden_vec)
|
||
return logits
|
||
|
||
class SubDecoderClass(nn.Module):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
super().__init__()
|
||
'''
|
||
This is the base class for all sub-decoders
|
||
'''
|
||
self.prediction_order = prediction_order
|
||
self.vocab = vocab
|
||
self.vocab_size = vocab.get_vocab_size()
|
||
# make layers
|
||
self._make_emb_layer(vocab, dim)
|
||
self._make_projection_layer(vocab, dim)
|
||
self._make_nonlinear_layer()
|
||
|
||
@property
|
||
def device(self):
|
||
return next(self.parameters()).device
|
||
|
||
def _make_emb_layer(self, vocab, dim):
|
||
self.emb_layer = MultiEmbedding(
|
||
vocab=vocab,
|
||
dim_model=dim
|
||
)
|
||
|
||
# def _make_projection_layer(self, vocab, dim):
|
||
# vocab_sizes = vocab.get_vocab_size()
|
||
# self.hidden2logit = nn.ModuleDict({
|
||
# f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
|
||
# })
|
||
|
||
def _make_nonlinear_layer(self):
|
||
pass
|
||
def _make_projection_layer(self, vocab, dim):
|
||
vocab_sizes = vocab.get_vocab_size()
|
||
self.vocab_sizes = vocab_sizes
|
||
self.max_vocab_size = max(vocab_sizes.values())
|
||
self.projection_keys = list(vocab_sizes.keys()) # For index order
|
||
|
||
# ✅ 保留原来的 Linear 层(这样 state_dict 可以匹配)
|
||
self.hidden2logit = nn.ModuleDict({
|
||
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
|
||
})
|
||
|
||
# # ✅ 构建用于 block 并行的权重
|
||
# weight_list = []
|
||
# bias_list = []
|
||
|
||
# for key in self.projection_keys:
|
||
# layer = self.hidden2logit[f"layer_{key}"]
|
||
# w = layer.weight
|
||
# b = layer.bias
|
||
|
||
# # pad to max_vocab_size
|
||
# w_padded = F.pad(w, (0, 0, 0, self.max_vocab_size - w.shape[0]))
|
||
# b_padded = F.pad(b, (0, self.max_vocab_size - b.shape[0]))
|
||
|
||
# weight_list.append(w_padded.unsqueeze(0)) # (1, Vmax, D)
|
||
# bias_list.append(b_padded.unsqueeze(0)) # (1, Vmax)
|
||
|
||
# self.register_buffer("proj_weight", torch.cat(weight_list, dim=0)) # (F, Vmax, D)
|
||
# self.register_buffer("proj_bias", torch.cat(bias_list, dim=0)) # (F, Vmax)
|
||
class FeedForward(SubDecoderClass):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
'''
|
||
FeedForward sub-decoder is used for compound token like CP or NB.
|
||
We followed the original sub-decoder proposed in the paper "Compound Word Transformer",
|
||
however the embedding size for each sub-token or musical feature is the same in our implementation.
|
||
The reason for that is we didn't find any significant difference in the performance of the model
|
||
|
||
There are two types of decoding style for the FeedForward sub-decoder:
|
||
1. Partial-sequential prediction: Predict type token first and then predict all the sub-tokens in parallel (origianl CP)
|
||
2. Fully-sequential prediction: Predict all the sub-tokens sequentially
|
||
'''
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
|
||
def _make_projection_layer(self, vocab, dim):
|
||
vocab_sizes = vocab.get_vocab_size()
|
||
self.hidden2logit = nn.ModuleDict({
|
||
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
|
||
})
|
||
self.catvec2hidden = nn.ModuleDict({
|
||
f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items()
|
||
})
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec']
|
||
target = input_dict['target']
|
||
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
for feature in self.prediction_order:
|
||
if isinstance(feature, str):
|
||
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec)
|
||
logits_dict[feature] = logit
|
||
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
sampled_token_dict[feature] = sampled_token
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token) # B x T x emb_size
|
||
catvec = torch.cat([hidden_vec, feature_emb.unsqueeze(0)], dim=-1)
|
||
hidden_vec = self.catvec2hidden[f"layer_{feature}"](catvec)
|
||
else:
|
||
assert feature == self.prediction_order[-1], "Parallel prediction should be the last feature"
|
||
for par_feature in feature:
|
||
logit = self.hidden2logit[f"layer_{par_feature}"](hidden_vec)
|
||
logits_dict[par_feature] = logit
|
||
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
sampled_token_dict[par_feature] = sampled_token
|
||
return logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
for feature in self.prediction_order:
|
||
if isinstance(feature, str):
|
||
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec)
|
||
logits_dict[feature] = logit
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., self.vocab.feature_list.index(feature)]) # B x T x emb_size
|
||
catvec = torch.cat([hidden_vec, feature_emb], dim=-1)
|
||
hidden_vec = self.catvec2hidden[f"layer_{feature}"](catvec)
|
||
else:
|
||
assert feature == self.prediction_order[-1], "Parallel prediction should be the last feature"
|
||
for par_feature in feature:
|
||
logit = self.hidden2logit[f"layer_{par_feature}"](hidden_vec)
|
||
logits_dict[par_feature] = logit
|
||
return logits_dict
|
||
|
||
class Parallel(SubDecoderClass):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
'''
|
||
Parallel sub-decoder is used for parallel prediction of multiple sub-tokens or musical features
|
||
This method is proposed in the paper "Multitrack Music Transformer"
|
||
'''
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec']
|
||
target = input_dict['target']
|
||
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
for feature in self.prediction_order:
|
||
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
sampled_token_dict[feature] = sampled_token
|
||
return logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
for feature in self.prediction_order:
|
||
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec)
|
||
logits_dict[feature] = logit
|
||
return logits_dict
|
||
|
||
class RNN(SubDecoderClass):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
'''
|
||
RNN sub-decoder is used for sequential prediction of multiple sub-tokens or musical features
|
||
This method is similar to the method proposed in "PianoTree VAE"
|
||
'''
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
|
||
|
||
self.pos_enc = nn.Embedding(len(prediction_order), dim)
|
||
nn.init.zeros_(self.pos_enc.weight)
|
||
|
||
self.decoding_rnn = nn.GRU(
|
||
input_size=dim,
|
||
hidden_size=dim,
|
||
num_layers=sub_decoder_depth,
|
||
dropout=dropout,
|
||
batch_first=True)
|
||
|
||
def _apply_pos_enc(self, tgt, apply_type='last'):
|
||
if apply_type == 'all':
|
||
pos = torch.arange(tgt.shape[1]).to(tgt.device)
|
||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
|
||
tgt_pos = tgt + self.pos_enc(pos.long())
|
||
elif apply_type == 'last':
|
||
pos = torch.arange(tgt.shape[1]).to(tgt.device)
|
||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
|
||
pos_emb = self.pos_enc(pos.long())
|
||
# zero out the pos_emb except for the last token
|
||
pos_emb[:, :-1, :] = 0
|
||
tgt_pos = tgt + pos_emb
|
||
return tgt_pos
|
||
|
||
def _prepare_token_embedding_for_teacher_forcing(self, input_seq, target):
|
||
for feature in self.prediction_order[:-1]:
|
||
feature_idx = self.vocab.feature_list.index(feature)
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||
input_seq = torch.cat([input_seq, feature_emb_reshape], dim=1)
|
||
return input_seq
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||
target = input_dict['target'] # B x T x num_sub_tokens-1
|
||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], -1)).unsqueeze(1) # (B*T) x 1 x d_model
|
||
input_seq = hidden_vec_reshape # (B*T) x 1 x d_model
|
||
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
h_0 = input_seq[:, 0, :].unsqueeze(0) # 1 x (B*T) x d_model
|
||
input_seq = self._apply_pos_enc(input_seq, apply_type='all') # (B*T) x 1 x d_model
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
input_seq, _ = self.decoding_rnn(input_seq, h_0) # input_seq: (B*T) x (idx+1) x hidden_size, h_n: num_layers x (B*T) x hidden_size
|
||
logit = self.hidden2logit[f"layer_{feature}"](input_seq[:, -1, :]) # (B*T) x vocab_size
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
sampled_token_dict[feature] = sampled_token
|
||
if idx == len(self.prediction_order)-1:
|
||
return logits_dict, sampled_token_dict
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token) # B x T x emb_size
|
||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||
input_seq = torch.cat([input_seq, feature_emb_reshape], dim=1) # (B*T) x (idx+2) x d_model
|
||
input_seq = self._apply_pos_enc(input_seq, apply_type='last') # (B*T) x (idx+2) x d_model
|
||
return logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
input_seq = self._prepare_token_embedding_for_teacher_forcing(input_seq, target) # (B*T) x len(prediction_order) x d_model
|
||
# initial hidden state has no positional encoding
|
||
h0 = input_seq[:, 0, :].unsqueeze(0) # 1 x (B*T) x d_model
|
||
h0 = h0.contiguous()
|
||
# apply positional encoding
|
||
input_seq = self._apply_pos_enc(input_seq, apply_type='all') # (B*T) x len(prediction_order) x d_model
|
||
# get output using rnn
|
||
output, _ = self.decoding_rnn(input_seq, h0) # (B*T) x len(prediction_order) x d_model
|
||
output = output.reshape((hidden_vec.shape[0], hidden_vec.shape[1], len(self.prediction_order), -1)) # B x T x len(prediction_order) x d_model
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
logit = self.hidden2logit[f"layer_{feature}"](output[:, :, idx, :]) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
return logits_dict
|
||
|
||
class SelfAttention(SubDecoderClass):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
'''
|
||
This sub-decoder is used for sequential prediction of multiple sub-tokens or musical features
|
||
This method is similar to the method proposed in "UniAudio", but different in making the sequence of sub-tokens.
|
||
The UniAudio adds the output of the main decoder or hidden vec directly to embedding of the sub-token,
|
||
while our method puts the hidden vec in the input sequence so that the attention mechanism can learn the relationship between the hidden vec and the sub-token
|
||
'''
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
|
||
|
||
self.pos_enc = nn.Embedding(1 + len(prediction_order), dim)
|
||
nn.init.zeros_(self.pos_enc.weight)
|
||
|
||
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||
|
||
window_size = 1 # number of previous output of the main decoder to be used in the sub-decoder
|
||
causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size)
|
||
self.register_buffer('causal_mask', causal_mask)
|
||
|
||
self.transformer_decoder = Decoder(
|
||
dim = dim,
|
||
depth = sub_decoder_depth,
|
||
heads = heads,
|
||
attn_dropout = dropout,
|
||
ff_dropout = dropout,
|
||
attn_flash = True)
|
||
# add final dropout
|
||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||
self._apply_xavier_init()
|
||
print('Adding dropout after feedforward layer in x-transformer')
|
||
self._add_dropout_after_ff(dropout)
|
||
print('Adding dropout after attention layer in x-transformer')
|
||
self._add_dropout_after_attn(dropout)
|
||
|
||
def _add_dropout_after_attn(self, dropout):
|
||
for layer in self.transformer_decoder.layers:
|
||
if 'Attention' in str(type(layer[1])):
|
||
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||
layer[1].to_out.append(nn.Dropout(dropout))
|
||
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||
else:
|
||
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||
|
||
def _add_dropout_after_ff(self, dropout):
|
||
for layer in self.transformer_decoder.layers:
|
||
if 'FeedForward' in str(type(layer[1])):
|
||
layer[1].ff.append(nn.Dropout(dropout))
|
||
|
||
def _apply_xavier_init(self):
|
||
for name, param in self.transformer_decoder.named_parameters():
|
||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||
|
||
def _apply_pos_enc(self, tgt, apply_type='last'):
|
||
if apply_type == 'all':
|
||
pos = torch.arange(tgt.shape[1]).to(tgt.device)
|
||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
|
||
tgt_pos = tgt + self.pos_enc(pos.long())
|
||
elif apply_type == 'last':
|
||
pos = torch.arange(tgt.shape[1]).to(tgt.device)
|
||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
|
||
pos_emb = self.pos_enc(pos.long()) # (B*T) x (window_size + BOS + num_sub_tokens-1) x dim
|
||
# zero out the pos_emb except for the last token
|
||
pos_emb[:, :-1, :] = 0
|
||
tgt_pos = tgt + pos_emb
|
||
return tgt_pos
|
||
|
||
def _prepare_input_seq_list(self, hidden_vec_reshape, target=None):
|
||
input_seq_list = []
|
||
input_seq_list.append(hidden_vec_reshape)
|
||
BOS_emb = self.sub_decoder_BOS_emb.unsqueeze(0).repeat(hidden_vec_reshape.shape[0], 1, 1) # (B*T) x 1 x d_model
|
||
if target is None:
|
||
input_seq_list.append(BOS_emb[-1:, :, :])
|
||
else: # training
|
||
input_seq_list.append(BOS_emb)
|
||
return input_seq_list
|
||
|
||
def _prepare_token_embedding_for_teacher_forcing(self, input_seq_list, target):
|
||
for feature in self.prediction_order[:-1]:
|
||
feature_idx = self.vocab.feature_list.index(feature)
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||
input_seq_list.append(feature_emb_reshape)
|
||
memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||
return memory_tensor
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||
target = input_dict['target'] # B x T x num_sub_tokens
|
||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||
input_seq_list = self._prepare_input_seq_list(hidden_vec_reshape, target)
|
||
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
input_seq_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS) x d_model
|
||
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='all') # (B*T) x (window_size + BOS) x d_model
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
output = self.transformer_decoder(pos_target_tensor)
|
||
logit = self.hidden2logit[f"layer_{feature}"](output[:, -1:])
|
||
logits_dict[feature] = logit.reshape((1, 1, -1)) # 1 x 1 x vocab_size
|
||
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
sampled_token_dict[feature] = sampled_token
|
||
if idx == len(self.prediction_order)-1:
|
||
return logits_dict, sampled_token_dict
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||
input_seq_list.append(feature_emb_reshape)
|
||
input_seq_tensor = torch.cat(input_seq_list, dim=1)
|
||
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='last')
|
||
return logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
# preparing for training
|
||
input_seq_tensor = self._prepare_token_embedding_for_teacher_forcing(input_seq_list, target) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='all') # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||
# get output using self-attention
|
||
output = self.transformer_decoder(pos_target_tensor)
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
return logits_dict
|
||
|
||
class SelfAttentionUniAudio(SelfAttention):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
'''
|
||
Uniaudio version of self-attention sub-decoder
|
||
Through the experiments, we found that the performance of the model is better than our proposed self-attention sub-decoder
|
||
It shows comparable performance with the cross-attention sub-decoder
|
||
However, NMT shows better performance than UniAudio in terms of the performance of the model
|
||
'''
|
||
|
||
def _prepare_token_embedding_for_teacher_forcing(self, hidden_vec_reshape, target):
|
||
input_seq_list = []
|
||
# append zero vector
|
||
input_seq_list.append(torch.zeros(hidden_vec_reshape.shape[0], 1, hidden_vec_reshape.shape[2]).to(self.device))
|
||
for feature in self.prediction_order[:-1]:
|
||
feature_idx = self.vocab.feature_list.index(feature)
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||
input_seq_list.append(feature_emb_reshape)
|
||
|
||
feature_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x num_sub-tokens x d_model
|
||
# Ensure hidden_vec_reshape and feature_tensor have the same shape
|
||
assert hidden_vec_reshape.shape == feature_tensor.shape, f"Shapes of hidden_vec_reshape and feature_tensor do not match: {hidden_vec_reshape.shape} vs {feature_tensor.shape}"
|
||
# Sum hidden_vec_reshape and feature_tensor in the last dimension
|
||
memory_tensor = hidden_vec_reshape + feature_tensor
|
||
return memory_tensor
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||
target = input_dict['target'] # B x T x num_sub-tokens
|
||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||
hidden_vec_reshape = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub-tokens x d_model
|
||
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
pos_target_tensor = self._apply_pos_enc(hidden_vec_reshape, apply_type='all') # (B*T) x (window_size + BOS) x d_model
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
output = self.transformer_decoder(pos_target_tensor)
|
||
logit = self.hidden2logit[f"layer_{feature}"](output[:, -1:])
|
||
logits_dict[feature] = logit.reshape((1, 1, -1)) # 1 x 1 x vocab_size
|
||
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
sampled_token_dict[feature] = sampled_token
|
||
if idx == len(self.prediction_order)-1:
|
||
return logits_dict, sampled_token_dict
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||
pos_target_tensor = torch.cat([pos_target_tensor[:, :idx+1, :], feature_emb_reshape + pos_target_tensor[:, idx+1:idx+2, :], pos_target_tensor[:, idx+2:, :]], dim=1)
|
||
|
||
return logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
# preparing for training
|
||
input_seq_tensor = self._prepare_token_embedding_for_teacher_forcing(hidden_vec_reshape, target) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='all') # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||
# get output using self-attention
|
||
output = self.transformer_decoder(pos_target_tensor)
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
return logits_dict
|
||
|
||
class CrossAttention(SubDecoderClass):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
'''
|
||
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
|
||
As the output of the main decoder is the representation of the whole sequence,
|
||
it contains richer information which can even decode out sub-tokens in a parallel manner
|
||
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
|
||
'''
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
self.sub_decoder_enricher_use = sub_decoder_enricher_use
|
||
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
|
||
|
||
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
|
||
nn.init.zeros_(self.pos_enc.weight)
|
||
|
||
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||
if sub_decoder_enricher_use:
|
||
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||
causal_mask = generate_SA_mask(len(prediction_order))
|
||
causl_ca_mask = generate_CA_mask(len(prediction_order), len(prediction_order)).to(self.device)
|
||
self.register_buffer('causal_mask', causal_mask)
|
||
self.register_buffer('causal_ca_mask', causl_ca_mask)
|
||
|
||
if sub_decoder_depth > 1:
|
||
self.sub_decoder_layers = nn.Sequential(
|
||
*[DecoderLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)]
|
||
)
|
||
else:
|
||
self.sub_decoder_layers = nn.Sequential(DecoderLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||
if sub_decoder_enricher_use:
|
||
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
|
||
|
||
def _apply_window_on_hidden_vec(self, hidden_vec):
|
||
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
|
||
window_size = 1
|
||
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
|
||
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
|
||
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
|
||
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
|
||
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
|
||
return new_hidden_vec
|
||
|
||
def _apply_pos_enc(self, tgt):
|
||
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
|
||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
|
||
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
|
||
return tgt_pos
|
||
|
||
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
|
||
for _, feature in enumerate(self.prediction_order[:-1]):
|
||
feature_idx = self.vocab.feature_list.index(feature)
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||
memory_list.append(feature_emb_reshape)
|
||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
|
||
return memory_tensor
|
||
|
||
def _prepare_memory_list(self, hidden_vec, target=None):
|
||
memory_list = [] # used for key and value in cross attention
|
||
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||
if target is not None: # training
|
||
memory_list.append(BOS_emb)
|
||
else: # inference
|
||
memory_list.append(BOS_emb[-1:, :, :])
|
||
return memory_list
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||
target = input_dict['target']
|
||
|
||
# apply window on hidden_vec for enricher
|
||
if self.sub_decoder_enricher_use:
|
||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||
input_seq_pos = self._apply_pos_enc(input_seq)
|
||
# prepare memory
|
||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target)
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x 1 x d_model
|
||
old_memory_tensor = memory_tensor
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec[-1:]}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq']
|
||
CA_attn_mask = generate_CA_mask(input_seq_pos.shape[1], memory_tensor.shape[1]).to(self.device)
|
||
input_dict = {'input_seq': input_seq_pos[-1:], 'memory': memory_tensor, 'memory_mask': CA_attn_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq']
|
||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||
logit = logit.reshape((1, 1, -1)) # 1 x 1 x vocab_size
|
||
logits_dict[feature] = logit
|
||
sampled_token,prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
sampled_token_dict[feature] = sampled_token
|
||
if idx == len(self.prediction_order)-1:
|
||
return logits_dict, sampled_token_dict
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||
memory_list.append(feature_emb_reshape)
|
||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + idx+1) x d_model
|
||
return logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
memory_tensor = self._prepare_token_embedding_for_teacher_forcing(memory_list, target) # (B*T) x (BOS + num_sub_tokens-1) x d_model
|
||
# apply feature enricher to memory
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# implement sub decoder cross attention
|
||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# get prob
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
return logits_dict
|
||
|
||
class Flatten4Encodec(SubDecoderClass):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool
|
||
):
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||
hidden_vec = input_dict['hidden_vec']
|
||
|
||
# ---- Training ---- #
|
||
logits_tensor = torch.zeros(hidden_vec.shape[0], hidden_vec.shape[1], 2049).to(self.device)
|
||
for idx, feature_type in enumerate(self.prediction_order):
|
||
# ::4 means that we only use the first token in each 4 tokens
|
||
# so the chosen tokens will be: 0, 4, 8, 12, ...
|
||
# 1::4 means that we only use the second token in each 4 tokens
|
||
# so the chosen tokens will be: 1, 5, 9, 13, ...
|
||
separated_hidden_vec = hidden_vec[:, idx::4, :]
|
||
logit = self.hidden2logit[f"layer_{feature_type}"](separated_hidden_vec)
|
||
logits_tensor[:, idx::4, :] = logit
|
||
# prob_dict[feature_type] = prob
|
||
return logits_tensor
|
||
|
||
def run_one_step(self, input_dict, sampling_method=None, threshold=None, temperature=None, feature_type=None):
|
||
# ---- Generate(Inference) ---- #
|
||
hidden_vec = input_dict['hidden_vec']
|
||
logit = self.hidden2logit[f"layer_{feature_type}"](hidden_vec[:, -1:])
|
||
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
return logit, sampled_token
|
||
|
||
|
||
class DiffusionDecoder(SubDecoderClass):
|
||
def __init__(
|
||
self,
|
||
prediction_order:list,
|
||
vocab:LangTokenVocab,
|
||
sub_decoder_depth:int,
|
||
dim:int,
|
||
heads:int,
|
||
dropout:float,
|
||
sub_decoder_enricher_use:bool,
|
||
MASK_IDX:int = 126336,
|
||
denoising_steps:int = 8,
|
||
eps:float = 1e-3,
|
||
method:str = 'low-confidence', # or random or auto-regressive
|
||
):
|
||
'''
|
||
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
|
||
As the output of the main decoder is the representation of the whole sequence,
|
||
it contains richer information which can even decode out sub-tokens in a parallel manner
|
||
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
|
||
'''
|
||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||
self.sub_decoder_enricher_use = sub_decoder_enricher_use
|
||
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
|
||
|
||
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
|
||
nn.init.zeros_(self.pos_enc.weight)
|
||
|
||
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
|
||
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
|
||
self.MASK_idx = MASK_IDX
|
||
self.denoising_steps = denoising_steps
|
||
self.eps = eps
|
||
self.method = method
|
||
|
||
self.input_norm = nn.LayerNorm(dim)
|
||
|
||
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||
|
||
if sub_decoder_enricher_use:
|
||
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||
causal_mask = generate_SA_mask(len(prediction_order))
|
||
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
|
||
self.register_buffer('causal_mask', causal_mask)
|
||
self.register_buffer('causal_ca_mask', causal_ca_mask)
|
||
|
||
# get depth of the sub-decoder
|
||
if sub_decoder_depth > 1:
|
||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||
else:
|
||
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||
if sub_decoder_enricher_use:
|
||
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
|
||
|
||
|
||
# simplified version of the forward process in diffusion model
|
||
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
|
||
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
|
||
b, l = reshaped_input_ids.shape
|
||
t = torch.rand(b, device=input_ids.device)
|
||
p_mask = (1 - eps) * t + eps
|
||
p_mask = p_mask[:, None].repeat(1, l)
|
||
|
||
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
||
# 126336 is used for [MASK] token,attention that this token is not in the vocab
|
||
if mask_idx is not None:
|
||
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
|
||
else:
|
||
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
|
||
return noisy_batch, masked_indices, p_mask
|
||
|
||
|
||
def _apply_window_on_hidden_vec(self, hidden_vec):
|
||
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
|
||
window_size = 1
|
||
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
|
||
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
|
||
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
|
||
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
|
||
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
|
||
return new_hidden_vec
|
||
|
||
def _apply_pos_enc(self, tgt):
|
||
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
|
||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
|
||
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
|
||
return tgt_pos
|
||
|
||
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
|
||
for _, feature in enumerate(self.prediction_order[:-1]):
|
||
feature_idx = self.vocab.feature_list.index(feature)
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||
memory_list.append(feature_emb_reshape)
|
||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
|
||
return memory_tensor
|
||
|
||
# return a tensor
|
||
def _get_noisy_tensor(self, target_shape):
|
||
new_target = torch.zeros(target_shape).to(self.device)
|
||
# fill all the elements in the tensor with the embedding of the mask token
|
||
new_target[:, :, :] = self.diffusion_mask_emb
|
||
return new_target
|
||
|
||
# prepare the embedding of the target,
|
||
def _prepare_embedding(self, memory_list, target):
|
||
for _, feature in enumerate(self.prediction_order):
|
||
feature_idx = self.vocab.feature_list.index(feature)
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||
memory_list.append(feature_emb_reshape)
|
||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
|
||
return memory_tensor
|
||
|
||
|
||
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
|
||
memory_list = [] # used for key and value in cross attention
|
||
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||
if add_BOS is true:
|
||
if target is not None: # training
|
||
memory_list.append(BOS_emb)
|
||
else: # inference
|
||
memory_list.append(BOS_emb[-1:, :, :])
|
||
else:
|
||
pass
|
||
return memory_list
|
||
|
||
def _get_num_transfer_tokens(self, mask_index, steps):
|
||
'''
|
||
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
||
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
||
the expected number of tokens transitioned at each step should be consistent.
|
||
|
||
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
||
'''
|
||
mask_num = mask_index.sum(dim=1, keepdim=True)
|
||
base = mask_num // steps
|
||
remainder = mask_num % steps
|
||
|
||
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
||
|
||
for i in range(mask_num.size(0)):
|
||
num_transfer_tokens[i, :remainder[i]] += 1
|
||
|
||
return num_transfer_tokens
|
||
|
||
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
|
||
sampled_token_dict = {}
|
||
logits_dict = {}
|
||
candidate_token_embeddings = {}
|
||
candidate_token_probs = {}
|
||
b,t,d = hidden_vec.shape # B x T x d_model
|
||
# print("*"*8)
|
||
logits_list = []
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_list.append(logit)
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
logit = logits_list[idx] # B x T x vocab_siz
|
||
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
if step==0 and force_decode:
|
||
if feature == 'velocity':
|
||
sampled_token = torch.tensor([2]).to(logit.device)
|
||
prob = torch.tensor([1.0]).to(logit.device)
|
||
else:
|
||
prob = torch.tensor([0.0]).to(logit.device)
|
||
# print(feature, sampled_token, prob)
|
||
sampled_token_dict[feature] = sampled_token
|
||
logits_dict[feature] = logit
|
||
candidate_token_probs[feature] = prob
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
|
||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
|
||
# print("sampled_token_dict", sampled_token_dict)
|
||
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
|
||
|
||
def sample_from_logits_fast(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None):
|
||
sampled_token_dict = {}
|
||
logits_dict = {}
|
||
candidate_token_embeddings = {}
|
||
candidate_token_probs = {}
|
||
|
||
b, t, d = hidden_vec.shape # (B, T, D)
|
||
F = len(self.projection_keys)
|
||
Vmax = self.max_vocab_size
|
||
|
||
# === 1. 取出所有 feature 的位置 === #
|
||
feature_pos_list = [self.feature_order_in_output[f] for f in self.projection_keys]
|
||
|
||
# === 2. 提取 attn_output 中各 feature 的位置 → (B, F, D) === #
|
||
attn_features = torch.stack(
|
||
[attn_output[:, pos, :] for pos in feature_pos_list], dim=1
|
||
) # (B, F, D)
|
||
|
||
# === 3. 使用 batch 矩阵乘法:einsum 实现并行 Linear === #
|
||
# attn_features: (B, F, D)
|
||
# proj_weight: (F, Vmax, D)
|
||
# proj_bias: (F, Vmax)
|
||
# output: (B, F, Vmax)
|
||
logits = torch.einsum("bfd,fvd->bfv", attn_features, self.proj_weight) + self.proj_bias
|
||
|
||
# === 4. 按照原始 vocab size 截断每个 feature 的 logits === #
|
||
logits_list = []
|
||
logits_dict_by_feature = {
|
||
feature: logits[:, i, :self.vocab_sizes[feature]]
|
||
for i, feature in enumerate(self.projection_keys)
|
||
}
|
||
for i, feature in enumerate(self.projection_keys):
|
||
vocab_size = self.vocab_sizes[feature]
|
||
logits_list.append(logits[:, i, :vocab_size]) # (B, vocab_size)
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
logit = logits_dict_by_feature[feature].unsqueeze(0) # B x T x vocab_size
|
||
sampled_token, prob = sample_with_prob_fast(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
# print(feature, sampled_token, prob)
|
||
sampled_token_dict[feature] = sampled_token.squeeze(0) # B x T
|
||
logits_dict[feature] = logit
|
||
candidate_token_probs[feature] = prob
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
|
||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
|
||
|
||
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
|
||
|
||
def choose_tokens(self, hidden_vec, step, method, stacked_logits_probs, num_transfer_tokens):
|
||
if method == 'low-confidence':
|
||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||
elif method == 'random':
|
||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||
elif method == 'auto-regressive':
|
||
indices = torch.tensor([[step]], device=hidden_vec.device)
|
||
return indices
|
||
|
||
|
||
def forward_(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||
target = input_dict['target'] #B x T x d_model
|
||
|
||
|
||
# apply window on hidden_vec for enricher
|
||
if self.sub_decoder_enricher_use:
|
||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||
input_seq_pos = input_seq
|
||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||
# prepare memory
|
||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
b,t,d = hidden_vec.shape # B x T x d_model
|
||
l = len(self.prediction_order) # num_sub_tokens
|
||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||
|
||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||
# denoising c
|
||
stored_logits_dict = {}
|
||
stored_probs_dict = {}
|
||
for step in range(self.denoising_steps):
|
||
# nomalize the memory tensor
|
||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||
# input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
candidate_token_probs = {}
|
||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
|
||
# set prob of the changed tokens to -inf
|
||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||
# indices = self.choose_tokens(hidden_vec,step, "auto-regressive", stacked_logits_probs, num_transfer_tokens)
|
||
indices = self.choose_tokens(hidden_vec, step, self.method, stacked_logits_probs, num_transfer_tokens)
|
||
# breakpoint()
|
||
# undate the masked history
|
||
for i in range(b*t):
|
||
for j in range(l):
|
||
if j in indices[i]:
|
||
masked_history[i][j] = False
|
||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||
# breakpoint()
|
||
# print("stored_probs_dict", stored_probs_dict)
|
||
# print("sampled_token_dict", sampled_token_dict)
|
||
return stored_logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||
# apply layer norm
|
||
|
||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||
if worst_case: # mask all ,turn into parallel
|
||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# get prob
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
return logits_dict, (masked_indices, p_mask)
|
||
|
||
def forward_old(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||
target = input_dict['target'] #B x T x d_model
|
||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||
|
||
# apply window on hidden_vec for enricher
|
||
if self.sub_decoder_enricher_use:
|
||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||
|
||
if bos_hidden_vec is None: # start of generation
|
||
if target is None:
|
||
bos_hidden_vec = input_seq_pos
|
||
else:
|
||
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
|
||
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
|
||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
|
||
|
||
else:
|
||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||
|
||
# input_seq_pos = input_seq
|
||
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
|
||
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
|
||
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
|
||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||
# prepare memory
|
||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
b,t,d = hidden_vec.shape # B x T x d_model
|
||
l = len(self.prediction_order) # num_sub_tokens
|
||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||
|
||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||
# denoising c
|
||
stored_logits_dict = {}
|
||
stored_probs_dict = {}
|
||
for step in range(self.denoising_steps):
|
||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||
# nomalize the memory tensor
|
||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
candidate_token_probs = {}
|
||
candidate_token_embeddings = {}
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
sampled_token,probs = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||
# print(idx,feature,sampled_token,probs)
|
||
sampled_token_dict[feature] = sampled_token
|
||
candidate_token_probs[feature] = probs
|
||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||
|
||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, l)) # (B*T) x num_sub_tokens x vocab_size
|
||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, l, d))
|
||
|
||
# set prob of the changed tokens to -inf
|
||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||
|
||
if self.method == 'low-confidence':
|
||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||
elif self.method == 'random':
|
||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||
elif self.method == 'auto-regressive':
|
||
indices = torch.tensor([[step]], device=logit.device)
|
||
# undate the masked history
|
||
for i in range(b*t):
|
||
for j in range(l):
|
||
if j in indices[i]:
|
||
masked_history[i][j] = False
|
||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||
return stored_logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||
# apply layer norm
|
||
|
||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||
if worst_case: # mask all ,turn into parallel
|
||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||
# all is embedding
|
||
# memory_tensor = self.layer_norm(memory_tensor)
|
||
# apply feature enricher to memory
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# implement sub decoder cross attention
|
||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||
# inter_input = torch.cat([input_seq_pos, memory_tensor], dim=1)
|
||
# inter_input = input_seq_pos + memory_tensor # (B*T) x num_sub_tokens x d_model
|
||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# get prob
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
return logits_dict, (masked_indices, p_mask)
|
||
|
||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, validation=False):
|
||
logits_dict = {}
|
||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||
target = input_dict['target'] #B x T x d_model
|
||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||
|
||
# apply window on hidden_vec for enricher
|
||
if self.sub_decoder_enricher_use:
|
||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||
|
||
if bos_hidden_vec is None: # start of generation
|
||
if target is None:
|
||
bos_hidden_vec = input_seq_pos
|
||
else:
|
||
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
|
||
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
|
||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
|
||
|
||
else:
|
||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||
|
||
# input_seq_pos = input_seq
|
||
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
|
||
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
|
||
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
|
||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||
# prepare memory
|
||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||
# ---- Generate(Inference) ---- #
|
||
if target is None:
|
||
sampled_token_dict = {}
|
||
b,t,d = hidden_vec.shape # B x T x d_model
|
||
l = len(self.prediction_order) # num_sub_tokens
|
||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||
|
||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||
# denoising c
|
||
stored_logits_dict = {}
|
||
stored_probs_dict = {}
|
||
# with torch.profiler.profile(
|
||
# activities=[
|
||
# torch.profiler.ProfilerActivity.CPU,
|
||
# torch.profiler.ProfilerActivity.CUDA],
|
||
# record_shapes=True,
|
||
# profile_memory=True,
|
||
# with_stack=True
|
||
# ) as prof:
|
||
for step in range(self.denoising_steps):
|
||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||
# nomalize the memory tensor
|
||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
candidate_token_probs = {}
|
||
|
||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
|
||
force_decode=Force_decode,
|
||
step=step)
|
||
|
||
# set prob of the changed tokens to -inf
|
||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||
|
||
if self.method == 'low-confidence':
|
||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||
elif self.method == 'random':
|
||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||
elif self.method == 'auto-regressive':
|
||
indices = torch.tensor([[step]], device=logit.device)
|
||
# undate the masked history
|
||
for i in range(b*t):
|
||
for j in range(l):
|
||
if j in indices[i]:
|
||
masked_history[i][j] = False
|
||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
|
||
# print(sampled_token_dict)
|
||
return stored_logits_dict, sampled_token_dict
|
||
|
||
# ---- Training ---- #
|
||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||
# apply layer norm
|
||
|
||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||
if worst_case: # mask all ,turn into parallel
|
||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||
# all is embedding
|
||
# memory_tensor = self.layer_norm(memory_tensor)
|
||
# apply feature enricher to memory
|
||
if self.sub_decoder_enricher_use:
|
||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||
input_dict = self.feature_enricher_layers(input_dict)
|
||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# implement sub decoder cross attention
|
||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||
input_dict = self.sub_decoder_layers(input_dict)
|
||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||
# get prob
|
||
for idx, feature in enumerate(self.prediction_order):
|
||
feature_pos = self.feature_order_in_output[feature]
|
||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||
logits_dict[feature] = logit
|
||
return logits_dict, (masked_indices, p_mask) |