1021 add flexable attr control

This commit is contained in:
FelixChan
2025-10-21 15:27:03 +08:00
parent d6b68ef90b
commit b493ede479
15 changed files with 400 additions and 394 deletions

View File

@ -67,8 +67,19 @@ def get_best_ckpt_path_and_config(wandb_dir, code):
return last_ckpt_fn, config_path, metadata_path, vocab_path
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str, condition_dataset: str=None):
# if config is a path, load it
if isinstance(config, (str, Path)):
from omegaconf import OmegaConf
config = OmegaConf.load(config)
config = wandb_style_config_to_omega_config(config)
nn_params = config.nn_params
for_evaluation = True
if condition_dataset is not None:
print(f"Conditioned dataset {condition_dataset} is used instead of {config.dataset}")
config.dataset = condition_dataset
for_evaluation = False
dataset_name = config.dataset
vocab_path = Path(vocab_path)
@ -104,7 +115,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
input_length=config.train_params.input_length,
first_pred_feature=config.data_params.first_pred_feature,
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
for_evaluation=True,
for_evaluation=for_evaluation
)
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
@ -114,7 +125,6 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
split_ratio = config.data_params.split_ratio
# test_set = []
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
# get proper prediction order according to the encoding scheme and target feature in the config
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
@ -480,6 +490,28 @@ class Evaluator:
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_with_attrCtl(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072, attr_list=None):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
tuneidx = tuneidx.cuda()
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature, attr_list=attr_list)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = self.config.nn_params.encoding_scheme

View File

