1021 add flexable attr control
This commit is contained in:
@ -67,8 +67,19 @@ def get_best_ckpt_path_and_config(wandb_dir, code):
|
||||
|
||||
return last_ckpt_fn, config_path, metadata_path, vocab_path
|
||||
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str, condition_dataset: str=None):
|
||||
# if config is a path, load it
|
||||
if isinstance(config, (str, Path)):
|
||||
from omegaconf import OmegaConf
|
||||
config = OmegaConf.load(config)
|
||||
config = wandb_style_config_to_omega_config(config)
|
||||
|
||||
nn_params = config.nn_params
|
||||
for_evaluation = True
|
||||
if condition_dataset is not None:
|
||||
print(f"Conditioned dataset {condition_dataset} is used instead of {config.dataset}")
|
||||
config.dataset = condition_dataset
|
||||
for_evaluation = False
|
||||
dataset_name = config.dataset
|
||||
vocab_path = Path(vocab_path)
|
||||
|
||||
@ -104,7 +115,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
|
||||
input_length=config.train_params.input_length,
|
||||
first_pred_feature=config.data_params.first_pred_feature,
|
||||
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
|
||||
for_evaluation=True,
|
||||
for_evaluation=for_evaluation
|
||||
)
|
||||
|
||||
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
|
||||
@ -114,7 +125,6 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
|
||||
split_ratio = config.data_params.split_ratio
|
||||
# test_set = []
|
||||
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
|
||||
|
||||
# get proper prediction order according to the encoding scheme and target feature in the config
|
||||
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
|
||||
|
||||
@ -480,6 +490,28 @@ class Evaluator:
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_with_attrCtl(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072, attr_list=None):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
tuneidx = tuneidx.cuda()
|
||||
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature, attr_list=attr_list)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
|
||||
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
|
||||
@ -102,7 +102,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
'''
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
# self.prediction_order = net.prediction_order
|
||||
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
|
||||
self.attribute2idx_after = {'pitch': 0,
|
||||
'duration': 1,
|
||||
'velocity': 2,
|
||||
'type': 3,
|
||||
'beat': 4,
|
||||
'chord': 5,
|
||||
'tempo': 6,
|
||||
'instrument': 7}
|
||||
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
||||
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
|
||||
return self.net(input_seq, target, context=context)
|
||||
|
||||
@ -164,7 +174,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
|
||||
return total_out
|
||||
|
||||
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
|
||||
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None):
|
||||
'''
|
||||
Runs one step of autoregressive generation by taking the input sequence, embedding it,
|
||||
passing it through the main decoder, and generating logits and a sampled token.
|
||||
@ -192,7 +202,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
|
||||
|
||||
# Generate the next token
|
||||
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
|
||||
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step)
|
||||
return logits, sampled_token, intermidiates, hidden_vec
|
||||
|
||||
def _update_total_out(self, total_out, sampled_token):
|
||||
@ -225,7 +235,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
return total_out, sampled_token
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None):
|
||||
'''
|
||||
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
|
||||
until the desired maximum sequence length is reached or the end token is encountered.
|
||||
@ -243,15 +253,19 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
- total_out: The generated sequence of tokens as a tensor.
|
||||
'''
|
||||
# Prepare the starting sequence for inference
|
||||
if attr_list is None:
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
|
||||
|
||||
# If a condition is provided, run one initial step
|
||||
if condition is not None:
|
||||
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||
else:
|
||||
cache = LayerIntermediates()
|
||||
total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures)
|
||||
# for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336
|
||||
condition_filtered = condition.clone().unsqueeze(0)
|
||||
# print(self.attribute2idx)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
if attr not in attr_list:
|
||||
condition_filtered[:, :, idx] = 126336
|
||||
# rearange condition_filtered to match prediction order
|
||||
|
||||
# Continue generating tokens until the maximum sequence length is reached
|
||||
cache = LayerIntermediates()
|
||||
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
|
||||
bos_hidden_vec = None
|
||||
hidden_vec_list = []
|
||||
@ -261,6 +275,20 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
|
||||
input_tensor = total_out[:, -self.net.input_length:]
|
||||
# Generate the next token and update the cache
|
||||
time_start = time.time()
|
||||
# if attr_list is not None, get one token in condition_filtered each time step
|
||||
if attr_list is not None:
|
||||
condition_filtered = condition_filtered.to(self.net.device)
|
||||
# print(condition_filtered[:,:20,:])
|
||||
# print(condition_filtered.shape)
|
||||
condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :]
|
||||
# rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4
|
||||
condition_step_rearranged = torch.zeros_like(condition_step)
|
||||
for attr, idx in self.attribute2idx.items():
|
||||
new_idx = self.attribute2idx_after[attr]
|
||||
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
|
||||
# print("condition_step shape:", condition_step.shape)
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
|
||||
else:
|
||||
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
|
||||
time_end = time.time()
|
||||
token_time_list.append(time_end - time_start)
|
||||
@ -416,11 +444,11 @@ class AmadeusModel(nn.Module):
|
||||
return self.decoder(input_seq, target, context=context)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
|
||||
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None):
|
||||
if batch_size == 1:
|
||||
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
|
||||
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list)
|
||||
else:
|
||||
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
|
||||
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list)
|
||||
|
||||
class AmadeusModel4Encodec(AmadeusModel):
|
||||
def __init__(
|
||||
|
||||
@ -43,6 +43,22 @@ def typical_sampling(logits, thres=0.99):
|
||||
scores = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||
return scores
|
||||
|
||||
def min_p_sampling(logits, alpha=0.05):
|
||||
"""
|
||||
logits: Tensor of shape [B, L, V]
|
||||
alpha: float, relative probability threshold (e.g., 0.05)
|
||||
"""
|
||||
# 计算 softmax 概率
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
|
||||
# 找到每个位置的最大概率
|
||||
max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1]
|
||||
|
||||
# 保留概率 >= alpha * max_prob 的 token
|
||||
mask = probs < (alpha * max_probs) # True 表示要屏蔽
|
||||
masked_logits = logits.masked_fill(mask, float('-inf'))
|
||||
return masked_logits
|
||||
|
||||
def add_gumbel_noise(logits, temperature):
|
||||
'''
|
||||
The Gumbel max is a method for sampling categorical distributions.
|
||||
@ -91,6 +107,8 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
|
||||
modified_logits = typical_sampling(logits, thres=threshold)
|
||||
elif sampling_method == "eta":
|
||||
modified_logits = eta_sampling(logits, epsilon=threshold)
|
||||
elif sampling_method == "min_p":
|
||||
modified_logits = min_p_sampling(logits, alpha=threshold)
|
||||
else:
|
||||
modified_logits = logits # 其他情况直接使用原始logits
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from re import T
|
||||
from selectors import EpollSelector
|
||||
from turtle import st
|
||||
from numpy import indices
|
||||
@ -6,7 +7,7 @@ import torch
|
||||
import torch.profiler
|
||||
import torch.nn as nn
|
||||
|
||||
from x_transformers import Decoder
|
||||
from .custom_x_transformers import Decoder
|
||||
|
||||
from .transformer_utils import MultiEmbedding, RVQMultiEmbedding
|
||||
from .sub_decoder_utils import *
|
||||
@ -146,7 +147,7 @@ class FeedForward(SubDecoderClass):
|
||||
f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items()
|
||||
})
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
target = input_dict['target']
|
||||
@ -204,7 +205,7 @@ class Parallel(SubDecoderClass):
|
||||
'''
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
target = input_dict['target']
|
||||
@ -414,7 +415,7 @@ class SelfAttention(SubDecoderClass):
|
||||
memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
|
||||
return memory_tensor
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] # B x T x num_sub_tokens
|
||||
@ -490,7 +491,7 @@ class SelfAttentionUniAudio(SelfAttention):
|
||||
memory_tensor = hidden_vec_reshape + feature_tensor
|
||||
return memory_tensor
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] # B x T x num_sub-tokens
|
||||
@ -604,7 +605,7 @@ class CrossAttention(SubDecoderClass):
|
||||
memory_list.append(BOS_emb[-1:, :, :])
|
||||
return memory_list
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target']
|
||||
@ -677,7 +678,7 @@ class Flatten4Encodec(SubDecoderClass):
|
||||
):
|
||||
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
|
||||
hidden_vec = input_dict['hidden_vec']
|
||||
|
||||
# ---- Training ---- #
|
||||
@ -941,94 +942,7 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
indices = torch.tensor([[step]], device=hidden_vec.device)
|
||||
return indices
|
||||
|
||||
|
||||
def forward_(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = input_seq
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
for step in range(self.denoising_steps):
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
# input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
# indices = self.choose_tokens(hidden_vec,step, "auto-regressive", stacked_logits_probs, num_transfer_tokens)
|
||||
indices = self.choose_tokens(hidden_vec, step, self.method, stacked_logits_probs, num_transfer_tokens)
|
||||
# breakpoint()
|
||||
# undate the masked history
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
# breakpoint()
|
||||
# print("stored_probs_dict", stored_probs_dict)
|
||||
# print("sampled_token_dict", sampled_token_dict)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
def forward_old(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
@ -1070,139 +984,25 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
# add attribute control here
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
candidate_token_embeddings = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
sampled_token,probs = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
# print(idx,feature,sampled_token,probs)
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
candidate_token_probs[feature] = probs
|
||||
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
|
||||
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
|
||||
candidate_token_embeddings[feature] = feature_emb_reshape
|
||||
|
||||
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, l)) # (B*T) x num_sub_tokens x vocab_size
|
||||
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, l, d))
|
||||
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
elif self.method == 'random':
|
||||
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
|
||||
elif self.method == 'auto-regressive':
|
||||
indices = torch.tensor([[step]], device=logit.device)
|
||||
# undate the masked history
|
||||
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
|
||||
if condition_step is not None:
|
||||
# print("shape of condition_step", condition_step.shape)
|
||||
condition_step = condition_step.reshape((b*t, l))
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
token = condition_step[i][j]
|
||||
if condition_step[i][j] != self.MASK_idx:
|
||||
|
||||
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
|
||||
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
|
||||
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
|
||||
# ---- Training ---- #
|
||||
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
|
||||
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
|
||||
# apply layer norm
|
||||
|
||||
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
|
||||
if worst_case: # mask all ,turn into parallel
|
||||
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
|
||||
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# all is embedding
|
||||
# memory_tensor = self.layer_norm(memory_tensor)
|
||||
# apply feature enricher to memory
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# implement sub decoder cross attention
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
# inter_input = torch.cat([input_seq_pos, memory_tensor], dim=1)
|
||||
# inter_input = input_seq_pos + memory_tensor # (B*T) x num_sub_tokens x d_model
|
||||
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# get prob
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
feature_pos = self.feature_order_in_output[feature]
|
||||
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
|
||||
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
|
||||
logits_dict[feature] = logit
|
||||
return logits_dict, (masked_indices, p_mask)
|
||||
|
||||
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, validation=False):
|
||||
logits_dict = {}
|
||||
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
|
||||
target = input_dict['target'] #B x T x d_model
|
||||
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
|
||||
|
||||
# apply window on hidden_vec for enricher
|
||||
if self.sub_decoder_enricher_use:
|
||||
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
|
||||
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
|
||||
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
if bos_hidden_vec is None: # start of generation
|
||||
if target is None:
|
||||
bos_hidden_vec = input_seq_pos
|
||||
else:
|
||||
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
|
||||
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
|
||||
|
||||
else:
|
||||
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# input_seq_pos = input_seq
|
||||
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
|
||||
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
|
||||
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
|
||||
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
|
||||
# prepare memory
|
||||
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
|
||||
# ---- Generate(Inference) ---- #
|
||||
if target is None:
|
||||
sampled_token_dict = {}
|
||||
b,t,d = hidden_vec.shape # B x T x d_model
|
||||
l = len(self.prediction_order) # num_sub_tokens
|
||||
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
|
||||
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
|
||||
|
||||
# indicate the position of the mask token,1 means that the token hsa been masked
|
||||
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
|
||||
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
|
||||
# denoising c
|
||||
stored_logits_dict = {}
|
||||
stored_probs_dict = {}
|
||||
# with torch.profiler.profile(
|
||||
# activities=[
|
||||
# torch.profiler.ProfilerActivity.CPU,
|
||||
@ -1213,8 +1013,6 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
# ) as prof:
|
||||
for step in range(self.denoising_steps):
|
||||
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
# nomalize the memory tensor
|
||||
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
|
||||
if self.sub_decoder_enricher_use:
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
|
||||
input_dict = self.feature_enricher_layers(input_dict)
|
||||
@ -1223,14 +1021,15 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
|
||||
input_dict = self.sub_decoder_layers(input_dict)
|
||||
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
|
||||
candidate_token_probs = {}
|
||||
|
||||
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
|
||||
force_decode=Force_decode,
|
||||
step=step)
|
||||
|
||||
# print("step", step)
|
||||
# print("toknes", sampled_token_dict)
|
||||
# set prob of the changed tokens to -inf
|
||||
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
|
||||
print("stacked_logits_probs", stacked_logits_probs.clone())
|
||||
|
||||
if self.method == 'low-confidence':
|
||||
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
|
||||
@ -1242,12 +1041,25 @@ class DiffusionDecoder(SubDecoderClass):
|
||||
for i in range(b*t):
|
||||
for j in range(l):
|
||||
if j in indices[i]:
|
||||
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
|
||||
masked_history[i][j] = False
|
||||
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
|
||||
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
|
||||
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
|
||||
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
|
||||
# skip if all tokens are unmasked
|
||||
if not expand_masked_history.any():
|
||||
# print("All tokens have been unmasked. Ending denoising process.")
|
||||
break
|
||||
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
|
||||
# get final sampled tokens by embedding the unmasked tokens
|
||||
sampled_token_dict = {}
|
||||
for idx, feature in enumerate(self.prediction_order):
|
||||
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
|
||||
sampled_token_dict[feature] = sampled_token
|
||||
# print("Final sampled tokens:")
|
||||
# print(sampled_token_dict)
|
||||
# print(condition_step)
|
||||
return stored_logits_dict, sampled_token_dict
|
||||
|
||||
# ---- Training ---- #
|
||||
|
||||
@ -510,6 +510,7 @@ class Melody(SymbolicMusicDataset):
|
||||
shuffled_tune_names = list(self.tune_in_idx.keys())
|
||||
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
||||
song_dict = {}
|
||||
ratio = 0.8
|
||||
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
||||
if song not in song_dict:
|
||||
song_dict[song] = []
|
||||
|
||||
@ -2,8 +2,8 @@ defaults:
|
||||
# - nn_params: nb8_embSum_NMT
|
||||
# - nn_params: remi8
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
- nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
# - nn_params: nb8_embSum_subPararell
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M
|
||||
|
||||
@ -15,7 +15,7 @@ defaults:
|
||||
# - nn_params: remi8_main12_head_16_dim512
|
||||
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||
|
||||
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||
@ -44,7 +44,7 @@ train_params:
|
||||
focal_gamma: 0
|
||||
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||
scheduler : cosinelr
|
||||
initial_lr: 0.0004
|
||||
initial_lr: 0.0003
|
||||
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||
warmup_steps: 2000 #number of warmup steps
|
||||
@ -59,7 +59,7 @@ inference_params:
|
||||
data_params:
|
||||
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
|
||||
split_ratio: 0.998 # train-validation-test split ratio
|
||||
aug_type: pitch # random, null | pitch and chord augmentation type
|
||||
aug_type: null # random, null | pitch and chord augmentation type
|
||||
general:
|
||||
debug: False
|
||||
make_log: True # True, False | update the log file in wandb online to your designated project and entity
|
||||
|
||||
@ -74,7 +74,8 @@ class LanguageModelTrainer:
|
||||
sampling_threshold: float, # Threshold for sampling decisions
|
||||
sampling_temperature: float, # Temperature for controlling sampling randomness
|
||||
config, # Configuration parameters (contains general, training, and inference settings)
|
||||
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
):
|
||||
# Save model, optimizer, and other configurations
|
||||
self.model = model
|
||||
|
||||
@ -112,6 +112,23 @@ class MultiEmbedding(nn.Module):
|
||||
layer_idx = self.feature_list.index(key)
|
||||
return self.layers[layer_idx](token)
|
||||
|
||||
def get_token_by_emb(self, key, token_emb):
|
||||
'''
|
||||
token_emb: B x emb_size
|
||||
'''
|
||||
layer_idx = self.feature_list.index(key)
|
||||
embedding_layer = self.layers[layer_idx] # nn.Embedding
|
||||
# compute cosine similarity between token_emb and embedding weights
|
||||
emb_weights = embedding_layer.weight # vocab_size x emb_size
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
token_emb.unsqueeze(1), # B x 1 x emb_size
|
||||
emb_weights.unsqueeze(0), # 1 x vocab_size x emb_size
|
||||
dim=-1
|
||||
) # B x vocab_size
|
||||
# get the index of the most similar embedding
|
||||
token_idx = torch.argmax(cos_sim, dim=-1) # B
|
||||
return token_idx
|
||||
|
||||
class SummationEmbedder(MultiEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -168,7 +168,7 @@ class CorpusMaker():
|
||||
print("length of midi list: ", len(self.midi_list))
|
||||
# Use set for faster lookup (O(1) per check)
|
||||
processed_files_set = set(processed_files)
|
||||
# self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
|
||||
self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
|
||||
# reverse the list to process the latest files first
|
||||
self.midi_list.reverse()
|
||||
print(f"length of midi list after filtering: ", len(self.midi_list))
|
||||
|
||||
@ -61,7 +61,7 @@ class Corpus2Event():
|
||||
# remove the corpus files that are already in the out_dir
|
||||
# Use set for faster existence checks
|
||||
existing_files = set(f.name for f in self.out_dir.glob("*.pkl"))
|
||||
# corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
|
||||
corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
|
||||
for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)):
|
||||
if event is None:
|
||||
broken_count += 1
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from ast import arg
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
@ -25,14 +26,25 @@ def get_argument_parser():
|
||||
parser.add_argument(
|
||||
"-generation_type",
|
||||
type=str,
|
||||
choices=('conditioned', 'unconditioned', 'text-conditioned'),
|
||||
choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'),
|
||||
default='unconditioned',
|
||||
help="generation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-attr_list",
|
||||
type=str,
|
||||
default="beat,duration",
|
||||
help="attribute list for attribute-controlled generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dataset",
|
||||
type=str,
|
||||
help="dataset name, only for conditioned generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-sampling_method",
|
||||
type=str,
|
||||
choices=('top_p', 'top_k'),
|
||||
choices=('top_p', 'top_k', 'min_p'),
|
||||
default='top_p',
|
||||
help="sampling method",
|
||||
)
|
||||
@ -74,7 +86,7 @@ def get_argument_parser():
|
||||
parser.add_argument(
|
||||
"-num_processes",
|
||||
type=int,
|
||||
default=4,
|
||||
default=1,
|
||||
help="number of processes to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -97,7 +109,7 @@ def get_argument_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
def load_resources(wandb_exp_dir, device):
|
||||
def load_resources(wandb_exp_dir, condition_dataset, device):
|
||||
"""Load model and dataset resources for a process"""
|
||||
wandb_dir = Path('wandb')
|
||||
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
|
||||
@ -107,7 +119,8 @@ def load_resources(wandb_exp_dir, device):
|
||||
# Load checkpoint to specified device
|
||||
print("Loading checkpoint from:", ckpt_path)
|
||||
ckpt = torch.load(ckpt_path, map_location=device)
|
||||
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
||||
print(config)
|
||||
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset)
|
||||
model.load_state_dict(ckpt['model'], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -123,20 +136,33 @@ def load_resources(wandb_exp_dir, device):
|
||||
|
||||
return config, model, dataset_for_prompt, vocab
|
||||
|
||||
def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
||||
def conditioned_worker(process_idx, gpu_id, args):
|
||||
"""Worker process for conditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
# print(test_set)
|
||||
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||
else:
|
||||
selected_tunes = [name for _, name in test_set][:args.num_samples]
|
||||
|
||||
# Split selected data across processes
|
||||
selected_data = [d for d in test_set if d[1] in selected_tunes]
|
||||
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||
start_idx = 1
|
||||
end_idx = min(chunk_size, len(selected_data))
|
||||
data_slice = selected_data[start_idx:end_idx]
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
||||
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
|
||||
|
||||
# Process assigned data slice
|
||||
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
||||
@ -154,13 +180,62 @@ def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
||||
generation_length=args.generate_length
|
||||
)
|
||||
|
||||
def attr_conditioned_worker(process_idx, gpu_id, args):
|
||||
"""Worker process for conditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
# attr_list = "position,duration"
|
||||
attr_list = args.attr_list.split(',')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
# print(test_set)
|
||||
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||
else:
|
||||
selected_tunes = [name for _, name in test_set][:args.num_samples]
|
||||
|
||||
# Split selected data across processes
|
||||
selected_data = [d for d in test_set if d[1] in selected_tunes]
|
||||
# chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||
# start_idx = 1
|
||||
# end_idx = min(chunk_size, len(selected_data))
|
||||
# data_slice = selected_data[start_idx:end_idx]
|
||||
data_slice = selected_data
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}"
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
|
||||
|
||||
# Process assigned data slice
|
||||
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
||||
batch_dir = base_path
|
||||
batch_dir.mkdir(parents=True, exist_ok=True)
|
||||
evaluator.generate_samples_with_attrCtl(
|
||||
batch_dir,
|
||||
args.num_target_measure,
|
||||
tune_in_idx,
|
||||
tune_name,
|
||||
config.data_params.first_pred_feature,
|
||||
args.sampling_method,
|
||||
args.threshold,
|
||||
args.temperature,
|
||||
generation_length=args.generate_length,
|
||||
attr_list=attr_list
|
||||
)
|
||||
|
||||
|
||||
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
|
||||
"""Worker process for unconditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
@ -187,7 +262,7 @@ def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
@ -237,36 +312,29 @@ def main():
|
||||
if not wandb_dir.exists():
|
||||
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
||||
|
||||
# Load test set to get selected tunes (dummy load to get dataset info)
|
||||
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
_, test_set, _ = prepare_model_and_dataset_from_config(
|
||||
wandb_dir / "files" / "config.yaml",
|
||||
wandb_dir / "files" / "metadata.json",
|
||||
wandb_dir / "files" / "vocab.json"
|
||||
)
|
||||
|
||||
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||
else:
|
||||
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
|
||||
|
||||
# Split selected data across processes
|
||||
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
|
||||
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||
|
||||
for i in range(args.num_processes):
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min((i+1)*chunk_size, len(selected_data))
|
||||
data_slice = selected_data[start_idx:end_idx]
|
||||
|
||||
if not data_slice:
|
||||
continue
|
||||
|
||||
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||
p = Process(
|
||||
target=conditioned_worker,
|
||||
args=(i, gpu_id, args, data_slice)
|
||||
args=(i, gpu_id, args)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
elif args.generation_type == 'attr-conditioned':
|
||||
# Prepare selected tunes
|
||||
wandb_dir = Path('wandb') / args.wandb_exp_dir
|
||||
if not wandb_dir.exists():
|
||||
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
||||
|
||||
|
||||
for i in range(args.num_processes):
|
||||
|
||||
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||
p = Process(
|
||||
target=attr_conditioned_worker,
|
||||
args=(i, gpu_id, args)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
BIN
len_tunes/Melody/len_nb8.png
Normal file
BIN
len_tunes/Melody/len_nb8.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
BIN
len_tunes/msmidi/len_nb8.png
Normal file
BIN
len_tunes/msmidi/len_nb8.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
134
midi_sim.py
Normal file
134
midi_sim.py
Normal file
@ -0,0 +1,134 @@
|
||||
import os
|
||||
from math import ceil
|
||||
#CUDA_VISIBLE_DEVICES= "0"
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from symusic import Score
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
|
||||
|
||||
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (0., 1.5)):
|
||||
if(not a.shape[1] or not b.shape[1]):
|
||||
return np.inf
|
||||
a_onset, a_pitch = a
|
||||
b_onset, b_pitch = b
|
||||
a_onset = a_onset.astype(np.float32)
|
||||
b_onset = b_onset.astype(np.float32)
|
||||
a_pitch = a_pitch.astype(np.int16)
|
||||
b_pitch = b_pitch.astype(np.int16)
|
||||
|
||||
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
|
||||
a2b_idx = onset_dist_matrix.argmin(1)
|
||||
b2a_idx = onset_dist_matrix.argmin(0)
|
||||
|
||||
a_pitch -= (np.median(a_pitch) - np.median(b_pitch)).astype(np.int16) # Normalize pitch
|
||||
a_pitch = a_pitch + np.arange(-7, 7).reshape(-1, 1) # Transpose invarient
|
||||
|
||||
interval_diff = np.concatenate([
|
||||
a_pitch[:, a2b_idx] - b_pitch,
|
||||
b_pitch[b2a_idx] - a_pitch], axis=1)
|
||||
pitch_dist = np.abs(semitone2degree[interval_diff % 8] + np.abs(interval_diff) // 8 * np.sign(interval_diff)).mean(1).min()
|
||||
onset_dist = np.abs(np.concatenate([
|
||||
a_onset[a2b_idx] - b_onset,
|
||||
b_onset[b2a_idx] - a_onset], axis=0)).mean()
|
||||
|
||||
return (weight[0] * onset_dist + weight[1] * pitch_dist) / sum(weight)
|
||||
|
||||
|
||||
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 8., hop_size: float = 4.):
|
||||
x = sorted(x)
|
||||
trim_offset = (x[0][0] // hop_size) * hop_size
|
||||
end_time = x[-1][0]
|
||||
num_segment = ceil((end_time - window_size - trim_offset) / hop_size) + 1
|
||||
|
||||
time_matrix = (np.fromiter((time for time, _ in x), dtype=float) - trim_offset).reshape(1, -1).repeat(num_segment, axis=0)
|
||||
seg_time_starts = np.arange(num_segment).reshape(-1, 1) * hop_size
|
||||
|
||||
time_compare_matrix = np.where((time_matrix >= seg_time_starts) & (time_matrix <= seg_time_starts + window_size), 0, 1)
|
||||
time_compare_matrix = np.diff(np.pad(time_compare_matrix, ((0, 0), (1, 1)), constant_values=1))
|
||||
start_idxs = sorted(np.where(time_compare_matrix == -1), key=lambda x: x[0])[1].tolist()
|
||||
end_idxs = sorted(np.where(time_compare_matrix == 1), key=lambda x: x[0])[1].tolist()
|
||||
|
||||
segments = [x[start:end] for start, end in zip(start_idxs, end_idxs)]
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
|
||||
a = midi_time_sliding_window(a, window_size=window_size, hop_size=hop_size)
|
||||
b = midi_time_sliding_window(b, window_size=window_size, hop_size=hop_size)
|
||||
dist = np.inf
|
||||
for x,i in enumerate(a):
|
||||
for y,j in enumerate(b):
|
||||
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
|
||||
if cur_dist == 0:
|
||||
print(x, y)
|
||||
if(cur_dist < dist):
|
||||
dist = cur_dist
|
||||
return float(dist)
|
||||
|
||||
|
||||
def extract_notes(filepath: str):
|
||||
"""读取MIDI并返回 (time, pitch) 列表"""
|
||||
try:
|
||||
s = Score(filepath).to("quarter")
|
||||
notes = []
|
||||
# for t in s.tracks:
|
||||
# notes.extend([(n.time, n.pitch) for n in t.notes])
|
||||
notes = [(n.time, n.pitch) for n in s.tracks[0].notes] # 仅使用第一个track
|
||||
return notes
|
||||
except Exception as e:
|
||||
print(f"读取 {filepath} 出错: {e}")
|
||||
return []
|
||||
|
||||
def compare_pair(file_a: str, file_b: str):
|
||||
try:
|
||||
notes_a = extract_notes(file_a)
|
||||
notes_b = extract_notes(file_b)
|
||||
if not notes_a or not notes_b:
|
||||
return (file_a, file_b, np.inf)
|
||||
dist = midi_dist(notes_a, notes_b)
|
||||
return (file_a, file_b, dist)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"⚠️ compare_pair 出错: {file_a} vs {file_b}")
|
||||
traceback.print_exc()
|
||||
return (file_a, file_b, np.inf)
|
||||
|
||||
|
||||
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
|
||||
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
|
||||
files_a = files_a[:100] # 仅比较前100个文件以节省时间
|
||||
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
|
||||
|
||||
results = []
|
||||
pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files")
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
|
||||
for fut in as_completed(futures):
|
||||
pbar.update(1)
|
||||
try:
|
||||
results.append(fut.result())
|
||||
except Exception as e:
|
||||
print(fut.result())
|
||||
print(f"Error comparing pair: {e}")
|
||||
# print(f"Compared: {results[-1][0]} vs {results[-1][1]}, Distance: {results[-1][2]:.4f}")
|
||||
# with tqdm(total=len(files_a) * len(files_b)) as pbar:
|
||||
# for fa in files_a:
|
||||
# for fb in files_b:
|
||||
# results.append(compare_pair(fa, fb))
|
||||
# pbar.update(1)
|
||||
# # 排序
|
||||
results = sorted(results, key=lambda x: x[2])
|
||||
|
||||
# 保存
|
||||
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
|
||||
df.to_csv(out_csv, index=False)
|
||||
print(f"已保存结果到 {out_csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dir_a = "wandb/run-20251015_154556-f0pj3ys3/cond_4m_top_p_t0.99_temp1.25/process_2_batch_23"
|
||||
dir_b = "dataset/Melody"
|
||||
batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6)
|
||||
105
,idi_sim.py
105
,idi_sim.py
@ -1,105 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from symusic import Score
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
|
||||
|
||||
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True):
|
||||
if(not a.shape[1] or not b.shape[1]):
|
||||
return np.inf
|
||||
a_onset, a_pitch = a
|
||||
b_onset, b_pitch = b
|
||||
a_onset = a_onset.astype(np.float32)
|
||||
b_onset = b_onset.astype(np.float32)
|
||||
a_pitch = a_pitch.astype(np.uint8)
|
||||
b_pitch = b_pitch.astype(np.uint8)
|
||||
|
||||
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
|
||||
if(oti):
|
||||
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12]
|
||||
dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight)
|
||||
a2b = dist_matrix.min(2)
|
||||
b2a = dist_matrix.min(1)
|
||||
dist = np.concatenate([a2b, b2a], axis=1)
|
||||
return dist.sum(axis=1).min() / len(dist)
|
||||
else:
|
||||
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12]
|
||||
dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight)
|
||||
a2b = dist_matrix.min(1)
|
||||
b2a = dist_matrix.min(0)
|
||||
return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1]))
|
||||
|
||||
|
||||
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.):
|
||||
x = sorted(x)
|
||||
end_time = x[-1][0]
|
||||
out = [[] for _ in range(int(end_time // hop_size))]
|
||||
for i in sorted(x):
|
||||
segment = min(int(i[0] // hop_size), len(out) - 1)
|
||||
while(i[0] >= segment * hop_size):
|
||||
out[segment].append(i)
|
||||
segment -= 1
|
||||
if(segment < 0):
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
|
||||
a = midi_time_sliding_window(a)
|
||||
b = midi_time_sliding_window(b)
|
||||
dist = np.inf
|
||||
for i in a:
|
||||
for j in b:
|
||||
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
|
||||
if(cur_dist < dist):
|
||||
dist = cur_dist
|
||||
return dist
|
||||
|
||||
|
||||
def extract_notes(filepath: str):
|
||||
"""读取MIDI并返回 (time, pitch) 列表"""
|
||||
try:
|
||||
s = Score(filepath).to("quarter")
|
||||
notes = []
|
||||
for t in s.tracks:
|
||||
notes.extend([(n.time, n.pitch) for n in t.notes])
|
||||
return notes
|
||||
except Exception as e:
|
||||
print(f"读取 {filepath} 出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def compare_pair(file_a: str, file_b: str):
|
||||
notes_a = extract_notes(file_a)
|
||||
notes_b = extract_notes(file_b)
|
||||
if not notes_a or not notes_b:
|
||||
return (file_a, file_b, np.inf)
|
||||
dist = midi_dist(notes_a, notes_b)
|
||||
return (file_a, file_b, dist)
|
||||
|
||||
|
||||
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
|
||||
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
|
||||
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
|
||||
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
|
||||
for fut in as_completed(futures):
|
||||
results.append(fut.result())
|
||||
|
||||
# 排序
|
||||
results = sorted(results, key=lambda x: x[2])
|
||||
|
||||
# 保存
|
||||
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
|
||||
df.to_csv(out_csv, index=False)
|
||||
print(f"已保存结果到 {out_csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dir_a = "folder_a"
|
||||
dir_b = "folder_b"
|
||||
batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8)
|
||||
Reference in New Issue
Block a user