1021 add flexable attr control
This commit is contained in:
@ -67,8 +67,19 @@ def get_best_ckpt_path_and_config(wandb_dir, code):
|
||||
|
||||
return last_ckpt_fn, config_path, metadata_path, vocab_path
|
||||
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str, condition_dataset: str=None):
|
||||
# if config is a path, load it
|
||||
if isinstance(config, (str, Path)):
|
||||
from omegaconf import OmegaConf
|
||||
config = OmegaConf.load(config)
|
||||
config = wandb_style_config_to_omega_config(config)
|
||||
|
||||
nn_params = config.nn_params
|
||||
for_evaluation = True
|
||||
if condition_dataset is not None:
|
||||
print(f"Conditioned dataset {condition_dataset} is used instead of {config.dataset}")
|
||||
config.dataset = condition_dataset
|
||||
for_evaluation = False
|
||||
dataset_name = config.dataset
|
||||
vocab_path = Path(vocab_path)
|
||||
|
||||
@ -104,7 +115,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
|
||||
input_length=config.train_params.input_length,
|
||||
first_pred_feature=config.data_params.first_pred_feature,
|
||||
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
|
||||
for_evaluation=True,
|
||||
for_evaluation=for_evaluation
|
||||
)
|
||||
|
||||
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
|
||||
@ -114,7 +125,6 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
|
||||
split_ratio = config.data_params.split_ratio
|
||||
# test_set = []
|
||||
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
|
||||
|
||||
# get proper prediction order according to the encoding scheme and target feature in the config
|
||||
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
|
||||
|
||||
@ -480,6 +490,28 @@ class Evaluator:
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_with_attrCtl(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072, attr_list=None):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
tuneidx = tuneidx.cuda()
|
||||
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature, attr_list=attr_list)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
|
||||
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
|
||||
@ -102,7 +102,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
'''
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# self.prediction_order = net.prediction_order
|
||||
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
|
||||
self.attribute2idx_after = {'pitch': 0,
|
||||
'duration': 1,
|
||||
'velocity': 2,
|
||||
'type': 3,
|
||||
'beat': 4,
|
||||
'chord': 5,
|
||||
'tempo': 6,
|
||||
'instrument': 7}
|
||||
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
||||
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
|
||||
return self.net(input_seq, target, context=context)
|
||||
|
||||
@ -164,7 +174,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
|
||||
return total_out
|
||||
|
||||
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
|
||||
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None):
|
||||
'''
|
||||
Runs one step of autoregressive generation by taking the input sequence, embedding it,
|
||||
passing it through the main decoder, and generating logits and a sampled token.
|
||||
@ -192,7 +202,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
|
||||
|
||||
# Generate the next token
|
||||
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
|
||||
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step)
|
||||
return logits, sampled_token, intermidiates, hidden_vec
|
||||
|
||||
def _update_total_out(self, total_out, sampled_token):
|
||||
@ -225,7 +235,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
return total_out, sampled_token
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None):
|
||||
'''
|
||||
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
|
||||
until the desired maximum sequence length is reached or the end token is encountered.
|
||||
@ -243,15 +253,19 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
- total_out: The generated sequence of tokens as a tensor.
|
||||
'''
|
||||
# Prepare the starting sequence for inference
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
||||
|
||||
# If a condition is provided, run one initial step
|
||||
if condition is not None:
|
||||
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||
if attr_list is None:
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
||||
else:
|
||||
cache = LayerIntermediates()
|
||||
|
||||
# Continue generating tokens until the maximum sequence length is reached
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures)
|
||||
# for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336
|
||||
condition_filtered = condition.clone().unsqueeze(0)
|
||||
# print(self.attribute2idx)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
if attr not in attr_list:
|
||||
condition_filtered[:, :, idx] = 126336
|
||||
# rearange condition_filtered to match prediction order
|
||||
|
||||
cache = LayerIntermediates()
|
||||
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
|
||||
bos_hidden_vec = None
|
||||
hidden_vec_list = []
|
||||
@ -261,7 +275,21 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
input_tensor = total_out[:, -self.net.input_length:]
|
||||
# Generate the next token and update the cache
|
||||
time_start = time.time()
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||
# if attr_list is not None, get one token in condition_filtered each time step
|
||||
if attr_list is not None:
|
||||
condition_filtered = condition_filtered.to(self.net.device)
|
||||
# print(condition_filtered[:,:20,:])
|
||||
# print(condition_filtered.shape)
|
||||
condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :]
|
||||
# rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4
|
||||
condition_step_rearranged = torch.zeros_like(condition_step)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
new_idx = self.attribute2idx_after[attr]
|
||||
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
|
||||
# print("condition_step shape:", condition_step.shape)
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
|
||||
else:
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||
time_end = time.time()
|
||||
token_time_list.append(time_end - time_start)
|
||||
if bos_hidden_vec is None:
|
||||
@ -416,11 +444,11 @@ class AmadeusModel(nn.Module):
|
||||
return self.decoder(input_seq, target, context=context)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None):
|
||||
if batch_size == 1:
|
||||
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
|
||||
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list)
|
||||
else:
|
||||
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
|
||||
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list)
|
||||
|
||||
class AmadeusModel4Encodec(AmadeusModel):
|
||||
def __init__(
|
||||
|
||||
@ -43,6 +43,22 @@ def typical_sampling(logits, thres=0.99):
|
||||
scores = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||
return scores
|
||||
|
||||
def min_p_sampling(logits, alpha=0.05):
|
||||
"""
|
||||
logits: Tensor of shape [B, L, V]
|
||||
alpha: float, relative probability threshold (e.g., 0.05)
|
||||
"""
|
||||
# 计算 softmax 概率
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
# 找到每个位置的最大概率
|
||||
max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1]
|
||||
|
||||
# 保留概率 >= alpha * max_prob 的 token
|
||||
mask = probs < (alpha * max_probs) # True 表示要屏蔽
|
||||
masked_logits = logits.masked_fill(mask, float('-inf'))
|
||||
return masked_logits
|
||||
|
||||
def add_gumbel_noise(logits, temperature):
|
||||
'''
|
||||
The Gumbel max is a method for sampling categorical distributions.
|
||||
@ -91,6 +107,8 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
|
||||
modified_logits = typical_sampling(logits, thres=threshold)
|
||||
elif sampling_method == "eta":
|
||||
modified_logits = eta_sampling(logits, epsilon=threshold)
|
||||
elif sampling_method == "min_p":
|
||||
modified_logits = min_p_sampling(logits, alpha=threshold)
|
||||
else:
|
||||
modified_logits = logits # 其他情况直接使用原始logits
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from re import T
|
||||
from selectors import EpollSelector
|
||||
from turtle import st
|
||||
from numpy import indices
|
||||
@ -6,7 +7,7 @@ import torch
|
||||
import torch.profiler
|
||||
import torch.nn as nn
|
||||
|
||||
from x_transformers import Decoder
|
||||
from .custom_x_transformers import Decoder
|
||||
|
||||
from .transformer_utils import MultiEmbedding, RVQMultiEmbedding
|
||||
from .sub_decoder_utils import *
|
||||
@ -146,7 +147,7 @@ class FeedForward(SubDecoderClass):
|
||||
f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items()
|
||||
})
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
target = input_dict['target']
|
||||
@ -204,7 +205,7 @@ class Parallel(SubDecoderClass):
|
||||
'''
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
target = input_dict['target']
|
||||
@ -414,7 +415,7 @@ class SelfAttention(SubDecoderClass):
|
||||
memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||||
return memory_tensor
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] # B x T x num_sub_tokens
|
||||
@ -490,7 +491,7 @@ class SelfAttentionUniAudio(SelfAttention):
|
||||
memory_tensor = hidden_vec_reshape + feature_tensor
|
||||
return memory_tensor
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] # B x T x num_sub-tokens
|
||||
@ -604,7 +605,7 @@ class CrossAttention(SubDecoderClass):
|
||||
memory_list.append(BOS_emb[-1:, :, :])
|
||||
return memory_list
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target']
|
||||
@ -677,7 +678,7 @@ class Flatten4Encodec(SubDecoderClass):
|
||||
):
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
|
||||
# ---- Training ---- #
|
||||
@ -838,7 +839,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
This function is designed to precompute the number of tokens that need to be transitioned at each step.
|
||||
'''
|
||||
mask_num = mask_index.sum(dim=1, keepdim=True)
|
||||
mask_num = mask_index.sum(dim=1,keepdim=True)
|
||||
base = mask_num // steps
|
||||
remainder = mask_num % steps
|
||||
|
||||
@ -941,94 +942,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
indices = torch.tensor([[step]], device=hidden_vec.device)
|
||||
return indices
|
||||
|
||||
|
||||
def forward_(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = input_seq
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
for step in range(self.denoising_steps):
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
# input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
# indices = self.choose_tokens(hidden_vec,step, "auto-regressive", stacked_logits_probs, num_transfer_tokens)
|
||||
indices = self.choose_tokens(hidden_vec, step, self.method, stacked_logits_probs, num_transfer_tokens)
|
||||
# breakpoint()
|
||||
# undate the masked history
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
# breakpoint()
|
||||
# print("stored_probs_dict", stored_probs_dict)
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
def forward_old(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
@ -1070,139 +984,25 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
# add attribute control here
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
candidate_token_embeddings = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
sampled_token,probs = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
# print(idx,feature,sampled_token,probs)
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
candidate_token_probs[feature] = probs
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||||
|
||||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, l)) # (B*T) x num_sub_tokens x vocab_size
|
||||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, l, d))
|
||||
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
elif self.method == 'random':
|
||||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||||
elif self.method == 'auto-regressive':
|
||||
indices = torch.tensor([[step]], device=logit.device)
|
||||
# undate the masked history
|
||||
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
|
||||
if condition_step is not None:
|
||||
# print("shape of condition_step", condition_step.shape)
|
||||
condition_step = condition_step.reshape((b*t, l))
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# all is embedding
|
||||
# memory_tensor = self.layer_norm(memory_tensor)
|
||||
# apply feature enricher to memory
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# implement sub decoder cross attention
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
# inter_input = torch.cat([input_seq_pos, memory_tensor], dim=1)
|
||||
# inter_input = input_seq_pos + memory_tensor # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, validation=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
token = condition_step[i][j]
|
||||
if condition_step[i][j] != self.MASK_idx:
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
if bos_hidden_vec is None: # start of generation
|
||||
if target is None:
|
||||
bos_hidden_vec = input_seq_pos
|
||||
else:
|
||||
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
|
||||
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
|
||||
|
||||
else:
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# input_seq_pos = input_seq
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
|
||||
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
|
||||
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
|
||||
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
# with torch.profiler.profile(
|
||||
# activities=[
|
||||
# torch.profiler.ProfilerActivity.CPU,
|
||||
@ -1213,8 +1013,6 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# ) as prof:
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
@ -1223,14 +1021,15 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
|
||||
force_decode=Force_decode,
|
||||
step=step)
|
||||
|
||||
# print("step", step)
|
||||
# print("toknes", sampled_token_dict)
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
@ -1242,12 +1041,25 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
|
||||
# skip if all tokens are unmasked
|
||||
if not expand_masked_history.any():
|
||||
# print("All tokens have been unmasked. Ending denoising process.")
|
||||
break
|
||||
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
|
||||
# get final sampled tokens by embedding the unmasked tokens
|
||||
sampled_token_dict = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
# print("Final sampled tokens:")
|
||||
# print(sampled_token_dict)
|
||||
# print(condition_step)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
|
||||
@ -510,6 +510,7 @@ class Melody(SymbolicMusicDataset):
|
||||
shuffled_tune_names = list(self.tune_in_idx.keys())
|
||||
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
||||
song_dict = {}
|
||||
ratio = 0.8
|
||||
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
||||
if song not in song_dict:
|
||||
song_dict[song] = []
|
||||
|
||||
@ -2,8 +2,8 @@ defaults:
|
||||
# - nn_params: nb8_embSum_NMT
|
||||
# - nn_params: remi8
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
- nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
# - nn_params: nb8_embSum_subPararell
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M
|
||||
|
||||
@ -15,7 +15,7 @@ defaults:
|
||||
# - nn_params: remi8_main12_head_16_dim512
|
||||
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||
|
||||
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||
@ -44,7 +44,7 @@ train_params:
|
||||
focal_gamma: 0
|
||||
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||
scheduler : cosinelr
|
||||
initial_lr: 0.0004
|
||||
initial_lr: 0.0003
|
||||
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||
warmup_steps: 2000 #number of warmup steps
|
||||
@ -59,7 +59,7 @@ inference_params:
|
||||
data_params:
|
||||
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
|
||||
split_ratio: 0.998 # train-validation-test split ratio
|
||||
aug_type: pitch # random, null | pitch and chord augmentation type
|
||||
aug_type: null # random, null | pitch and chord augmentation type
|
||||
general:
|
||||
debug: False
|
||||
make_log: True # True, False | update the log file in wandb online to your designated project and entity
|
||||
|
||||
@ -74,7 +74,8 @@ class LanguageModelTrainer:
|
||||
sampling_threshold: float, # Threshold for sampling decisions
|
||||
sampling_temperature: float, # Temperature for controlling sampling randomness
|
||||
config, # Configuration parameters (contains general, training, and inference settings)
|
||||
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
):
|
||||
# Save model, optimizer, and other configurations
|
||||
self.model = model
|
||||
|
||||
@ -111,6 +111,23 @@ class MultiEmbedding(nn.Module):
|
||||
def get_emb_by_key(self, key, token):
|
||||
layer_idx = self.feature_list.index(key)
|
||||
return self.layers[layer_idx](token)
|
||||
|
||||
def get_token_by_emb(self, key, token_emb):
|
||||
'''
|
||||
token_emb: B x emb_size
|
||||
'''
|
||||
layer_idx = self.feature_list.index(key)
|
||||
embedding_layer = self.layers[layer_idx] # nn.Embedding
|
||||
# compute cosine similarity between token_emb and embedding weights
|
||||
emb_weights = embedding_layer.weight # vocab_size x emb_size
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
token_emb.unsqueeze(1), # B x 1 x emb_size
|
||||
emb_weights.unsqueeze(0), # 1 x vocab_size x emb_size
|
||||
dim=-1
|
||||
) # B x vocab_size
|
||||
# get the index of the most similar embedding
|
||||
token_idx = torch.argmax(cos_sim, dim=-1) # B
|
||||
return token_idx
|
||||
|
||||
class SummationEmbedder(MultiEmbedding):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user