1127 update to latest
This commit is contained in:
@ -347,13 +347,24 @@ class SelfAttention(SubDecoderClass):
|
||||
causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size)
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
|
||||
# self.transformer_decoder = Decoder(
|
||||
# dim = dim,
|
||||
# depth = sub_decoder_depth,
|
||||
# heads = heads,
|
||||
# attn_dropout = dropout,
|
||||
# ff_dropout = dropout,
|
||||
# attn_flash = True)
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
dim = dim,
|
||||
depth = sub_decoder_depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True)
|
||||
attn_flash = True,
|
||||
use_rmsnorm=True,
|
||||
ff_swish = True, # set this to True
|
||||
ff_glu = True, # set to true to use for all feedforwards
|
||||
)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
@ -713,7 +724,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
denoising_steps:int = 8,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1091,7 +1102,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
class DiffusionDecoder(SubDecoderClass):
|
||||
class DiffusionDecoderV2(SubDecoderClass):
|
||||
def __init__(
|
||||
self,
|
||||
prediction_order:list,
|
||||
@ -1102,7 +1113,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
denoising_steps:int = 8,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1129,7 +1140,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
self.input_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
if sub_decoder_enricher_use:
|
||||
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||||
@ -1138,14 +1149,21 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
self.register_buffer('causal_ca_mask', causal_ca_mask)
|
||||
|
||||
# get depth of the sub-decoder
|
||||
if sub_decoder_depth > 1:
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
else:
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
|
||||
if sub_decoder_enricher_use:
|
||||
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
|
||||
self.aux_ar_decoder = SelfAttention(prediction_order=prediction_order,
|
||||
vocab=vocab,
|
||||
sub_decoder_depth=1,
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dropout=dropout,
|
||||
sub_decoder_enricher_use=False)
|
||||
|
||||
# simplified version of the forward process in diffusion model
|
||||
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
|
||||
@ -1273,9 +1291,11 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None, aux_ar=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
copy_input_dict = input_dict.copy()
|
||||
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
|
||||
@ -1307,6 +1327,10 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
if aux_ar: # inference with auxiliary auto-regressive decoder
|
||||
aux_ar_logits, sampled_token_dict = self.aux_ar_decoder(copy_input_dict, sampling_method='auto-regressive', threshold=threshold, temperature=temperature, condition_step=condition_step)
|
||||
# print("aux_ar_logits", aux_ar_logits)
|
||||
return aux_ar_logits, sampled_token_dict
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
@ -1420,4 +1444,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
# get aux ar decoder logits
|
||||
aux_ar_logits = self.aux_ar_decoder(copy_input_dict, target) # B x T
|
||||
return (logits_dict, aux_ar_logits), (masked_indices, p_mask)
|
||||
# return logits_dict, (masked_indices, p_mask)
|
||||
Reference in New Issue
Block a user