1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -713,7 +713,7 @@ class DiffusionDecoder(SubDecoderClass):
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 8,
denoising_steps:int = 6,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
@ -1029,7 +1029,338 @@ class DiffusionDecoder(SubDecoderClass):
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
print("stacked_logits_probs", stacked_logits_probs.clone())
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif self.method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif self.method == 'auto-regressive':
indices = torch.tensor([[step]], device=logit.device)
# undate the masked history
for i in range(b*t):
for j in range(l):
if j in indices[i]:
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
# skip if all tokens are unmasked
if not expand_masked_history.any():
# print("All tokens have been unmasked. Ending denoising process.")
break
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
# get final sampled tokens by embedding the unmasked tokens
sampled_token_dict = {}
for idx, feature in enumerate(self.prediction_order):
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
sampled_token_dict[feature] = sampled_token
# print("Final sampled tokens:")
# print(sampled_token_dict)
# print(condition_step)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# all is embedding
# memory_tensor = self.layer_norm(memory_tensor)
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
class DiffusionDecoder(SubDecoderClass):
def __init__(
self,
prediction_order:list,
vocab:LangTokenVocab,
sub_decoder_depth:int,
dim:int,
heads:int,
dropout:float,
sub_decoder_enricher_use:bool,
MASK_IDX:int = 126336,
denoising_steps:int = 6,
eps:float = 1e-3,
method:str = 'low-confidence', # or random or auto-regressive
):
'''
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
As the output of the main decoder is the representation of the whole sequence,
it contains richer information which can even decode out sub-tokens in a parallel manner
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
self.sub_decoder_enricher_use = sub_decoder_enricher_use
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
nn.init.zeros_(self.pos_enc.weight)
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
self.MASK_idx = MASK_IDX
self.denoising_steps = denoising_steps
self.eps = eps
self.method = method
self.input_norm = nn.LayerNorm(dim)
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
causal_mask = generate_SA_mask(len(prediction_order))
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
self.register_buffer('causal_mask', causal_mask)
self.register_buffer('causal_ca_mask', causal_ca_mask)
# get depth of the sub-decoder
if sub_decoder_depth > 1:
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
else:
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
if sub_decoder_enricher_use:
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
# simplified version of the forward process in diffusion model
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
b, l = reshaped_input_ids.shape
t = torch.rand(b, device=input_ids.device)
p_mask = (1 - eps) * t + eps
p_mask = p_mask[:, None].repeat(1, l)
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
# 126336 is used for [MASK] token,attention that this token is not in the vocab
if mask_idx is not None:
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
else:
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
return noisy_batch, masked_indices, p_mask
def _apply_window_on_hidden_vec(self, hidden_vec):
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
window_size = 1
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
return new_hidden_vec
def _apply_pos_enc(self, tgt):
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
return tgt_pos
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
for _, feature in enumerate(self.prediction_order[:-1]):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
return memory_tensor
# return a tensor
def _get_noisy_tensor(self, target_shape):
new_target = torch.zeros(target_shape).to(self.device)
# fill all the elements in the tensor with the embedding of the mask token
new_target[:, :, :] = self.diffusion_mask_emb
return new_target
# prepare the embedding of the target,
def _prepare_embedding(self, memory_list, target):
for _, feature in enumerate(self.prediction_order):
feature_idx = self.vocab.feature_list.index(feature)
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
memory_list.append(feature_emb_reshape)
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
return memory_tensor
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
memory_list = [] # used for key and value in cross attention
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
if add_BOS is true:
if target is not None: # training
memory_list.append(BOS_emb)
else: # inference
memory_list.append(BOS_emb[-1:, :, :])
else:
pass
return memory_list
def _get_num_transfer_tokens(self, mask_index, steps):
'''
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.
This function is designed to precompute the number of tokens that need to be transitioned at each step.
'''
mask_num = mask_index.sum(dim=1,keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
for i in range(mask_num.size(0)):
num_transfer_tokens[i, :remainder[i]] += 1
return num_transfer_tokens
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
sampled_token_dict = {}
logits_dict = {}
candidate_token_embeddings = {}
candidate_token_probs = {}
b,t,d = hidden_vec.shape # B x T x d_model
# print("*"*8)
logits_list = []
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_list.append(logit)
for idx, feature in enumerate(self.prediction_order):
logit = logits_list[idx] # B x T x vocab_siz
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if step==0 and force_decode:
if feature == 'velocity':
sampled_token = torch.tensor([2]).to(logit.device)
prob = torch.tensor([1.0]).to(logit.device)
else:
prob = torch.tensor([0.0]).to(logit.device)
# print(feature, sampled_token, prob)
sampled_token_dict[feature] = sampled_token
logits_dict[feature] = logit
candidate_token_probs[feature] = prob
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
# print("sampled_token_dict", sampled_token_dict)
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
if bos_hidden_vec is None: # start of generation
if target is None:
bos_hidden_vec = input_seq_pos
else:
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
else:
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = input_seq
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
# add attribute control here
stored_logits_dict = {}
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
if condition_step is not None:
# print("shape of condition_step", condition_step.shape)
condition_step = condition_step.reshape((b*t, l))
for i in range(b*t):
for j in range(l):
token = condition_step[i][j]
if condition_step[i][j] != self.MASK_idx:
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
# with torch.profiler.profile(
# activities=[
# torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA],
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# ) as prof:
for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
force_decode=Force_decode,
step=step)
# print("step", step)
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
# print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)