1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -347,13 +347,24 @@ class SelfAttention(SubDecoderClass):
causal_mask = generate_causality_mask_on_window(size=window_size + len(prediction_order), window_size=window_size)
self.register_buffer('causal_mask', causal_mask)
# self.transformer_decoder = Decoder(
# dim = dim,
# depth = sub_decoder_depth,
# heads = heads,
# attn_dropout = dropout,
# ff_dropout = dropout,
# attn_flash = True)
self.transformer_decoder = Decoder(
dim = dim,
dim = dim,
depth = sub_decoder_depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
attn_flash = True,
use_rmsnorm=True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
@ -713,7 +724,7 @@ class DiffusionDecoder(SubDecoderClass):
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 6,
denoising_steps:int = 8,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
@ -1091,7 +1102,7 @@ class DiffusionDecoder(SubDecoderClass):
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
class DiffusionDecoder(SubDecoderClass):
class DiffusionDecoderV2(SubDecoderClass):
def __init__(
self,
prediction_order:list,
@ -1102,7 +1113,7 @@ class DiffusionDecoder(SubDecoderClass):
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 6,
denoising_steps:int = 8,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
@ -1129,7 +1140,7 @@ class DiffusionDecoder(SubDecoderClass):
self.input_norm = nn.LayerNorm(dim)
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
self.feature_boost_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
@ -1138,14 +1149,21 @@ class DiffusionDecoder(SubDecoderClass):
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causal_ca_mask)
# get depth of the sub-decoder
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
self.sub_decoder_layers = nn.Sequential(*[TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
else:
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
self.sub_decoder_layers = nn.Sequential(TransformerLayerV2(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
self.aux_ar_decoder = SelfAttention(prediction_order=prediction_order,
vocab=vocab,
sub_decoder_depth=1,
dim=dim,
heads=heads,
dropout=dropout,
sub_decoder_enricher_use=False)
# simplified version of the forward process in diffusion model
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
@ -1273,9 +1291,11 @@ class DiffusionDecoder(SubDecoderClass):
# print("sampled_token_dict", sampled_token_dict)
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None, aux_ar=False):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
copy_input_dict = input_dict.copy()
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
@ -1307,6 +1327,10 @@ class DiffusionDecoder(SubDecoderClass):
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
if aux_ar: # inference with auxiliary auto-regressive decoder
aux_ar_logits, sampled_token_dict = self.aux_ar_decoder(copy_input_dict, sampling_method='auto-regressive', threshold=threshold, temperature=temperature, condition_step=condition_step)
# print("aux_ar_logits", aux_ar_logits)
return aux_ar_logits, sampled_token_dict
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
@ -1420,4 +1444,7 @@ class DiffusionDecoder(SubDecoderClass):
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
# get aux ar decoder logits
aux_ar_logits = self.aux_ar_decoder(copy_input_dict, target) # B x T
return (logits_dict, aux_ar_logits), (masked_indices, p_mask)
# return logits_dict, (masked_indices, p_mask)