@ -102,7 +102,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
'''
super().__init__()
self.net = net
# self.prediction_order = net.prediction_order
# self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)}
self.attribute2idx_after = {'pitch': 0,
'duration': 1,
'velocity': 2,
'type': 3,
'beat': 4,
'chord': 5,
'tempo': 6,
'instrument': 7}
self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None):
return self.net(input_seq, target, context=context)
@ -164,7 +174,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device)
return total_out
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None):
def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None):
'''
Runs one step of autoregressive generation by taking the input sequence, embedding it,
passing it through the main decoder, and generating logits and a sampled token.
@ -192,7 +202,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec}
# Generate the next token
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature)
logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step)
return logits, sampled_token, intermidiates, hidden_vec
def _update_total_out(self, total_out, sampled_token):
@ -225,7 +235,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
return total_out, sampled_token
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None):
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None):
'''
Autoregressively generates a sequence of tokens by repeatedly sampling the next token
until the desired maximum sequence length is reached or the end token is encountered.
@ -243,15 +253,19 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
- total_out: The generated sequence of tokens as a tensor.
'''
# Prepare the starting sequence for inference
if attr_list is None:
total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures)
# If a condition is provided, run one initial step
if condition is not None:
_, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
else:
cache = LayerIntermediates()
total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures)
# for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336
condition_filtered = condition.clone().unsqueeze(0)
# print(self.attribute2idx)
for attr, idx in self.attribute2idx.items():
if attr not in attr_list:
condition_filtered[:, :, idx] = 126336
# rearange condition_filtered to match prediction order
# Continue generating tokens until the maximum sequence length is reached
cache = LayerIntermediates()
pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token")
bos_hidden_vec = None
hidden_vec_list = []
@ -261,6 +275,20 @@ class AmadeusModelAutoregressiveWrapper(nn.Module):
input_tensor = total_out[:, -self.net.input_length:]
# Generate the next token and update the cache
time_start = time.time()
# if attr_list is not None, get one token in condition_filtered each time step
if attr_list is not None:
condition_filtered = condition_filtered.to(self.net.device)
# print(condition_filtered[:,:20,:])
# print(condition_filtered.shape)
condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :]
# rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4
condition_step_rearranged = torch.zeros_like(condition_step)
for attr, idx in self.attribute2idx.items():
new_idx = self.attribute2idx_after[attr]
condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx]
# print("condition_step shape:", condition_step.shape)
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged)
else:
_, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context)
time_end = time.time()
token_time_list.append(time_end - time_start)
@ -416,11 +444,11 @@ class AmadeusModel(nn.Module):
return self.decoder(input_seq, target, context=context)
@torch.inference_mode()
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None):
def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None):
if batch_size == 1:
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context)
return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list)
else:
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context)
return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list)
class AmadeusModel4Encodec(AmadeusModel):
def __init__(

View File

@ -43,6 +43,22 @@ def typical_sampling(logits, thres=0.99):
scores = logits.masked_fill(indices_to_remove, float("-inf"))
return scores
def min_p_sampling(logits, alpha=0.05):
"""
logits: Tensor of shape [B, L, V]
alpha: float, relative probability threshold (e.g., 0.05)
"""
# 计算 softmax 概率
probs = F.softmax(logits, dim=-1)
# 找到每个位置的最大概率
max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1]
# 保留概率 >= alpha * max_prob 的 token
mask = probs < (alpha * max_probs) # True 表示要屏蔽
masked_logits = logits.masked_fill(mask, float('-inf'))
return masked_logits
def add_gumbel_noise(logits, temperature):
'''
The Gumbel max is a method for sampling categorical distributions.
@ -91,6 +107,8 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
modified_logits = typical_sampling(logits, thres=threshold)
elif sampling_method == "eta":
modified_logits = eta_sampling(logits, epsilon=threshold)
elif sampling_method == "min_p":
modified_logits = min_p_sampling(logits, alpha=threshold)
else:
modified_logits = logits # 其他情况直接使用原始logits

View File

@ -1,3 +1,4 @@
from re import T
from selectors import EpollSelector
from turtle import st
from numpy import indices
@ -6,7 +7,7 @@ import torch
import torch.profiler
import torch.nn as nn
from x_transformers import Decoder
from .custom_x_transformers import Decoder
from .transformer_utils import MultiEmbedding, RVQMultiEmbedding
from .sub_decoder_utils import *
@ -146,7 +147,7 @@ class FeedForward(SubDecoderClass):
f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items()
})
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec']
target = input_dict['target']
@ -204,7 +205,7 @@ class Parallel(SubDecoderClass):
'''
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec']
target = input_dict['target']
@ -414,7 +415,7 @@ class SelfAttention(SubDecoderClass):
memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model
return memory_tensor
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] # B x T x num_sub_tokens
@ -490,7 +491,7 @@ class SelfAttentionUniAudio(SelfAttention):
memory_tensor = hidden_vec_reshape + feature_tensor
return memory_tensor
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] # B x T x num_sub-tokens
@ -604,7 +605,7 @@ class CrossAttention(SubDecoderClass):
memory_list.append(BOS_emb[-1:, :, :])
return memory_list
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target']
@ -677,7 +678,7 @@ class Flatten4Encodec(SubDecoderClass):
):
super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None):
hidden_vec = input_dict['hidden_vec']
# ---- Training ---- #
@ -941,94 +942,7 @@ class DiffusionDecoder(SubDecoderClass):
indices = torch.tensor([[step]], device=hidden_vec.device)
return indices
def forward_(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = input_seq
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
stored_logits_dict = {}
stored_probs_dict = {}
for step in range(self.denoising_steps):
# nomalize the memory tensor
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
# input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
candidate_token_probs = {}
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
# indices = self.choose_tokens(hidden_vec,step, "auto-regressive", stacked_logits_probs, num_transfer_tokens)
indices = self.choose_tokens(hidden_vec, step, self.method, stacked_logits_probs, num_transfer_tokens)
# breakpoint()
# undate the masked history
for i in range(b*t):
for j in range(l):
if j in indices[i]:
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
# breakpoint()
# print("stored_probs_dict", stored_probs_dict)
# print("sampled_token_dict", sampled_token_dict)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
def forward_old(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False):
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
@ -1070,139 +984,25 @@ class DiffusionDecoder(SubDecoderClass):
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
# add attribute control here
stored_logits_dict = {}
stored_probs_dict = {}
for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# nomalize the memory tensor
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
candidate_token_probs = {}
candidate_token_embeddings = {}
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
sampled_token,probs = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
# print(idx,feature,sampled_token,probs)
sampled_token_dict[feature] = sampled_token
candidate_token_probs[feature] = probs
feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token)
feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size
candidate_token_embeddings[feature] = feature_emb_reshape
stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, l)) # (B*T) x num_sub_tokens x vocab_size
stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, l, d))
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
elif self.method == 'random':
indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device)
elif self.method == 'auto-regressive':
indices = torch.tensor([[step]], device=logit.device)
# undate the masked history
stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device)
if condition_step is not None:
# print("shape of condition_step", condition_step.shape)
condition_step = condition_step.reshape((b*t, l))
for i in range(b*t):
for j in range(l):
if j in indices[i]:
token = condition_step[i][j]
if condition_step[i][j] != self.MASK_idx:
# print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone()
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
return stored_logits_dict, sampled_token_dict
memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j])
stored_token_embeddings[i][j][:] = memory_tensor[i][j][:]
# print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}")
# ---- Training ---- #
_, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model
memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model
# apply layer norm
extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model
if worst_case: # mask all ,turn into parallel
extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device)
memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor)
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# all is embedding
# memory_tensor = self.layer_norm(memory_tensor)
# apply feature enricher to memory
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# implement sub decoder cross attention
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
# inter_input = torch.cat([input_seq_pos, memory_tensor], dim=1)
# inter_input = input_seq_pos + memory_tensor # (B*T) x num_sub_tokens x d_model
# input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask}
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# get prob
for idx, feature in enumerate(self.prediction_order):
feature_pos = self.feature_order_in_output[feature]
logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :])
logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size
logits_dict[feature] = logit
return logits_dict, (masked_indices, p_mask)
def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, validation=False):
logits_dict = {}
hidden_vec = input_dict['hidden_vec'] # B x T x d_model
target = input_dict['target'] #B x T x d_model
bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder
# apply window on hidden_vec for enricher
if self.sub_decoder_enricher_use:
window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model
hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model
input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
if bos_hidden_vec is None: # start of generation
if target is None:
bos_hidden_vec = input_seq_pos
else:
bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model
bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1))
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1)
else:
bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = input_seq
input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask}
boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model
input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model
# input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model
# prepare memory
memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False)
# ---- Generate(Inference) ---- #
if target is None:
sampled_token_dict = {}
b,t,d = hidden_vec.shape # B x T x d_model
l = len(self.prediction_order) # num_sub_tokens
memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d))
all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model
# indicate the position of the mask token,1 means that the token hsa been masked
masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool()
num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps)
# denoising c
stored_logits_dict = {}
stored_probs_dict = {}
# with torch.profiler.profile(
# activities=[
# torch.profiler.ProfilerActivity.CPU,
@ -1213,8 +1013,6 @@ class DiffusionDecoder(SubDecoderClass):
# ) as prof:
for step in range(self.denoising_steps):
memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model
# nomalize the memory tensor
# memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model
if self.sub_decoder_enricher_use:
input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec}
input_dict = self.feature_enricher_layers(input_dict)
@ -1223,14 +1021,15 @@ class DiffusionDecoder(SubDecoderClass):
input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask}
input_dict = self.sub_decoder_layers(input_dict)
attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model
candidate_token_probs = {}
sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature,
force_decode=Force_decode,
step=step)
# print("step", step)
# print("toknes", sampled_token_dict)
# set prob of the changed tokens to -inf
stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf)
print("stacked_logits_probs", stacked_logits_probs.clone())
if self.method == 'low-confidence':
_, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1)
@ -1242,12 +1041,25 @@ class DiffusionDecoder(SubDecoderClass):
for i in range(b*t):
for j in range(l):
if j in indices[i]:
# print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}")
masked_history[i][j] = False
stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone()
stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:]
expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings)
memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings)
# skip if all tokens are unmasked
if not expand_masked_history.any():
# print("All tokens have been unmasked. Ending denoising process.")
break
# print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
# get final sampled tokens by embedding the unmasked tokens
sampled_token_dict = {}
for idx, feature in enumerate(self.prediction_order):
sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :])
sampled_token_dict[feature] = sampled_token
# print("Final sampled tokens:")
# print(sampled_token_dict)
# print(condition_step)
return stored_logits_dict, sampled_token_dict
# ---- Training ---- #

