Files
MIDIFoundationModel/Amadeus/sub_decoder_zoo.py
2025-11-27 15:44:17 +08:00

1450 lines
75 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from re import T
from selectors import EpollSelector
from turtle import st
from numpy import indices
from sympy import Trace, false, true
import torch
import torch.profiler
import torch.nn as nn
from .custom_x_transformers import Decoder
from .transformer_utils import MultiEmbedding, RVQMultiEmbedding
from .sub_decoder_utils import *
from .sampling_utils import sample, sample_with_prob, sample_with_prob_fast, top_p_sampling, typical_sampling, eta_sampling
from data_representation.vocab_utils import LangTokenVocab
class SingleProjection(nn.Module):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
'''
This sub-decoder is used for REMI based models
'''
super().__init__()
vocab_size = vocab.get_vocab_size()
self.proj = nn.Linear(dim, vocab_size)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=1):
hidden_vec = input_dict['hidden_vec']
target = input_dict['target']
# ---- Generate(Inference) ---- #
if target is None:
logits = self.proj(hidden_vec[:, -1:])
sampled_token = sample(logits, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
return logits, sampled_token
# ---- Training ---- #
logits = self.proj(hidden_vec)
return logits
class SubDecoderClass(nn.Module):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
super().__init__()
'''
This is the base class for all sub-decoders
'''
self.prediction_order = prediction_order
self.vocab = vocab
self.vocab_size = vocab.get_vocab_size()
# make layers
self._make_emb_layer(vocab, dim)
self._make_projection_layer(vocab, dim)
self._make_nonlinear_layer()
@property
def device(self):
return next(self.parameters()).device
def _make_emb_layer(self, vocab, dim):
self.emb_layer = MultiEmbedding(
vocab=vocab,
dim_model=dim
)
# def _make_projection_layer(self, vocab, dim):
# vocab_sizes = vocab.get_vocab_size()
# self.hidden2logit = nn.ModuleDict({
# f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
# })
def _make_nonlinear_layer(self):
pass
def _make_projection_layer(self, vocab, dim):
vocab_sizes = vocab.get_vocab_size()
self.vocab_sizes = vocab_sizes
self.max_vocab_size = max(vocab_sizes.values())
self.projection_keys = list(vocab_sizes.keys()) # For index order
# ✅ 保留原来的 Linear 层(这样 state_dict 可以匹配)
self.hidden2logit = nn.ModuleDict({
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
})
# # ✅ 构建用于 block 并行的权重
# weight_list = []
# bias_list = []
# for key in self.projection_keys:
# layer = self.hidden2logit[f"layer_{key}"]
# w = layer.weight
# b = layer.bias
# # pad to max_vocab_size
# w_padded = F.pad(w, (0, 0, 0, self.max_vocab_size - w.shape[0]))
# b_padded = F.pad(b, (0, self.max_vocab_size - b.shape[0]))
# weight_list.append(w_padded.unsqueeze(0)) # (1, Vmax, D)
# bias_list.append(b_padded.unsqueeze(0)) # (1, Vmax)
# self.register_buffer("proj_weight", torch.cat(weight_list, dim=0)) # (F, Vmax, D)
# self.register_buffer("proj_bias", torch.cat(bias_list, dim=0)) # (F, Vmax)
class FeedForward(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
'''
FeedForward sub-decoder is used for compound token like CP or NB.
We followed the original sub-decoder proposed in the paper "Compound Word Transformer",
however the embedding size for each sub-token or musical feature is the same in our implementation.
The reason for that is we didn't find any significant difference in the performance of the model
There are two types of decoding style for the FeedForward sub-decoder:
1. Partial-sequential prediction: Predict type token first and then predict all the sub-tokens in parallel (origianl CP)
2. Fully-sequential prediction: Predict all the sub-tokens sequentially
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
def _make_projection_layer(self, vocab, dim):
vocab_sizes = vocab.get_vocab_size()
self.hidden2logit = nn.ModuleDict({
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
})
self.catvec2hidden = nn.ModuleDict({
f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items()
})
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec']
target = input_dict['target']
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
for feature in self.prediction_order:
if isinstance(feature, str):
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec)
logits_dict[feature] = logit
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
sampled_token_dict[feature] = sampled_token
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token) # B x T x emb_size
catvec = torch.cat([hidden_vec, feature_emb.unsqueeze(0)], dim=-1)
hidden_vec = self.catvec2hidden[f"layer_{feature}"](catvec)
else:
assert feature == self.prediction_order[-1], "Parallel prediction should be the last feature"
for par_feature in feature:
logit = self.hidden2logit[f"layer_{par_feature}"](hidden_vec)
logits_dict[par_feature] = logit
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
sampled_token_dict[par_feature] = sampled_token
return logits_dict, sampled_token_dict
# ---- Training ---- #
for feature in self.prediction_order:
if isinstance(feature, str):
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec)
logits_dict[feature] = logit
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., self.vocab.feature_list.index(feature)]) # B x T x emb_size
catvec = torch.cat([hidden_vec, feature_emb], dim=-1)
hidden_vec = self.catvec2hidden[f"layer_{feature}"](catvec)
else:
assert feature == self.prediction_order[-1], "Parallel prediction should be the last feature"
for par_feature in feature:
logit = self.hidden2logit[f"layer_{par_feature}"](hidden_vec)
logits_dict[par_feature] = logit
return logits_dict
class Parallel(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
'''
Parallel sub-decoder is used for parallel prediction of multiple sub-tokens or musical features
This method is proposed in the paper "Multitrack Music Transformer"
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec']
target = input_dict['target']
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
for feature in self.prediction_order:
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec) # B x T x vocab_size
logits_dict[feature] = logit
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
sampled_token_dict[feature] = sampled_token
return logits_dict, sampled_token_dict
# ---- Training ---- #
for feature in self.prediction_order:
logit = self.hidden2logit[f"layer_{feature}"](hidden_vec)
logits_dict[feature] = logit
return logits_dict
class RNN(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
'''
RNN sub-decoder is used for sequential prediction of multiple sub-tokens or musical features
This method is similar to the method proposed in "PianoTree VAE"
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(len(prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.decoding_rnn = nn.GRU(
input_size=dim,
hidden_size=dim,
num_layers=sub_decoder_depth,
dropout=dropout,
batch_first=True)
def _apply_pos_enc(self, tgt, apply_type='last'):
if apply_type == 'all':
pos = torch.arange(tgt.shape[1]).to(tgt.device)
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
tgt_pos = tgt + self.pos_enc(pos.long())
elif apply_type == 'last':
pos = torch.arange(tgt.shape[1]).to(tgt.device)
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
pos_emb = self.pos_enc(pos.long())
# zero out the pos_emb except for the last token
pos_emb[:, :-1, :] = 0
tgt_pos = tgt + pos_emb
return tgt_pos
def _prepare_token_embedding_for_teacher_forcing(self, input_seq, target):
for feature in self.prediction_order[:-1]:
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
input_seq = torch.cat([input_seq, feature_emb_reshape], dim=1)
return input_seq
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] # B x T x num_sub_tokens-1
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], -1)).unsqueeze(1) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape # (B*T) x 1 x d_model
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
h_0 = input_seq[:, 0, :].unsqueeze(0) # 1 x (B*T) x d_model
input_seq = self._apply_pos_enc(input_seq, apply_type='all') # (B*T) x 1 x d_model
for idx, feature in enumerate(self.prediction_order):
input_seq, _ = self.decoding_rnn(input_seq, h_0) # input_seq: (B*T) x (idx+1) x hidden_size, h_n: num_layers x (B*T) x hidden_size
logit = self.hidden2logit[f"layer_{feature}"](input_seq[:, -1, :]) # (B*T) x vocab_size
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
sampled_token_dict[feature] = sampled_token
if idx == len(self.prediction_order)-1:
return logits_dict, sampled_token_dict
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
input_seq = torch.cat([input_seq, feature_emb_reshape], dim=1) # (B*T) x (idx+2) x d_model
input_seq = self._apply_pos_enc(input_seq, apply_type='last') # (B*T) x (idx+2) x d_model
return logits_dict, sampled_token_dict
# ---- Training ---- #
input_seq = self._prepare_token_embedding_for_teacher_forcing(input_seq, target) # (B*T) x len(prediction_order) x d_model
# initial hidden state has no positional encoding
h0 = input_seq[:, 0, :].unsqueeze(0) # 1 x (B*T) x d_model
h0 = h0.contiguous()
# apply positional encoding
input_seq = self._apply_pos_enc(input_seq, apply_type='all') # (B*T) x len(prediction_order) x d_model
# get output using rnn
output, _ = self.decoding_rnn(input_seq, h0) # (B*T) x len(prediction_order) x d_model
output = output.reshape((hidden_vec.shape[0], hidden_vec.shape[1], len(self.prediction_order), -1)) # B x T x len(prediction_order) x d_model
for idx, feature in enumerate(self.prediction_order):
logit = self.hidden2logit[f"layer_{feature}"](output[:, :, idx, :]) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict
class SelfAttention(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
'''
This sub-decoder is used for sequential prediction of multiple sub-tokens or musical features
This method is similar to the method proposed in "UniAudio", but different in making the sequence of sub-tokens.
The UniAudio adds the output of the main decoder or hidden vec directly to embedding of the sub-token,
while our method puts the hidden vec in the input sequence so that the attention mechanism can learn the relationship between the hidden vec and the sub-token
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(1 + len(prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
window_size = 1 # number of previous output of the main decoder to be used in the sub-decoder
causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size)
self.register_buffer('causal_mask', causal_mask)
# self.transformer_decoder = Decoder(
# dim = dim,
# depth = sub_decoder_depth,
# heads = heads,
# attn_dropout = dropout,
# ff_dropout = dropout,
# attn_flash = True)
self.transformer_decoder = Decoder(
dim = dim,
depth = sub_decoder_depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
use_rmsnorm=True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def _apply_pos_enc(self, tgt, apply_type='last'):
if apply_type == 'all':
pos = torch.arange(tgt.shape[1]).to(tgt.device)
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
tgt_pos = tgt + self.pos_enc(pos.long())
elif apply_type == 'last':
pos = torch.arange(tgt.shape[1]).to(tgt.device)
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1)
pos_emb = self.pos_enc(pos.long()) # (B*T) x (window_size + BOS + num_sub_tokens-1) x dim
# zero out the pos_emb except for the last token
pos_emb[:, :-1, :] = 0
tgt_pos = tgt + pos_emb
return tgt_pos
def _prepare_input_seq_list(self, hidden_vec_reshape, target=None):
input_seq_list = []
input_seq_list.append(hidden_vec_reshape)
BOS_emb = self.sub_decoder_BOS_emb.unsqueeze(0).repeat(hidden_vec_reshape.shape[0], 1, 1) # (B*T) x 1 x d_model
if target is None:
input_seq_list.append(BOS_emb[-1:, :, :])
else: # training
input_seq_list.append(BOS_emb)
return input_seq_list
def _prepare_token_embedding_for_teacher_forcing(self, input_seq_list, target):
for feature in self.prediction_order[:-1]:
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
input_seq_list.append(feature_emb_reshape)
memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
return memory_tensor
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] # B x T x num_sub_tokens
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq_list = self._prepare_input_seq_list(hidden_vec_reshape, target)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
input_seq_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS) x d_model
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='all') # (B*T) x (window_size + BOS) x d_model
for idx, feature in enumerate(self.prediction_order):
output = self.transformer_decoder(pos_target_tensor)
logit = self.hidden2logit[f"layer_{feature}"](output[:, -1:])
logits_dict[feature] = logit.reshape((1, 1, -1)) # 1 x 1 x vocab_size
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
sampled_token_dict[feature] = sampled_token
if idx == len(self.prediction_order)-1:
return logits_dict, sampled_token_dict
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
input_seq_list.append(feature_emb_reshape)
input_seq_tensor = torch.cat(input_seq_list, dim=1)
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='last')
return logits_dict, sampled_token_dict
# ---- Training ---- #
# preparing for training
input_seq_tensor = self._prepare_token_embedding_for_teacher_forcing(input_seq_list, target) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='all') # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
# get output using self-attention
output = self.transformer_decoder(pos_target_tensor)
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict
class SelfAttentionUniAudio(SelfAttention):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
'''
Uniaudio version of self-attention sub-decoder
Through the experiments, we found that the performance of the model is better than our proposed self-attention sub-decoder
It shows comparable performance with the cross-attention sub-decoder
However, NMT shows better performance than UniAudio in terms of the performance of the model
'''
def _prepare_token_embedding_for_teacher_forcing(self, hidden_vec_reshape, target):
input_seq_list = []
# append zero vector
input_seq_list.append(torch.zeros(hidden_vec_reshape.shape[0], 1, hidden_vec_reshape.shape[2]).to(self.device))
for feature in self.prediction_order[:-1]:
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
input_seq_list.append(feature_emb_reshape)
feature_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x num_sub-tokens x d_model
# Ensure hidden_vec_reshape and feature_tensor have the same shape
assert hidden_vec_reshape.shape == feature_tensor.shape, f"Shapes of hidden_vec_reshape and feature_tensor do not match: {hidden_vec_reshape.shape} vs {feature_tensor.shape}"
# Sum hidden_vec_reshape and feature_tensor in the last dimension
memory_tensor = hidden_vec_reshape + feature_tensor
return memory_tensor
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] # B x T x num_sub-tokens
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
hidden_vec_reshape = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub-tokens x d_model
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
pos_target_tensor = self._apply_pos_enc(hidden_vec_reshape, apply_type='all') # (B*T) x (window_size + BOS) x d_model
for idx, feature in enumerate(self.prediction_order):
output = self.transformer_decoder(pos_target_tensor)
logit = self.hidden2logit[f"layer_{feature}"](output[:, -1:])
logits_dict[feature] = logit.reshape((1, 1, -1)) # 1 x 1 x vocab_size
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
sampled_token_dict[feature] = sampled_token
if idx == len(self.prediction_order)-1:
return logits_dict, sampled_token_dict
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
pos_target_tensor = torch.cat([pos_target_tensor[:, :idx+1, :], feature_emb_reshape + pos_target_tensor[:, idx+1:idx+2, :], pos_target_tensor[:, idx+2:, :]], dim=1)
return logits_dict, sampled_token_dict
# ---- Training ---- #
# preparing for training
input_seq_tensor = self._prepare_token_embedding_for_teacher_forcing(hidden_vec_reshape, target) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
pos_target_tensor = self._apply_pos_enc(input_seq_tensor, apply_type='all') # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
# get output using self-attention
output = self.transformer_decoder(pos_target_tensor)
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict
class CrossAttention(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
'''
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
As the output of the main decoder is the representation of the whole sequence,
it contains richer information which can even decode out sub-tokens in a parallel manner
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.sub_decoder_enricher_use = sub_decoder_enricher_use
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
causal_mask = generate_SA_mask(len(prediction_order))
causl_ca_mask = generate_CA_mask(len(prediction_order), len(prediction_order)).to(self.device)
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causl_ca_mask)
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(
*[DecoderLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)]
)
else:
self.sub_decoder_layers = nn.Sequential(DecoderLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
def _apply_window_on_hidden_vec(self, hidden_vec):
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
window_size = 1
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
return new_hidden_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
return tgt_pos
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
for _, feature in enumerate(self.prediction_order[:-1]):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
return memory_tensor
def _prepare_memory_list(self, hidden_vec, target=None):
memory_list = [] # used for key and value in cross attention
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
if target is not None: # training
memory_list.append(BOS_emb)
else: # inference
memory_list.append(BOS_emb[-1:, :, :])
return memory_list
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target']
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq)
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x 1 x d_model
old_memory_tensor = memory_tensor
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec[-1:]}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq']
CA_attn_mask = generate_CA_mask(input_seq_pos.shape[1], memory_tensor.shape[1]).to(self.device)
input_dict = {'input_seq': input_seq_pos[-1:], 'memory': memory_tensor, 'memory_mask': CA_attn_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq']
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((1, 1, -1)) # 1 x 1 x vocab_size
logits_dict[feature] = logit
sampled_token,prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
sampled_token_dict[feature] = sampled_token
if idx == len(self.prediction_order)-1:
return logits_dict, sampled_token_dict
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + idx+1) x d_model
return logits_dict, sampled_token_dict
# ---- Training ---- #
memory_tensor = self._prepare_token_embedding_for_teacher_forcing(memory_list, target) # (B*T) x (BOS + num_sub_tokens-1) x d_model
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict
class Flatten4Encodec(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool
):
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
hidden_vec = input_dict['hidden_vec']
# ---- Training ---- #
logits_tensor = torch.zeros(hidden_vec.shape[0], hidden_vec.shape[1], 2049).to(self.device)
for idx, feature_type in enumerate(self.prediction_order):
# ::4 means that we only use the first token in each 4 tokens
# so the chosen tokens will be: 0, 4, 8, 12, ...
# 1::4 means that we only use the second token in each 4 tokens
# so the chosen tokens will be: 1, 5, 9, 13, ...
separated_hidden_vec = hidden_vec[:, idx::4, :]
logit = self.hidden2logit[f"layer_{feature_type}"](separated_hidden_vec)
logits_tensor[:, idx::4, :] = logit
# prob_dict[feature_type] = prob
return logits_tensor
def run_one_step(self, input_dict, sampling_method=None, threshold=None, temperature=None, feature_type=None):
# ---- Generate(Inference) ---- #
hidden_vec = input_dict['hidden_vec']
logit = self.hidden2logit[f"layer_{feature_type}"](hidden_vec[:, -1:])
sampled_token = sample(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
return logit, sampled_token
class DiffusionDecoder(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 8,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
'''
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
As the output of the main decoder is the representation of the whole sequence,
it contains richer information which can even decode out sub-tokens in a parallel manner
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.sub_decoder_enricher_use = sub_decoder_enricher_use
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
self.MASK_idx = MASK_IDX
self.denoising_steps = denoising_steps
self.eps = eps
self.method = method
self.input_norm = nn.LayerNorm(dim)
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
causal_mask = generate_SA_mask(len(prediction_order))
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causal_ca_mask)
# get depth of the sub-decoder
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
else:
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
# simplified version of the forward process in diffusion model
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
b, l = reshaped_input_ids.shape
t = torch.rand(b, device=input_ids.device)
p_mask = (1 - eps) * t + eps
p_mask = p_mask[:, None].repeat(1, l)
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
# 126336 is used for [MASK] token,attention that this token is not in the vocab
if mask_idx is not None:
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
else:
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
return noisy_batch, masked_indices, p_mask
def _apply_window_on_hidden_vec(self, hidden_vec):
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
window_size = 1
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
return new_hidden_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
return tgt_pos
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
for _, feature in enumerate(self.prediction_order[:-1]):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
return memory_tensor
# return a tensor
def _get_noisy_tensor(self, target_shape):
new_target = torch.zeros(target_shape).to(self.device)
# fill all the elements in the tensor with the embedding of the mask token
new_target[:, :, :] = self.diffusion_mask_emb
return new_target
# prepare the embedding of the target,
def _prepare_embedding(self, memory_list, target):
for _, feature in enumerate(self.prediction_order):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
return memory_tensor
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
memory_list = [] # used for key and value in cross attention
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
if add_BOS is true:
if target is not None: # training
memory_list.append(BOS_emb)
else: # inference
memory_list.append(BOS_emb[-1:, :, :])
else:
pass
return memory_list
def _get_num_transfer_tokens(self, mask_index, steps):
'''
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.
This function is designed to precompute the number of tokens that need to be transitioned at each step.
'''
mask_num = mask_index.sum(dim=1,keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
sampled_token_dict = {}
logits_dict = {}
candidate_token_embeddings = {}
candidate_token_probs = {}
b,t,d = hidden_vec.shape # B x T x d_model
# print("*"*8)
logits_list = []
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_list.append(logit)
for idx, feature in enumerate(self.prediction_order):
logit = logits_list[idx] # B x T x vocab_siz
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if step==0 and force_decode:
if feature == 'velocity':
sampled_token = torch.tensor([2]).to(logit.device)
prob = torch.tensor([1.0]).to(logit.device)
else:
prob = torch.tensor([0.0]).to(logit.device)
# print(feature, sampled_token, prob)
sampled_token_dict[feature] = sampled_token
logits_dict[feature] = logit
candidate_token_probs[feature] = prob
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
# print("sampled_token_dict", sampled_token_dict)
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def sample_from_logits_fast(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None):
sampled_token_dict = {}
logits_dict = {}
candidate_token_embeddings = {}
candidate_token_probs = {}
b, t, d = hidden_vec.shape # (B, T, D)
F = len(self.projection_keys)
Vmax = self.max_vocab_size
# === 1. 取出所有 feature 的位置 === #
feature_pos_list = [self.feature_order_in_output[f] for f in self.projection_keys]
# === 2. 提取 attn_output 中各 feature 的位置 → (B, F, D) === #
attn_features = torch.stack(
[attn_output[:, pos, :] for pos in feature_pos_list], dim=1
) # (B, F, D)
# === 3. 使用 batch 矩阵乘法einsum 实现并行 Linear === #
# attn_features: (B, F, D)
# proj_weight: (F, Vmax, D)
# proj_bias: (F, Vmax)
# output: (B, F, Vmax)
logits = torch.einsum("bfd,fvd->bfv", attn_features, self.proj_weight) + self.proj_bias
# === 4. 按照原始 vocab size 截断每个 feature 的 logits === #
logits_list = []
logits_dict_by_feature = {
feature: logits[:, i, :self.vocab_sizes[feature]]
for i, feature in enumerate(self.projection_keys)
}
for i, feature in enumerate(self.projection_keys):
vocab_size = self.vocab_sizes[feature]
logits_list.append(logits[:, i, :vocab_size]) # (B, vocab_size)
for idx, feature in enumerate(self.prediction_order):
logit = logits_dict_by_feature[feature].unsqueeze(0) # B x T x vocab_size
sampled_token, prob = sample_with_prob_fast(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
# print(feature, sampled_token, prob)
sampled_token_dict[feature] = sampled_token.squeeze(0) # B x T
logits_dict[feature] = logit
candidate_token_probs[feature] = prob
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def choose_tokens(self, hidden_vec, step, method, stacked_logits_probs, num_transfer_tokens):
if method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif method == 'auto-regressive':
indices = torch.tensor([[step]], device=hidden_vec.device)
return indices
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
if bos_hidden_vec is None: # start of generation
if target is None:
bos_hidden_vec = input_seq_pos
else:
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
else:
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = input_seq
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
# add attribute control here
stored_logits_dict = {}
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
if condition_step is not None:
# print("shape of condition_step", condition_step.shape)
condition_step = condition_step.reshape((b*t, l))
for i in range(b*t):
for j in range(l):
token = condition_step[i][j]
if condition_step[i][j] != self.MASK_idx:
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
# with torch.profiler.profile(
# activities=[
# torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA],
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# ) as prof:
for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
force_decode=Force_decode,
step=step)
# print("step", step)
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif self.method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif self.method == 'auto-regressive':
indices = torch.tensor([[step]], device=logit.device)
# undate the masked history
for i in range(b*t):
for j in range(l):
if j in indices[i]:
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
# skip if all tokens are unmasked
if not expand_masked_history.any():
# print("All tokens have been unmasked. Ending denoising process.")
break
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
# get final sampled tokens by embedding the unmasked tokens
sampled_token_dict = {}
for idx, feature in enumerate(self.prediction_order):
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
sampled_token_dict[feature] = sampled_token
# print("Final sampled tokens:")
# print(sampled_token_dict)
# print(condition_step)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# all is embedding
# memory_tensor = self.layer_norm(memory_tensor)
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
class DiffusionDecoderV2(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 8,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
'''
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
As the output of the main decoder is the representation of the whole sequence,
it contains richer information which can even decode out sub-tokens in a parallel manner
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.sub_decoder_enricher_use = sub_decoder_enricher_use
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
self.MASK_idx = MASK_IDX
self.denoising_steps = denoising_steps
self.eps = eps
self.method = method
self.input_norm = nn.LayerNorm(dim)
self.feature_boost_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
causal_mask = generate_SA_mask(len(prediction_order))
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causal_ca_mask)
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(*[TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
else:
self.sub_decoder_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
self.aux_ar_decoder = SelfAttention(prediction_order=prediction_order,
vocab=vocab,
sub_decoder_depth=1,
dim=dim,
heads=heads,
dropout=dropout,
sub_decoder_enricher_use=False)
# simplified version of the forward process in diffusion model
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
b, l = reshaped_input_ids.shape
t = torch.rand(b, device=input_ids.device)
p_mask = (1 - eps) * t + eps
p_mask = p_mask[:, None].repeat(1, l)
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
# 126336 is used for [MASK] token,attention that this token is not in the vocab
if mask_idx is not None:
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
else:
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
return noisy_batch, masked_indices, p_mask
def _apply_window_on_hidden_vec(self, hidden_vec):
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
window_size = 1
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
return new_hidden_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
return tgt_pos
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
for _, feature in enumerate(self.prediction_order[:-1]):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
return memory_tensor
# return a tensor
def _get_noisy_tensor(self, target_shape):
new_target = torch.zeros(target_shape).to(self.device)
# fill all the elements in the tensor with the embedding of the mask token
new_target[:, :, :] = self.diffusion_mask_emb
return new_target
# prepare the embedding of the target,
def _prepare_embedding(self, memory_list, target):
for _, feature in enumerate(self.prediction_order):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
return memory_tensor
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
memory_list = [] # used for key and value in cross attention
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
if add_BOS is true:
if target is not None: # training
memory_list.append(BOS_emb)
else: # inference
memory_list.append(BOS_emb[-1:, :, :])
else:
pass
return memory_list
def _get_num_transfer_tokens(self, mask_index, steps):
'''
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.
This function is designed to precompute the number of tokens that need to be transitioned at each step.
'''
mask_num = mask_index.sum(dim=1,keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
sampled_token_dict = {}
logits_dict = {}
candidate_token_embeddings = {}
candidate_token_probs = {}
b,t,d = hidden_vec.shape # B x T x d_model
# print("*"*8)
logits_list = []
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_list.append(logit)
for idx, feature in enumerate(self.prediction_order):
logit = logits_list[idx] # B x T x vocab_siz
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if step==0 and force_decode:
if feature == 'velocity':
sampled_token = torch.tensor([2]).to(logit.device)
prob = torch.tensor([1.0]).to(logit.device)
else:
prob = torch.tensor([0.0]).to(logit.device)
# print(feature, sampled_token, prob)
sampled_token_dict[feature] = sampled_token
logits_dict[feature] = logit
candidate_token_probs[feature] = prob
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
# print("sampled_token_dict", sampled_token_dict)
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None, aux_ar=False):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
copy_input_dict = input_dict.copy()
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
if bos_hidden_vec is None: # start of generation
if target is None:
bos_hidden_vec = input_seq_pos
else:
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
else:
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = input_seq
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
if aux_ar: # inference with auxiliary auto-regressive decoder
aux_ar_logits, sampled_token_dict = self.aux_ar_decoder(copy_input_dict, sampling_method='auto-regressive', threshold=threshold, temperature=temperature, condition_step=condition_step)
# print("aux_ar_logits", aux_ar_logits)
return aux_ar_logits, sampled_token_dict
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
# add attribute control here
stored_logits_dict = {}
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
if condition_step is not None:
# print("shape of condition_step", condition_step.shape)
condition_step = condition_step.reshape((b*t, l))
for i in range(b*t):
for j in range(l):
token = condition_step[i][j]
if condition_step[i][j] != self.MASK_idx:
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
# with torch.profiler.profile(
# activities=[
# torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA],
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# ) as prof:
for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
force_decode=Force_decode,
step=step)
# print("step", step)
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif self.method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif self.method == 'auto-regressive':
indices = torch.tensor([[step]], device=logit.device)
# undate the masked history
for i in range(b*t):
for j in range(l):
if j in indices[i]:
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
# skip if all tokens are unmasked
if not expand_masked_history.any():
# print("All tokens have been unmasked. Ending denoising process.")
break
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
# get final sampled tokens by embedding the unmasked tokens
sampled_token_dict = {}
for idx, feature in enumerate(self.prediction_order):
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
sampled_token_dict[feature] = sampled_token
# print("Final sampled tokens:")
# print(sampled_token_dict)
# print(condition_step)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# all is embedding
# memory_tensor = self.layer_norm(memory_tensor)
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
# get aux ar decoder logits
aux_ar_logits = self.aux_ar_decoder(copy_input_dict, target) # B x T
return (logits_dict, aux_ar_logits), (masked_indices, p_mask)
# return logits_dict, (masked_indices, p_mask)