1021 add flexable attr control

This commit is contained in:
FelixChan
2025-10-21 15:27:03 +08:00
parent d6b68ef90b
commit b493ede479
15 changed files with 400 additions and 394 deletions

View File

@ -67,8 +67,19 @@ def get_best_ckpt_path_and_config(wandb_dir, code):
return last_ckpt_fn, config_path, metadata_path, vocab_path return last_ckpt_fn, config_path, metadata_path, vocab_path
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str): def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str, condition_dataset: str=None):
# if config is a path, load it
if isinstance(config, (str, Path)):
from omegaconf import OmegaConf
config = OmegaConf.load(config)
config = wandb_style_config_to_omega_config(config)
nn_params = config.nn_params nn_params = config.nn_params
for_evaluation = True
if condition_dataset is not None:
print(f"Conditioned dataset {condition_dataset} is used instead of {config.dataset}")
config.dataset = condition_dataset
for_evaluation = False
dataset_name = config.dataset dataset_name = config.dataset
vocab_path = Path(vocab_path) vocab_path = Path(vocab_path)
@ -104,7 +115,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
input_length=config.train_params.input_length, input_length=config.train_params.input_length,
first_pred_feature=config.data_params.first_pred_feature, first_pred_feature=config.data_params.first_pred_feature,
caption_path=config.captions_path if hasattr(config, 'captions_path') else None, caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
for_evaluation=True, for_evaluation=for_evaluation
) )
vocab_sizes = symbolic_dataset.vocab.get_vocab_size() vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
@ -114,7 +125,6 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
split_ratio = config.data_params.split_ratio split_ratio = config.data_params.split_ratio
# test_set = [] # test_set = []
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None) train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
# get proper prediction order according to the encoding scheme and target feature in the config # get proper prediction order according to the encoding scheme and target feature in the config
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params) prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
@ -480,6 +490,28 @@ class Evaluator:
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8) prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid")) decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_with_attrCtl(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072, attr_list=None):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
tuneidx = tuneidx.cuda()
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature, attr_list=attr_list)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1): def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = self.config.nn_params.encoding_scheme encoding_scheme = self.config.nn_params.encoding_scheme

View File

