1021 add flexable attr control
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from re import T
|
||||
from selectors import EpollSelector
|
||||
from turtle import st
|
||||
from numpy import indices
|
||||
@ -6,7 +7,7 @@ import torch
|
||||
import torch.profiler
|
||||
import torch.nn as nn
|
||||
|
||||
from x_transformers import Decoder
|
||||
from .custom_x_transformers import Decoder
|
||||
|
||||
from .transformer_utils import MultiEmbedding, RVQMultiEmbedding
|
||||
from .sub_decoder_utils import *
|
||||
@ -146,7 +147,7 @@ class FeedForward(SubDecoderClass):
|
||||
f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items()
|
||||
})
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
target = input_dict['target']
|
||||
@ -204,7 +205,7 @@ class Parallel(SubDecoderClass):
|
||||
'''
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
target = input_dict['target']
|
||||
@ -414,7 +415,7 @@ class SelfAttention(SubDecoderClass):
|
||||
memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||||
return memory_tensor
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] # B x T x num_sub_tokens
|
||||
@ -490,7 +491,7 @@ class SelfAttentionUniAudio(SelfAttention):
|
||||
memory_tensor = hidden_vec_reshape + feature_tensor
|
||||
return memory_tensor
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] # B x T x num_sub-tokens
|
||||
@ -604,7 +605,7 @@ class CrossAttention(SubDecoderClass):
|
||||
memory_list.append(BOS_emb[-1:, :, :])
|
||||
return memory_list
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target']
|
||||
@ -677,7 +678,7 @@ class Flatten4Encodec(SubDecoderClass):
|
||||
):
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
|
||||
# ---- Training ---- #
|
||||
@ -838,7 +839,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
||||
'''
|
||||
mask_num = mask_index.sum(dim=1, keepdim=True)
|
||||
mask_num = mask_index.sum(dim=1,keepdim=True)
|
||||
base = mask_num // steps
|
||||
remainder = mask_num % steps
|
||||
|
||||
@ -941,94 +942,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
indices = torch.tensor([[step]], device=hidden_vec.device)
|
||||
return indices
|
||||
|
||||
|
||||
def forward_(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = input_seq
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
for step in range(self.denoising_steps):
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
# input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
# indices = self.choose_tokens(hidden_vec,step, "auto-regressive", stacked_logits_probs, num_transfer_tokens)
|
||||
indices = self.choose_tokens(hidden_vec, step, self.method, stacked_logits_probs, num_transfer_tokens)
|
||||
# breakpoint()
|
||||
# undate the masked history
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
# breakpoint()
|
||||
# print("stored_probs_dict", stored_probs_dict)
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
def forward_old(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
@ -1070,139 +984,25 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
# add attribute control here
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
candidate_token_embeddings = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
sampled_token,probs = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
# print(idx,feature,sampled_token,probs)
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
candidate_token_probs[feature] = probs
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||||
|
||||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, l)) # (B*T) x num_sub_tokens x vocab_size
|
||||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, l, d))
|
||||
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
elif self.method == 'random':
|
||||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||||
elif self.method == 'auto-regressive':
|
||||
indices = torch.tensor([[step]], device=logit.device)
|
||||
# undate the masked history
|
||||
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
|
||||
if condition_step is not None:
|
||||
# print("shape of condition_step", condition_step.shape)
|
||||
condition_step = condition_step.reshape((b*t, l))
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# all is embedding
|
||||
# memory_tensor = self.layer_norm(memory_tensor)
|
||||
# apply feature enricher to memory
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# implement sub decoder cross attention
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
# inter_input = torch.cat([input_seq_pos, memory_tensor], dim=1)
|
||||
# inter_input = input_seq_pos + memory_tensor # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, validation=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
token = condition_step[i][j]
|
||||
if condition_step[i][j] != self.MASK_idx:
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
if bos_hidden_vec is None: # start of generation
|
||||
if target is None:
|
||||
bos_hidden_vec = input_seq_pos
|
||||
else:
|
||||
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
|
||||
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
|
||||
|
||||
else:
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# input_seq_pos = input_seq
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
|
||||
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
|
||||
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
|
||||
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
# with torch.profiler.profile(
|
||||
# activities=[
|
||||
# torch.profiler.ProfilerActivity.CPU,
|
||||
@ -1213,8 +1013,6 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# ) as prof:
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
@ -1223,14 +1021,15 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
|
||||
force_decode=Force_decode,
|
||||
step=step)
|
||||
|
||||
# print("step", step)
|
||||
# print("toknes", sampled_token_dict)
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
@ -1242,12 +1041,25 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
|
||||
# skip if all tokens are unmasked
|
||||
if not expand_masked_history.any():
|
||||
# print("All tokens have been unmasked. Ending denoising process.")
|
||||
break
|
||||
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
|
||||
# get final sampled tokens by embedding the unmasked tokens
|
||||
sampled_token_dict = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
# print("Final sampled tokens:")
|
||||
# print(sampled_token_dict)
|
||||
# print(condition_step)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
|
||||
Reference in New Issue
Block a user