View File

@ -510,6 +510,7 @@ class Melody(SymbolicMusicDataset):
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
ratio = 0.8
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []

View File

@ -2,8 +2,8 @@ defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
- nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M
@ -15,7 +15,7 @@ defaults:
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
@ -44,7 +44,7 @@ train_params:
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.0004
initial_lr: 0.0003
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps
@ -59,7 +59,7 @@ inference_params:
data_params:
first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument)
split_ratio: 0.998 # train-validation-test split ratio
aug_type: pitch # random, null | pitch and chord augmentation type
aug_type: null # random, null | pitch and chord augmentation type
general:
debug: False
make_log: True # True, False | update the log file in wandb online to your designated project and entity

View File

@ -74,7 +74,8 @@ class LanguageModelTrainer:
sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings)
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
):
# Save model, optimizer, and other configurations
self.model = model

View File

@ -112,6 +112,23 @@ class MultiEmbedding(nn.Module):
layer_idx = self.feature_list.index(key)
return self.layers[layer_idx](token)
def get_token_by_emb(self, key, token_emb):
'''
token_emb: B x emb_size
'''
layer_idx = self.feature_list.index(key)
embedding_layer = self.layers[layer_idx] # nn.Embedding
# compute cosine similarity between token_emb and embedding weights
emb_weights = embedding_layer.weight # vocab_size x emb_size
cos_sim = torch.nn.functional.cosine_similarity(
token_emb.unsqueeze(1), # B x 1 x emb_size
emb_weights.unsqueeze(0), # 1 x vocab_size x emb_size
dim=-1
) # B x vocab_size
# get the index of the most similar embedding
token_idx = torch.argmax(cos_sim, dim=-1) # B
return token_idx
class SummationEmbedder(MultiEmbedding):
def __init__(
self,

View File

@ -168,7 +168,7 @@ class CorpusMaker():
print("length of midi list: ", len(self.midi_list))
# Use set for faster lookup (O(1) per check)
processed_files_set = set(processed_files)
# self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
# reverse the list to process the latest files first
self.midi_list.reverse()
print(f"length of midi list after filtering: ", len(self.midi_list))

View File

@ -61,7 +61,7 @@ class Corpus2Event():
# remove the corpus files that are already in the out_dir
# Use set for faster existence checks
existing_files = set(f.name for f in self.out_dir.glob("*.pkl"))
# corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)):
if event is None:
broken_count += 1

View File

@ -1,3 +1,4 @@
from ast import arg
import sys
import os
from pathlib import Path
@ -25,14 +26,25 @@ def get_argument_parser():
parser.add_argument(
"-generation_type",
type=str,
choices=('conditioned', 'unconditioned', 'text-conditioned'),
choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'),
default='unconditioned',
help="generation type",
)
parser.add_argument(
"-attr_list",
type=str,
default="beat,duration",
help="attribute list for attribute-controlled generation",
)
parser.add_argument(
"-dataset",
type=str,
help="dataset name, only for conditioned generation",
)
parser.add_argument(
"-sampling_method",
type=str,
choices=('top_p', 'top_k'),
choices=('top_p', 'top_k', 'min_p'),
default='top_p',
help="sampling method",
)
@ -74,7 +86,7 @@ def get_argument_parser():
parser.add_argument(
"-num_processes",
type=int,
default=4,
default=1,
help="number of processes to use",
)
parser.add_argument(
@ -97,7 +109,7 @@ def get_argument_parser():
)
return parser
def load_resources(wandb_exp_dir, device):
def load_resources(wandb_exp_dir, condition_dataset, device):
"""Load model and dataset resources for a process"""
wandb_dir = Path('wandb')
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
@ -107,7 +119,8 @@ def load_resources(wandb_exp_dir, device):
# Load checkpoint to specified device
print("Loading checkpoint from:", ckpt_path)
ckpt = torch.load(ckpt_path, map_location=device)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
print(config)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset)
model.load_state_dict(ckpt['model'], strict=False)
model.to(device)
model.eval()
@ -123,20 +136,33 @@ def load_resources(wandb_exp_dir, device):
return config, model, dataset_for_prompt, vocab
def conditioned_worker(process_idx, gpu_id, args, data_slice):
def conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
start_idx = 1
end_idx = min(chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
@ -154,13 +180,62 @@ def conditioned_worker(process_idx, gpu_id, args, data_slice):
generation_length=args.generate_length
)
def attr_conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# attr_list = "position,duration"
attr_list = args.attr_list.split(',')
# Load resources with proper device
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
# chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
# start_idx = 1
# end_idx = min(chunk_size, len(selected_data))
# data_slice = selected_data[start_idx:end_idx]
data_slice = selected_data
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
batch_dir = base_path
batch_dir.mkdir(parents=True, exist_ok=True)
evaluator.generate_samples_with_attrCtl(
batch_dir,
args.num_target_measure,
tune_in_idx,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
attr_list=attr_list
)
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
@ -187,7 +262,7 @@ def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
@ -237,36 +312,29 @@ def main():
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
# Load test set to get selected tunes (dummy load to get dataset info)
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, test_set, _ = prepare_model_and_dataset_from_config(
wandb_dir / "files" / "config.yaml",
wandb_dir / "files" / "metadata.json",
wandb_dir / "files" / "vocab.json"
)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
for i in range(args.num_processes):
start_idx = i * chunk_size
end_idx = min((i+1)*chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
if not data_slice:
continue
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=conditioned_worker,
args=(i, gpu_id, args, data_slice)
args=(i, gpu_id, args)
)
processes.append(p)
p.start()
elif args.generation_type == 'attr-conditioned':
# Prepare selected tunes
wandb_dir = Path('wandb') / args.wandb_exp_dir
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=attr_conditioned_worker,
args=(i, gpu_id, args)
)
processes.append(p)
p.start()

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