@ -102,7 +102,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
''' '''
super().__init__() super().__init__()
self.net = net self.net = net
# self.prediction_order = net.prediction_order
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
self.attribute2idx_after = {'pitch': 0,
'duration': 1,
'velocity': 2,
'type': 3,
'beat': 4,
'chord': 5,
'tempo': 6,
'instrument': 7}
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None): def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context) return self.net(input_seq, target, context=context)
@ -164,7 +174,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device) total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
return total_out return total_out
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None): def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None):
''' '''
Runs one step of autoregressive generation by taking the input sequence, embedding it, Runs one step of autoregressive generation by taking the input sequence, embedding it,
passing it through the main decoder, and generating logits and a sampled token. passing it through the main decoder, and generating logits and a sampled token.
@ -192,7 +202,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec} input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
# Generate the next token # Generate the next token
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature) logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step)
return logits, sampled_token, intermidiates, hidden_vec return logits, sampled_token, intermidiates, hidden_vec
def _update_total_out(self, total_out, sampled_token): def _update_total_out(self, total_out, sampled_token):
@ -225,7 +235,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
return total_out, sampled_token return total_out, sampled_token
@torch.inference_mode() @torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None): def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None):
''' '''
Autoregressively generates a sequence of tokens by repeatedly sampling the next token Autoregressively generates a sequence of tokens by repeatedly sampling the next token
until the desired maximum sequence length is reached or the end token is encountered. until the desired maximum sequence length is reached or the end token is encountered.
@ -243,15 +253,19 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
- total_out: The generated sequence of tokens as a tensor. - total_out: The generated sequence of tokens as a tensor.
''' '''
# Prepare the starting sequence for inference # Prepare the starting sequence for inference
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures) if attr_list is None:
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
# If a condition is provided, run one initial step
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
else: else:
cache = LayerIntermediates() total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures)
# for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336
condition_filtered = condition.clone().unsqueeze(0)
# print(self.attribute2idx)
for attr, idx in self.attribute2idx.items():
if attr not in attr_list:
condition_filtered[:, :, idx] = 126336
# rearange condition_filtered to match prediction order
# Continue generating tokens until the maximum sequence length is reached cache = LayerIntermediates()
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token") pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
bos_hidden_vec = None bos_hidden_vec = None
hidden_vec_list = [] hidden_vec_list = []
@ -261,7 +275,21 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
input_tensor = total_out[:, -self.net.input_length:] input_tensor = total_out[:, -self.net.input_length:]
# Generate the next token and update the cache # Generate the next token and update the cache
time_start = time.time() time_start = time.time()
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context) # if attr_list is not None, get one token in condition_filtered each time step
if attr_list is not None:
condition_filtered = condition_filtered.to(self.net.device)
# print(condition_filtered[:,:20,:])
# print(condition_filtered.shape)
condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :]
# rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4
condition_step_rearranged = torch.zeros_like(condition_step)
for attr, idx in self.attribute2idx.items():
new_idx = self.attribute2idx_after[attr]
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
# print("condition_step shape:", condition_step.shape)
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
else:
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time() time_end = time.time()
token_time_list.append(time_end - time_start) token_time_list.append(time_end - time_start)
if bos_hidden_vec is None: if bos_hidden_vec is None:
@ -416,11 +444,11 @@ class AmadeusModel(nn.Module):
return self.decoder(input_seq, target, context=context) return self.decoder(input_seq, target, context=context)
@torch.inference_mode() @torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None): def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None):
if batch_size == 1: if batch_size == 1:
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context) return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list)
else: else:
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context) return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list)
class AmadeusModel4Encodec(AmadeusModel): class AmadeusModel4Encodec(AmadeusModel):
def __init__( def __init__(

View File

@ -43,6 +43,22 @@ def typical_sampling(logits, thres=0.99):
scores = logits.masked_fill(indices_to_remove, float("-inf")) scores = logits.masked_fill(indices_to_remove, float("-inf"))
return scores return scores
def min_p_sampling(logits, alpha=0.05):
"""
logits: Tensor of shape [B, L, V]
alpha: float, relative probability threshold (e.g., 0.05)
"""
# 计算 softmax 概率
probs = F.softmax(logits, dim=-1)
# 找到每个位置的最大概率
max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1]
# 保留概率 >= alpha * max_prob 的 token
mask = probs < (alpha * max_probs) # True 表示要屏蔽
masked_logits = logits.masked_fill(mask, float('-inf'))
return masked_logits
def add_gumbel_noise(logits, temperature): def add_gumbel_noise(logits, temperature):
''' '''
The Gumbel max is a method for sampling categorical distributions. The Gumbel max is a method for sampling categorical distributions.
@ -91,6 +107,8 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
modified_logits = typical_sampling(logits, thres=threshold) modified_logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta": elif sampling_method == "eta":
modified_logits = eta_sampling(logits, epsilon=threshold) modified_logits = eta_sampling(logits, epsilon=threshold)
elif sampling_method == "min_p":
modified_logits = min_p_sampling(logits, alpha=threshold)
else: else:
modified_logits = logits # 其他情况直接使用原始logits modified_logits = logits # 其他情况直接使用原始logits

View File

@ -1,3 +1,4 @@
from re import T
from selectors import EpollSelector from selectors import EpollSelector
from turtle import st from turtle import st
from numpy import indices from numpy import indices
@ -6,7 +7,7 @@ import torch
import torch.profiler import torch.profiler
import torch.nn as nn import torch.nn as nn
from x_transformers import Decoder from .custom_x_transformers import Decoder
from .transformer_utils import MultiEmbedding, RVQMultiEmbedding from .transformer_utils import MultiEmbedding, RVQMultiEmbedding
from .sub_decoder_utils import * from .sub_decoder_utils import *
@ -146,7 +147,7 @@ class FeedForward(SubDecoderClass):
f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items() f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items()
}) })
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {} logits_dict = {}
hidden_vec = input_dict['hidden_vec'] hidden_vec = input_dict['hidden_vec']
target = input_dict['target'] target = input_dict['target']
@ -204,7 +205,7 @@ class Parallel(SubDecoderClass):
''' '''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use) super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {} logits_dict = {}
hidden_vec = input_dict['hidden_vec'] hidden_vec = input_dict['hidden_vec']
target = input_dict['target'] target = input_dict['target']
@ -414,7 +415,7 @@ class SelfAttention(SubDecoderClass):
memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
return memory_tensor return memory_tensor
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {} logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] # B x T x num_sub_tokens target = input_dict['target'] # B x T x num_sub_tokens
@ -490,7 +491,7 @@ class SelfAttentionUniAudio(SelfAttention):
memory_tensor = hidden_vec_reshape + feature_tensor memory_tensor = hidden_vec_reshape + feature_tensor
return memory_tensor return memory_tensor
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {} logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] # B x T x num_sub-tokens target = input_dict['target'] # B x T x num_sub-tokens
@ -604,7 +605,7 @@ class CrossAttention(SubDecoderClass):
memory_list.append(BOS_emb[-1:, :, :]) memory_list.append(BOS_emb[-1:, :, :])
return memory_list return memory_list
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {} logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] target = input_dict['target']
@ -677,7 +678,7 @@ class Flatten4Encodec(SubDecoderClass):
): ):
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use) super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
hidden_vec = input_dict['hidden_vec'] hidden_vec = input_dict['hidden_vec']
# ---- Training ---- # # ---- Training ---- #
@ -838,7 +839,7 @@ class DiffusionDecoder(SubDecoderClass):
This function is designed to precompute the number of tokens that need to be transitioned at each step. This function is designed to precompute the number of tokens that need to be transitioned at each step.
''' '''
mask_num = mask_index.sum(dim=1, keepdim=True) mask_num = mask_index.sum(dim=1,keepdim=True)
base = mask_num // steps base = mask_num // steps
remainder = mask_num % steps remainder = mask_num % steps
@ -941,94 +942,7 @@ class DiffusionDecoder(SubDecoderClass):
indices = torch.tensor([[step]], device=hidden_vec.device) indices = torch.tensor([[step]], device=hidden_vec.device)
return indices return indices
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
def forward_(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = input_seq
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
stored_logits_dict = {}
stored_probs_dict = {}
for step in range(self.denoising_steps):
# nomalize the memory tensor
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
# input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
candidate_token_probs = {}
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
# indices = self.choose_tokens(hidden_vec,step, "auto-regressive", stacked_logits_probs, num_transfer_tokens)
indices = self.choose_tokens(hidden_vec, step, self.method, stacked_logits_probs, num_transfer_tokens)
# breakpoint()
# undate the masked history
for i in range(b*t):
for j in range(l):
if j in indices[i]:
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
# breakpoint()
# print("stored_probs_dict", stored_probs_dict)
# print("sampled_token_dict", sampled_token_dict)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
def forward_old(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
logits_dict = {} logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model target = input_dict['target'] #B x T x d_model
@ -1070,139 +984,25 @@ class DiffusionDecoder(SubDecoderClass):
# indicate the position of the mask token,1 means that the token hsa been masked # indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool() masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps) # add attribute control here
# denoising c
stored_logits_dict = {} stored_logits_dict = {}
stored_probs_dict = {} stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
for step in range(self.denoising_steps): if condition_step is not None:
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model # print("shape of condition_step", condition_step.shape)
# nomalize the memory tensor condition_step = condition_step.reshape((b*t, l))
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
candidate_token_probs = {}
candidate_token_embeddings = {}
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
sampled_token,probs = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
# print(idx,feature,sampled_token,probs)
sampled_token_dict[feature] = sampled_token
candidate_token_probs[feature] = probs
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, l)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, l, d))
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif self.method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif self.method == 'auto-regressive':
indices = torch.tensor([[step]], device=logit.device)
# undate the masked history
for i in range(b*t): for i in range(b*t):
for j in range(l): for j in range(l):
if j in indices[i]: token = condition_step[i][j]
if condition_step[i][j] != self.MASK_idx:
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone() memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone() stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model # print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# all is embedding
# memory_tensor = self.layer_norm(memory_tensor)
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
# inter_input = torch.cat([input_seq_pos, memory_tensor], dim=1)
# inter_input = input_seq_pos + memory_tensor # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, validation=False):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
if bos_hidden_vec is None: # start of generation
if target is None:
bos_hidden_vec = input_seq_pos
else:
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
else:
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = input_seq
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps) num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c # denoising c
stored_logits_dict = {}
stored_probs_dict = {}
# with torch.profiler.profile( # with torch.profiler.profile(
# activities=[ # activities=[
# torch.profiler.ProfilerActivity.CPU, # torch.profiler.ProfilerActivity.CPU,
@ -1213,8 +1013,6 @@ class DiffusionDecoder(SubDecoderClass):
# ) as prof: # ) as prof:
for step in range(self.denoising_steps): for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# nomalize the memory tensor
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use: if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict) input_dict = self.feature_enricher_layers(input_dict)
@ -1223,14 +1021,15 @@ class DiffusionDecoder(SubDecoderClass):
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask} input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict) input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
candidate_token_probs = {}
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature, sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
force_decode=Force_decode, force_decode=Force_decode,
step=step) step=step)
# print("step", step)
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf # set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf) stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence': if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1) _, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
@ -1242,12 +1041,25 @@ class DiffusionDecoder(SubDecoderClass):
for i in range(b*t): for i in range(b*t):
for j in range(l): for j in range(l):
if j in indices[i]: if j in indices[i]:
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone() stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings) memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
# skip if all tokens are unmasked
if not expand_masked_history.any():
# print("All tokens have been unmasked. Ending denoising process.")
break
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) # print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
# get final sampled tokens by embedding the unmasked tokens
sampled_token_dict = {}
for idx, feature in enumerate(self.prediction_order):
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
sampled_token_dict[feature] = sampled_token
# print("Final sampled tokens:")
# print(sampled_token_dict) # print(sampled_token_dict)
# print(condition_step)
return stored_logits_dict, sampled_token_dict return stored_logits_dict, sampled_token_dict
# ---- Training ---- # # ---- Training ---- #

