1029 add octuple
This commit is contained in:
@ -713,7 +713,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 8,
|
||||
denoising_steps:int = 6,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
@ -1029,7 +1029,338 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# print("toknes", sampled_token_dict)
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
# print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
elif self.method == 'random':
|
||||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||||
elif self.method == 'auto-regressive':
|
||||
indices = torch.tensor([[step]], device=logit.device)
|
||||
# undate the masked history
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
|
||||
# skip if all tokens are unmasked
|
||||
if not expand_masked_history.any():
|
||||
# print("All tokens have been unmasked. Ending denoising process.")
|
||||
break
|
||||
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
|
||||
# get final sampled tokens by embedding the unmasked tokens
|
||||
sampled_token_dict = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
# print("Final sampled tokens:")
|
||||
# print(sampled_token_dict)
|
||||
# print(condition_step)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# all is embedding
|
||||
# memory_tensor = self.layer_norm(memory_tensor)
|
||||
# apply feature enricher to memory
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# implement sub decoder cross attention
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
class DiffusionDecoder(SubDecoderClass):
|
||||
def __init__(
|
||||
self,
|
||||
prediction_order:list,
|
||||
vocab:LangTokenVocab,
|
||||
sub_decoder_depth:int,
|
||||
dim:int,
|
||||
heads:int,
|
||||
dropout:float,
|
||||
sub_decoder_enricher_use:bool,
|
||||
MASK_IDX:int = 126336,
|
||||
denoising_steps:int = 6,
|
||||
eps:float = 1e-3,
|
||||
method:str = 'low-confidence', # or random or auto-regressive
|
||||
):
|
||||
'''
|
||||
The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder
|
||||
As the output of the main decoder is the representation of the whole sequence,
|
||||
it contains richer information which can even decode out sub-tokens in a parallel manner
|
||||
So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder
|
||||
'''
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
self.sub_decoder_enricher_use = sub_decoder_enricher_use
|
||||
self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)}
|
||||
|
||||
self.pos_enc = nn.Embedding(len(self.prediction_order), dim)
|
||||
nn.init.zeros_(self.pos_enc.weight)
|
||||
|
||||
self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||||
self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab
|
||||
nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02)
|
||||
self.MASK_idx = MASK_IDX
|
||||
self.denoising_steps = denoising_steps
|
||||
self.eps = eps
|
||||
self.method = method
|
||||
|
||||
self.input_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
if sub_decoder_enricher_use:
|
||||
self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True)
|
||||
causal_mask = generate_SA_mask(len(prediction_order))
|
||||
causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device)
|
||||
self.register_buffer('causal_mask', causal_mask)
|
||||
self.register_buffer('causal_ca_mask', causal_ca_mask)
|
||||
|
||||
# get depth of the sub-decoder
|
||||
if sub_decoder_depth > 1:
|
||||
self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)])
|
||||
else:
|
||||
self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout))
|
||||
if sub_decoder_enricher_use:
|
||||
self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout))
|
||||
|
||||
|
||||
# simplified version of the forward process in diffusion model
|
||||
def _forward_process(self, input_ids, eps=1e-3, mask_idx=None):
|
||||
reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens
|
||||
b, l = reshaped_input_ids.shape
|
||||
t = torch.rand(b, device=input_ids.device)
|
||||
p_mask = (1 - eps) * t + eps
|
||||
p_mask = p_mask[:, None].repeat(1, l)
|
||||
|
||||
masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
|
||||
# 126336 is used for [MASK] token,attention that this token is not in the vocab
|
||||
if mask_idx is not None:
|
||||
noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids)
|
||||
else:
|
||||
noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in
|
||||
return noisy_batch, masked_indices, p_mask
|
||||
|
||||
|
||||
def _apply_window_on_hidden_vec(self, hidden_vec):
|
||||
BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||||
# through our experiments, we found that the size of the window doesn't affect the performance of the model much
|
||||
window_size = 1
|
||||
zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model
|
||||
cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model
|
||||
new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model
|
||||
new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model
|
||||
new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model
|
||||
return new_hidden_vec
|
||||
|
||||
def _apply_pos_enc(self, tgt):
|
||||
pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens
|
||||
pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens
|
||||
tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model
|
||||
return tgt_pos
|
||||
|
||||
def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target):
|
||||
for _, feature in enumerate(self.prediction_order[:-1]):
|
||||
feature_idx = self.vocab.feature_list.index(feature)
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||||
memory_list.append(feature_emb_reshape)
|
||||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model
|
||||
return memory_tensor
|
||||
|
||||
# return a tensor
|
||||
def _get_noisy_tensor(self, target_shape):
|
||||
new_target = torch.zeros(target_shape).to(self.device)
|
||||
# fill all the elements in the tensor with the embedding of the mask token
|
||||
new_target[:, :, :] = self.diffusion_mask_emb
|
||||
return new_target
|
||||
|
||||
# prepare the embedding of the target,
|
||||
def _prepare_embedding(self, memory_list, target):
|
||||
for _, feature in enumerate(self.prediction_order):
|
||||
feature_idx = self.vocab.feature_list.index(feature)
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size
|
||||
feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size
|
||||
memory_list.append(feature_emb_reshape)
|
||||
memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model
|
||||
return memory_tensor
|
||||
|
||||
|
||||
def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True):
|
||||
memory_list = [] # used for key and value in cross attention
|
||||
BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model
|
||||
if add_BOS is true:
|
||||
if target is not None: # training
|
||||
memory_list.append(BOS_emb)
|
||||
else: # inference
|
||||
memory_list.append(BOS_emb[-1:, :, :])
|
||||
else:
|
||||
pass
|
||||
return memory_list
|
||||
|
||||
def _get_num_transfer_tokens(self, mask_index, steps):
|
||||
'''
|
||||
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
|
||||
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
|
||||
the expected number of tokens transitioned at each step should be consistent.
|
||||
|
||||
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
||||
'''
|
||||
mask_num = mask_index.sum(dim=1,keepdim=True)
|
||||
base = mask_num // steps
|
||||
remainder = mask_num % steps
|
||||
|
||||
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
|
||||
|
||||
for i in range(mask_num.size(0)):
|
||||
num_transfer_tokens[i, :remainder[i]] += 1
|
||||
|
||||
return num_transfer_tokens
|
||||
|
||||
def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None):
|
||||
sampled_token_dict = {}
|
||||
logits_dict = {}
|
||||
candidate_token_embeddings = {}
|
||||
candidate_token_probs = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
# print("*"*8)
|
||||
logits_list = []
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_list.append(logit)
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
logit = logits_list[idx] # B x T x vocab_siz
|
||||
sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
if step==0 and force_decode:
|
||||
if feature == 'velocity':
|
||||
sampled_token = torch.tensor([2]).to(logit.device)
|
||||
prob = torch.tensor([1.0]).to(logit.device)
|
||||
else:
|
||||
prob = torch.tensor([0.0]).to(logit.device)
|
||||
# print(feature, sampled_token, prob)
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
logits_dict[feature] = logit
|
||||
candidate_token_probs[feature] = prob
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size
|
||||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
if bos_hidden_vec is None: # start of generation
|
||||
if target is None:
|
||||
bos_hidden_vec = input_seq_pos
|
||||
else:
|
||||
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
|
||||
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
|
||||
|
||||
else:
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# input_seq_pos = input_seq
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
|
||||
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
# add attribute control here
|
||||
stored_logits_dict = {}
|
||||
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
|
||||
if condition_step is not None:
|
||||
# print("shape of condition_step", condition_step.shape)
|
||||
condition_step = condition_step.reshape((b*t, l))
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
token = condition_step[i][j]
|
||||
if condition_step[i][j] != self.MASK_idx:
|
||||
|
||||
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
|
||||
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
|
||||
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
# with torch.profiler.profile(
|
||||
# activities=[
|
||||
# torch.profiler.ProfilerActivity.CPU,
|
||||
# torch.profiler.ProfilerActivity.CUDA],
|
||||
# record_shapes=True,
|
||||
# profile_memory=True,
|
||||
# with_stack=True
|
||||
# ) as prof:
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
|
||||
force_decode=Force_decode,
|
||||
step=step)
|
||||
|
||||
# print("step", step)
|
||||
# print("toknes", sampled_token_dict)
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
# print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user