134
midi_sim.py Normal file
View File

@ -0,0 +1,134 @@
import os
from math import ceil
#CUDA_VISIBLE_DEVICES= "0"
import numpy as np
import pandas as pd
from symusic import Score
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (0., 1.5)):
if(not a.shape[1] or not b.shape[1]):
return np.inf
a_onset, a_pitch = a
b_onset, b_pitch = b
a_onset = a_onset.astype(np.float32)
b_onset = b_onset.astype(np.float32)
a_pitch = a_pitch.astype(np.int16)
b_pitch = b_pitch.astype(np.int16)
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
a2b_idx = onset_dist_matrix.argmin(1)
b2a_idx = onset_dist_matrix.argmin(0)
a_pitch -= (np.median(a_pitch) - np.median(b_pitch)).astype(np.int16) # Normalize pitch
a_pitch = a_pitch + np.arange(-7, 7).reshape(-1, 1) # Transpose invarient
interval_diff = np.concatenate([
a_pitch[:, a2b_idx] - b_pitch,
b_pitch[b2a_idx] - a_pitch], axis=1)
pitch_dist = np.abs(semitone2degree[interval_diff % 8] + np.abs(interval_diff) // 8 * np.sign(interval_diff)).mean(1).min()
onset_dist = np.abs(np.concatenate([
a_onset[a2b_idx] - b_onset,
b_onset[b2a_idx] - a_onset], axis=0)).mean()
return (weight[0] * onset_dist + weight[1] * pitch_dist) / sum(weight)
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 8., hop_size: float = 4.):
x = sorted(x)
trim_offset = (x[0][0] // hop_size) * hop_size
end_time = x[-1][0]
num_segment = ceil((end_time - window_size - trim_offset) / hop_size) + 1
time_matrix = (np.fromiter((time for time, _ in x), dtype=float) - trim_offset).reshape(1, -1).repeat(num_segment, axis=0)
seg_time_starts = np.arange(num_segment).reshape(-1, 1) * hop_size
time_compare_matrix = np.where((time_matrix >= seg_time_starts) & (time_matrix <= seg_time_starts + window_size), 0, 1)
time_compare_matrix = np.diff(np.pad(time_compare_matrix, ((0, 0), (1, 1)), constant_values=1))
start_idxs = sorted(np.where(time_compare_matrix == -1), key=lambda x: x[0])[1].tolist()
end_idxs = sorted(np.where(time_compare_matrix == 1), key=lambda x: x[0])[1].tolist()
segments = [x[start:end] for start, end in zip(start_idxs, end_idxs)]
return segments
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
a = midi_time_sliding_window(a, window_size=window_size, hop_size=hop_size)
b = midi_time_sliding_window(b, window_size=window_size, hop_size=hop_size)
dist = np.inf
for x,i in enumerate(a):
for y,j in enumerate(b):
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
if cur_dist == 0:
print(x, y)
if(cur_dist < dist):
dist = cur_dist
return float(dist)
def extract_notes(filepath: str):
"""读取MIDI并返回 (time, pitch) 列表"""
try:
s = Score(filepath).to("quarter")
notes = []
# for t in s.tracks:
# notes.extend([(n.time, n.pitch) for n in t.notes])
notes = [(n.time, n.pitch) for n in s.tracks[0].notes] # 仅使用第一个track
return notes
except Exception as e:
print(f"读取 {filepath} 出错: {e}")
return []
def compare_pair(file_a: str, file_b: str):
try:
notes_a = extract_notes(file_a)
notes_b = extract_notes(file_b)
if not notes_a or not notes_b:
return (file_a, file_b, np.inf)
dist = midi_dist(notes_a, notes_b)
return (file_a, file_b, dist)
except Exception as e:
import traceback
print(f"⚠️ compare_pair 出错: {file_a} vs {file_b}")
traceback.print_exc()
return (file_a, file_b, np.inf)
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
files_a = files_a[:100] # 仅比较前100个文件以节省时间
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
results = []
pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
for fut in as_completed(futures):
pbar.update(1)
try:
results.append(fut.result())
except Exception as e:
print(fut.result())
print(f"Error comparing pair: {e}")
# print(f"Compared: {results[-1][0]} vs {results[-1][1]}, Distance: {results[-1][2]:.4f}")
# with tqdm(total=len(files_a) * len(files_b)) as pbar:
# for fa in files_a:
# for fb in files_b:
# results.append(compare_pair(fa, fb))
# pbar.update(1)
# # 排序
results = sorted(results, key=lambda x: x[2])
# 保存
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
df.to_csv(out_csv, index=False)
print(f"已保存结果到 {out_csv}")
if __name__ == "__main__":
dir_a = "wandb/run-20251015_154556-f0pj3ys3/cond_4m_top_p_t0.99_temp1.25/process_2_batch_23"
dir_b = "dataset/Melody"
batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6)