View File

@ -510,6 +510,7 @@ class Melody(SymbolicMusicDataset):
shuffled_tune_names = list(self.tune_in_idx.keys()) shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {} song_dict = {}
ratio = 0.8
for song, orig_song in zip(song_names_without_version, shuffled_tune_names): for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict: if song not in song_dict:
song_dict[song] = [] song_dict[song] = []

View File

@ -2,8 +2,8 @@ defaults:
# - nn_params: nb8_embSum_NMT # - nn_params: nb8_embSum_NMT
# - nn_params: remi8 # - nn_params: remi8
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning # - nn_params: nb8_embSum_diff_t2m_150M_finetunning
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2 - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2 # - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
# - nn_params: nb8_embSum_subPararell # - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M # - nn_params: nb8_embSum_diff_t2m_150M
@ -15,7 +15,7 @@ defaults:
# - nn_params: remi8_main12_head_16_dim512 # - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3 # - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean # dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
@ -44,7 +44,7 @@ train_params:
focal_gamma: 0 focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details # learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr scheduler : cosinelr
initial_lr: 0.0004 initial_lr: 0.0003
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts' num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps warmup_steps: 2000 #number of warmup steps
@ -59,7 +59,7 @@ inference_params:
data_params: data_params:
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument) first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
split_ratio: 0.998 # train-validation-test split ratio split_ratio: 0.998 # train-validation-test split ratio
aug_type: pitch # random, null | pitch and chord augmentation type aug_type: null # random, null | pitch and chord augmentation type
general: general:
debug: False debug: False
make_log: True # True, False | update the log file in wandb online to your designated project and entity make_log: True # True, False | update the log file in wandb online to your designated project and entity

View File

@ -74,7 +74,8 @@ class LanguageModelTrainer:
sampling_threshold: float, # Threshold for sampling decisions sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings) config, # Configuration parameters (contains general, training, and inference settings)
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional) model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
): ):
# Save model, optimizer, and other configurations # Save model, optimizer, and other configurations
self.model = model self.model = model

View File

@ -112,6 +112,23 @@ class MultiEmbedding(nn.Module):
layer_idx = self.feature_list.index(key) layer_idx = self.feature_list.index(key)
return self.layers[layer_idx](token) return self.layers[layer_idx](token)
def get_token_by_emb(self, key, token_emb):
'''
token_emb: B x emb_size
'''
layer_idx = self.feature_list.index(key)
embedding_layer = self.layers[layer_idx] # nn.Embedding
# compute cosine similarity between token_emb and embedding weights
emb_weights = embedding_layer.weight # vocab_size x emb_size
cos_sim = torch.nn.functional.cosine_similarity(
token_emb.unsqueeze(1), # B x 1 x emb_size
emb_weights.unsqueeze(0), # 1 x vocab_size x emb_size
dim=-1
) # B x vocab_size
# get the index of the most similar embedding
token_idx = torch.argmax(cos_sim, dim=-1) # B
return token_idx
class SummationEmbedder(MultiEmbedding): class SummationEmbedder(MultiEmbedding):
def __init__( def __init__(
self, self,

View File

@ -168,7 +168,7 @@ class CorpusMaker():
print("length of midi list: ", len(self.midi_list)) print("length of midi list: ", len(self.midi_list))
# Use set for faster lookup (O(1) per check) # Use set for faster lookup (O(1) per check)
processed_files_set = set(processed_files) processed_files_set = set(processed_files)
# self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set] self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
# reverse the list to process the latest files first # reverse the list to process the latest files first
self.midi_list.reverse() self.midi_list.reverse()
print(f"length of midi list after filtering: ", len(self.midi_list)) print(f"length of midi list after filtering: ", len(self.midi_list))

View File

@ -61,7 +61,7 @@ class Corpus2Event():
# remove the corpus files that are already in the out_dir # remove the corpus files that are already in the out_dir
# Use set for faster existence checks # Use set for faster existence checks
existing_files = set(f.name for f in self.out_dir.glob("*.pkl")) existing_files = set(f.name for f in self.out_dir.glob("*.pkl"))
# corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files] corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)): for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)):
if event is None: if event is None:
broken_count += 1 broken_count += 1

View File

@ -1,3 +1,4 @@
from ast import arg
import sys import sys
import os import os
from pathlib import Path from pathlib import Path
@ -25,14 +26,25 @@ def get_argument_parser():
parser.add_argument( parser.add_argument(
"-generation_type", "-generation_type",
type=str, type=str,
choices=('conditioned', 'unconditioned', 'text-conditioned'), choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'),
default='unconditioned', default='unconditioned',
help="generation type", help="generation type",
) )
parser.add_argument(
"-attr_list",
type=str,
default="beat,duration",
help="attribute list for attribute-controlled generation",
)
parser.add_argument(
"-dataset",
type=str,
help="dataset name, only for conditioned generation",
)
parser.add_argument( parser.add_argument(
"-sampling_method", "-sampling_method",
type=str, type=str,
choices=('top_p', 'top_k'), choices=('top_p', 'top_k', 'min_p'),
default='top_p', default='top_p',
help="sampling method", help="sampling method",
) )
@ -74,7 +86,7 @@ def get_argument_parser():
parser.add_argument( parser.add_argument(
"-num_processes", "-num_processes",
type=int, type=int,
default=4, default=1,
help="number of processes to use", help="number of processes to use",
) )
parser.add_argument( parser.add_argument(
@ -97,7 +109,7 @@ def get_argument_parser():
) )
return parser return parser
def load_resources(wandb_exp_dir, device): def load_resources(wandb_exp_dir, condition_dataset, device):
"""Load model and dataset resources for a process""" """Load model and dataset resources for a process"""
wandb_dir = Path('wandb') wandb_dir = Path('wandb')
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir) ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
@ -107,7 +119,8 @@ def load_resources(wandb_exp_dir, device):
# Load checkpoint to specified device # Load checkpoint to specified device
print("Loading checkpoint from:", ckpt_path) print("Loading checkpoint from:", ckpt_path)
ckpt = torch.load(ckpt_path, map_location=device) ckpt = torch.load(ckpt_path, map_location=device)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path) print(config)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset)
model.load_state_dict(ckpt['model'], strict=False) model.load_state_dict(ckpt['model'], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
@ -123,20 +136,33 @@ def load_resources(wandb_exp_dir, device):
return config, model, dataset_for_prompt, vocab return config, model, dataset_for_prompt, vocab
def conditioned_worker(process_idx, gpu_id, args, data_slice): def conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation""" """Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id) torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}') device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device # Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device) config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
start_idx = 1
end_idx = min(chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
# Create output directory with process index # Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \ base_path = Path('wandb') / args.wandb_exp_dir / \
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}" f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True) base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device) evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice # Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice): for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
@ -154,13 +180,62 @@ def conditioned_worker(process_idx, gpu_id, args, data_slice):
generation_length=args.generate_length generation_length=args.generate_length
) )
def attr_conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# attr_list = "position,duration"
attr_list = args.attr_list.split(',')
# Load resources with proper device
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
# chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
# start_idx = 1
# end_idx = min(chunk_size, len(selected_data))
# data_slice = selected_data[start_idx:end_idx]
data_slice = selected_data
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
batch_dir = base_path
batch_dir.mkdir(parents=True, exist_ok=True)
evaluator.generate_samples_with_attrCtl(
batch_dir,
args.num_target_measure,
tune_in_idx,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
attr_list=attr_list
)
def unconditioned_worker(process_idx, gpu_id, args, num_samples): def unconditioned_worker(process_idx, gpu_id, args, num_samples):
"""Worker process for unconditioned generation""" """Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id) torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}') device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device # Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device) config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index # Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \ base_path = Path('wandb') / args.wandb_exp_dir / \
@ -187,7 +262,7 @@ def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
device = torch.device(f'cuda:{gpu_id}') device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device # Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device) config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index # Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \ base_path = Path('wandb') / args.wandb_exp_dir / \
@ -237,36 +312,29 @@ def main():
if not wandb_dir.exists(): if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found") raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
# Load test set to get selected tunes (dummy load to get dataset info)
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, test_set, _ = prepare_model_and_dataset_from_config(
wandb_dir / "files" / "config.yaml",
wandb_dir / "files" / "metadata.json",
wandb_dir / "files" / "vocab.json"
)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
for i in range(args.num_processes): for i in range(args.num_processes):
start_idx = i * chunk_size
end_idx = min((i+1)*chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
if not data_slice:
continue
gpu_id = gpu_ids[i % len(gpu_ids)] gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process( p = Process(
target=conditioned_worker, target=conditioned_worker,
args=(i, gpu_id, args, data_slice) args=(i, gpu_id, args)
)
processes.append(p)
p.start()
elif args.generation_type == 'attr-conditioned':
# Prepare selected tunes
wandb_dir = Path('wandb') / args.wandb_exp_dir
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=attr_conditioned_worker,
args=(i, gpu_id, args)
) )
processes.append(p) processes.append(p)
p.start() p.start()

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

134
midi_sim.py Normal file
View File

@ -0,0 +1,134 @@
import os
from math import ceil
#CUDA_VISIBLE_DEVICES= "0"
import numpy as np
import pandas as pd
from symusic import Score
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (0., 1.5)):
if(not a.shape[1] or not b.shape[1]):
return np.inf
a_onset, a_pitch = a
b_onset, b_pitch = b
a_onset = a_onset.astype(np.float32)
b_onset = b_onset.astype(np.float32)
a_pitch = a_pitch.astype(np.int16)
b_pitch = b_pitch.astype(np.int16)
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
a2b_idx = onset_dist_matrix.argmin(1)
b2a_idx = onset_dist_matrix.argmin(0)
a_pitch -= (np.median(a_pitch) - np.median(b_pitch)).astype(np.int16) # Normalize pitch
a_pitch = a_pitch + np.arange(-7, 7).reshape(-1, 1) # Transpose invarient
interval_diff = np.concatenate([
a_pitch[:, a2b_idx] - b_pitch,
b_pitch[b2a_idx] - a_pitch], axis=1)
pitch_dist = np.abs(semitone2degree[interval_diff % 8] + np.abs(interval_diff) // 8 * np.sign(interval_diff)).mean(1).min()
onset_dist = np.abs(np.concatenate([
a_onset[a2b_idx] - b_onset,
b_onset[b2a_idx] - a_onset], axis=0)).mean()
return (weight[0] * onset_dist + weight[1] * pitch_dist) / sum(weight)
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 8., hop_size: float = 4.):
x = sorted(x)
trim_offset = (x[0][0] // hop_size) * hop_size
end_time = x[-1][0]
num_segment = ceil((end_time - window_size - trim_offset) / hop_size) + 1
time_matrix = (np.fromiter((time for time, _ in x), dtype=float) - trim_offset).reshape(1, -1).repeat(num_segment, axis=0)
seg_time_starts = np.arange(num_segment).reshape(-1, 1) * hop_size
time_compare_matrix = np.where((time_matrix >= seg_time_starts) & (time_matrix <= seg_time_starts + window_size), 0, 1)
time_compare_matrix = np.diff(np.pad(time_compare_matrix, ((0, 0), (1, 1)), constant_values=1))
start_idxs = sorted(np.where(time_compare_matrix == -1), key=lambda x: x[0])[1].tolist()
end_idxs = sorted(np.where(time_compare_matrix == 1), key=lambda x: x[0])[1].tolist()
segments = [x[start:end] for start, end in zip(start_idxs, end_idxs)]
return segments
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
a = midi_time_sliding_window(a, window_size=window_size, hop_size=hop_size)
b = midi_time_sliding_window(b, window_size=window_size, hop_size=hop_size)
dist = np.inf
for x,i in enumerate(a):
for y,j in enumerate(b):
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
if cur_dist == 0:
print(x, y)
if(cur_dist < dist):
dist = cur_dist
return float(dist)
def extract_notes(filepath: str):
"""读取MIDI并返回 (time, pitch) 列表"""
try:
s = Score(filepath).to("quarter")
notes = []
# for t in s.tracks:
# notes.extend([(n.time, n.pitch) for n in t.notes])
notes = [(n.time, n.pitch) for n in s.tracks[0].notes] # 仅使用第一个track
return notes
except Exception as e:
print(f"读取 {filepath} 出错: {e}")
return []
def compare_pair(file_a: str, file_b: str):
try:
notes_a = extract_notes(file_a)
notes_b = extract_notes(file_b)
if not notes_a or not notes_b:
return (file_a, file_b, np.inf)
dist = midi_dist(notes_a, notes_b)
return (file_a, file_b, dist)
except Exception as e:
import traceback
print(f"⚠️ compare_pair 出错: {file_a} vs {file_b}")
traceback.print_exc()
return (file_a, file_b, np.inf)
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
files_a = files_a[:100] # 仅比较前100个文件以节省时间
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
results = []
pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
for fut in as_completed(futures):
pbar.update(1)
try:
results.append(fut.result())
except Exception as e:
print(fut.result())
print(f"Error comparing pair: {e}")
# print(f"Compared: {results[-1][0]} vs {results[-1][1]}, Distance: {results[-1][2]:.4f}")
# with tqdm(total=len(files_a) * len(files_b)) as pbar:
# for fa in files_a:
# for fb in files_b:
# results.append(compare_pair(fa, fb))
# pbar.update(1)
# # 排序
results = sorted(results, key=lambda x: x[2])
# 保存
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
df.to_csv(out_csv, index=False)
print(f"已保存结果到 {out_csv}")
if __name__ == "__main__":
dir_a = "wandb/run-20251015_154556-f0pj3ys3/cond_4m_top_p_t0.99_temp1.25/process_2_batch_23"
dir_b = "dataset/Melody"
batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6)

View File

@ -1,105 +0,0 @@
import os
import numpy as np
import pandas as pd
from symusic import Score
from concurrent.futures import ProcessPoolExecutor, as_completed
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True):
if(not a.shape[1] or not b.shape[1]):
return np.inf
a_onset, a_pitch = a
b_onset, b_pitch = b
a_onset = a_onset.astype(np.float32)
b_onset = b_onset.astype(np.float32)
a_pitch = a_pitch.astype(np.uint8)
b_pitch = b_pitch.astype(np.uint8)
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
if(oti):
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12]
dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight)
a2b = dist_matrix.min(2)
b2a = dist_matrix.min(1)
dist = np.concatenate([a2b, b2a], axis=1)
return dist.sum(axis=1).min() / len(dist)
else:
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12]
dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight)
a2b = dist_matrix.min(1)
b2a = dist_matrix.min(0)
return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1]))
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.):
x = sorted(x)
end_time = x[-1][0]
out = [[] for _ in range(int(end_time // hop_size))]
for i in sorted(x):
segment = min(int(i[0] // hop_size), len(out) - 1)
while(i[0] >= segment * hop_size):
out[segment].append(i)
segment -= 1
if(segment < 0):
break
return out
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
a = midi_time_sliding_window(a)
b = midi_time_sliding_window(b)
dist = np.inf
for i in a:
for j in b:
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
if(cur_dist < dist):
dist = cur_dist
return dist
def extract_notes(filepath: str):
"""读取MIDI并返回 (time, pitch) 列表"""
try:
s = Score(filepath).to("quarter")
notes = []
for t in s.tracks:
notes.extend([(n.time, n.pitch) for n in t.notes])
return notes
except Exception as e:
print(f"读取 {filepath} 出错: {e}")
return []
def compare_pair(file_a: str, file_b: str):
notes_a = extract_notes(file_a)
notes_b = extract_notes(file_b)
if not notes_a or not notes_b:
return (file_a, file_b, np.inf)
dist = midi_dist(notes_a, notes_b)
return (file_a, file_b, dist)
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
results = []
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
for fut in as_completed(futures):
results.append(fut.result())
# 排序
results = sorted(results, key=lambda x: x[2])
# 保存
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
df.to_csv(out_csv, index=False)
print(f"已保存结果到 {out_csv}")
if __name__ == "__main__":
dir_a = "folder_a"
dir_b = "folder_b"
batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8)