View File

@ -1,105 +0,0 @@
import os
import numpy as np
import pandas as pd
from symusic import Score
from concurrent.futures import ProcessPoolExecutor, as_completed
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True):
if(not a.shape[1] or not b.shape[1]):
return np.inf
a_onset, a_pitch = a
b_onset, b_pitch = b
a_onset = a_onset.astype(np.float32)
b_onset = b_onset.astype(np.float32)
a_pitch = a_pitch.astype(np.uint8)
b_pitch = b_pitch.astype(np.uint8)
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
if(oti):
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12]
dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight)
a2b = dist_matrix.min(2)
b2a = dist_matrix.min(1)
dist = np.concatenate([a2b, b2a], axis=1)
return dist.sum(axis=1).min() / len(dist)
else:
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12]
dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight)
a2b = dist_matrix.min(1)
b2a = dist_matrix.min(0)
return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1]))
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.):
x = sorted(x)
end_time = x[-1][0]
out = [[] for _ in range(int(end_time // hop_size))]
for i in sorted(x):
segment = min(int(i[0] // hop_size), len(out) - 1)
while(i[0] >= segment * hop_size):
out[segment].append(i)
segment -= 1
if(segment < 0):
break
return out
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
a = midi_time_sliding_window(a)
b = midi_time_sliding_window(b)
dist = np.inf
for i in a:
for j in b:
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
if(cur_dist < dist):
dist = cur_dist
return dist
def extract_notes(filepath: str):
"""读取MIDI并返回 (time, pitch) 列表"""
try:
s = Score(filepath).to("quarter")
notes = []
for t in s.tracks:
notes.extend([(n.time, n.pitch) for n in t.notes])
return notes
except Exception as e:
print(f"读取 {filepath} 出错: {e}")
return []
def compare_pair(file_a: str, file_b: str):
notes_a = extract_notes(file_a)
notes_b = extract_notes(file_b)
if not notes_a or not notes_b:
return (file_a, file_b, np.inf)
dist = midi_dist(notes_a, notes_b)
return (file_a, file_b, dist)
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
results = []
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
for fut in as_completed(futures):
results.append(fut.result())
# 排序
results = sorted(results, key=lambda x: x[2])
# 保存
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
df.to_csv(out_csv, index=False)
print(f"已保存结果到 {out_csv}")
if __name__ == "__main__":
dir_a = "folder_a"
dir_b = "folder_b"
batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8)