diff --git a/Amadeus/evaluation_utils.py b/Amadeus/evaluation_utils.py index 00aa52b..37f9917 100644 --- a/Amadeus/evaluation_utils.py +++ b/Amadeus/evaluation_utils.py @@ -67,8 +67,19 @@ def get_best_ckpt_path_and_config(wandb_dir, code): return last_ckpt_fn, config_path, metadata_path, vocab_path -def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str): +def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str, condition_dataset: str=None): + # if config is a path, load it + if isinstance(config, (str, Path)): + from omegaconf import OmegaConf + config = OmegaConf.load(config) + config = wandb_style_config_to_omega_config(config) + nn_params = config.nn_params + for_evaluation = True + if condition_dataset is not None: + print(f"Conditioned dataset {condition_dataset} is used instead of {config.dataset}") + config.dataset = condition_dataset + for_evaluation = False dataset_name = config.dataset vocab_path = Path(vocab_path) @@ -104,7 +115,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, input_length=config.train_params.input_length, first_pred_feature=config.data_params.first_pred_feature, caption_path=config.captions_path if hasattr(config, 'captions_path') else None, - for_evaluation=True, + for_evaluation=for_evaluation ) vocab_sizes = symbolic_dataset.vocab.get_vocab_size() @@ -114,7 +125,6 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, split_ratio = config.data_params.split_ratio # test_set = [] train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None) - # get proper prediction order according to the encoding scheme and target feature in the config prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params) @@ -480,6 +490,28 @@ class Evaluator: prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8) decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid")) + def generate_samples_with_attrCtl(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072, attr_list=None): + encoding_scheme = self.config.nn_params.encoding_scheme + + in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4} + try: + in_beat_resolution = in_beat_resolution_dict[self.config.dataset] + except KeyError: + in_beat_resolution = 4 # Default resolution if dataset is not found + + midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} + decoder_name = midi_decoder_dict[encoding_scheme] + decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) + + tuneidx = tuneidx.cuda() + generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature, attr_list=attr_list) + if encoding_scheme == 'nb': + generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature) + decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid")) + + prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8) + decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid")) + def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1): encoding_scheme = self.config.nn_params.encoding_scheme diff --git a/Amadeus/model_zoo.py b/Amadeus/model_zoo.py index 492c8c9..7b2788d 100644 --- a/Amadeus/model_zoo.py +++ b/Amadeus/model_zoo.py @@ -102,7 +102,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): ''' super().__init__() self.net = net - + # self.prediction_order = net.prediction_order + # self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)} + self.attribute2idx_after = {'pitch': 0, + 'duration': 1, + 'velocity': 2, + 'type': 3, + 'beat': 4, + 'chord': 5, + 'tempo': 6, + 'instrument': 7} + self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7} def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None): return self.net(input_seq, target, context=context) @@ -164,7 +174,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): total_out = torch.LongTensor(total_out).unsqueeze(0).to(self.net.device) return total_out - def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None): + def _run_one_step(self, input_seq, cache=None, sampling_method=None, threshold=None, temperature=1, bos_hidden_vec=None,context=None,condition_step=None): ''' Runs one step of autoregressive generation by taking the input sequence, embedding it, passing it through the main decoder, and generating logits and a sampled token. @@ -192,7 +202,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): input_dict = {'hidden_vec': hidden_vec, 'input_seq': input_seq, 'target': None, 'bos_token_hidden': bos_hidden_vec} # Generate the next token - logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature) + logits, sampled_token = self.net.sub_decoder(input_dict, sampling_method, threshold, temperature, condition_step=condition_step) return logits, sampled_token, intermidiates, hidden_vec def _update_total_out(self, total_out, sampled_token): @@ -225,7 +235,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): return total_out, sampled_token @torch.inference_mode() - def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None): + def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1, batch_size=1, context=None, attr_list=None): ''' Autoregressively generates a sequence of tokens by repeatedly sampling the next token until the desired maximum sequence length is reached or the end token is encountered. @@ -243,15 +253,19 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): - total_out: The generated sequence of tokens as a tensor. ''' # Prepare the starting sequence for inference - total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures) - - # If a condition is provided, run one initial step - if condition is not None: - _, _, cache = self._run_one_step(total_out[:, -self.net.input_length:], cache=LayerIntermediates(), sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context) + if attr_list is None: + total_out = self._prepare_inference(self.net.start_token, manual_seed, condition, num_target_measures) else: - cache = LayerIntermediates() - - # Continue generating tokens until the maximum sequence length is reached + total_out = self._prepare_inference(self.net.start_token, manual_seed, None, num_target_measures) + # for attribute-controlled generation, only keep the specified attributes in condition, others set to 126336 + condition_filtered = condition.clone().unsqueeze(0) + # print(self.attribute2idx) + for attr, idx in self.attribute2idx.items(): + if attr not in attr_list: + condition_filtered[:, :, idx] = 126336 + # rearange condition_filtered to match prediction order + + cache = LayerIntermediates() pbar = tqdm(total=max_seq_len, desc="Generating tokens", unit="token") bos_hidden_vec = None hidden_vec_list = [] @@ -261,7 +275,21 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): input_tensor = total_out[:, -self.net.input_length:] # Generate the next token and update the cache time_start = time.time() - _, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context) + # if attr_list is not None, get one token in condition_filtered each time step + if attr_list is not None: + condition_filtered = condition_filtered.to(self.net.device) + # print(condition_filtered[:,:20,:]) + # print(condition_filtered.shape) + condition_step = condition_filtered[:, total_out.shape[1]-1:total_out.shape[1], :] + # rearange order, 0 to 5, 1 to 6, 2 to 7, 3 to 0, 4 to 1, 5 to 2, 6 to 3, 7 to 4 + condition_step_rearranged = torch.zeros_like(condition_step) + for attr, idx in self.attribute2idx.items(): + new_idx = self.attribute2idx_after[attr] + condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx] + # print("condition_step shape:", condition_step.shape) + _, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged) + else: + _, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context) time_end = time.time() token_time_list.append(time_end - time_start) if bos_hidden_vec is None: @@ -416,11 +444,11 @@ class AmadeusModel(nn.Module): return self.decoder(input_seq, target, context=context) @torch.inference_mode() - def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None): + def generate(self, manual_seed, max_seq_len, condition=None, num_target_measures=4, sampling_method=None, threshold=None, temperature=1,batch_size=1,context=None,attr_list=None): if batch_size == 1: - return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context) + return self.decoder.generate(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, context=context, attr_list=attr_list) else: - return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context) + return self.decoder.generate_batch(manual_seed, max_seq_len, condition, num_target_measures, sampling_method, threshold, temperature, batch_size, context=context, attr_list=attr_list) class AmadeusModel4Encodec(AmadeusModel): def __init__( diff --git a/Amadeus/sampling_utils.py b/Amadeus/sampling_utils.py index c5742ca..d64d392 100644 --- a/Amadeus/sampling_utils.py +++ b/Amadeus/sampling_utils.py @@ -43,6 +43,22 @@ def typical_sampling(logits, thres=0.99): scores = logits.masked_fill(indices_to_remove, float("-inf")) return scores +def min_p_sampling(logits, alpha=0.05): + """ + logits: Tensor of shape [B, L, V] + alpha: float, relative probability threshold (e.g., 0.05) + """ + # 计算 softmax 概率 + probs = F.softmax(logits, dim=-1) + + # 找到每个位置的最大概率 + max_probs, _ = probs.max(dim=-1, keepdim=True) # [B, L, 1] + + # 保留概率 >= alpha * max_prob 的 token + mask = probs < (alpha * max_probs) # True 表示要屏蔽 + masked_logits = logits.masked_fill(mask, float('-inf')) + return masked_logits + def add_gumbel_noise(logits, temperature): ''' The Gumbel max is a method for sampling categorical distributions. @@ -91,6 +107,8 @@ def sample_with_prob(logits, sampling_method, threshold, temperature): modified_logits = typical_sampling(logits, thres=threshold) elif sampling_method == "eta": modified_logits = eta_sampling(logits, epsilon=threshold) + elif sampling_method == "min_p": + modified_logits = min_p_sampling(logits, alpha=threshold) else: modified_logits = logits # 其他情况直接使用原始logits diff --git a/Amadeus/sub_decoder_zoo.py b/Amadeus/sub_decoder_zoo.py index e0994d6..9179211 100644 --- a/Amadeus/sub_decoder_zoo.py +++ b/Amadeus/sub_decoder_zoo.py @@ -1,3 +1,4 @@ +from re import T from selectors import EpollSelector from turtle import st from numpy import indices @@ -6,7 +7,7 @@ import torch import torch.profiler import torch.nn as nn -from x_transformers import Decoder +from .custom_x_transformers import Decoder from .transformer_utils import MultiEmbedding, RVQMultiEmbedding from .sub_decoder_utils import * @@ -146,7 +147,7 @@ class FeedForward(SubDecoderClass): f"layer_{key}": nn.Linear(dim+dim, dim) for key, _ in vocab_sizes.items() }) - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None): logits_dict = {} hidden_vec = input_dict['hidden_vec'] target = input_dict['target'] @@ -204,7 +205,7 @@ class Parallel(SubDecoderClass): ''' super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use) - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None): logits_dict = {} hidden_vec = input_dict['hidden_vec'] target = input_dict['target'] @@ -414,7 +415,7 @@ class SelfAttention(SubDecoderClass): memory_tensor = torch.cat(input_seq_list, dim=1) # (B*T) x (window_size + BOS + num_sub_tokens-1) x d_model return memory_tensor - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None): logits_dict = {} hidden_vec = input_dict['hidden_vec'] # B x T x d_model target = input_dict['target'] # B x T x num_sub_tokens @@ -490,7 +491,7 @@ class SelfAttentionUniAudio(SelfAttention): memory_tensor = hidden_vec_reshape + feature_tensor return memory_tensor - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None): logits_dict = {} hidden_vec = input_dict['hidden_vec'] # B x T x d_model target = input_dict['target'] # B x T x num_sub-tokens @@ -604,7 +605,7 @@ class CrossAttention(SubDecoderClass): memory_list.append(BOS_emb[-1:, :, :]) return memory_list - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None): logits_dict = {} hidden_vec = input_dict['hidden_vec'] # B x T x d_model target = input_dict['target'] @@ -677,7 +678,7 @@ class Flatten4Encodec(SubDecoderClass): ): super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use) - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, condition_step=None): hidden_vec = input_dict['hidden_vec'] # ---- Training ---- # @@ -838,7 +839,7 @@ class DiffusionDecoder(SubDecoderClass): This function is designed to precompute the number of tokens that need to be transitioned at each step. ''' - mask_num = mask_index.sum(dim=1, keepdim=True) + mask_num = mask_index.sum(dim=1,keepdim=True) base = mask_num // steps remainder = mask_num % steps @@ -941,94 +942,7 @@ class DiffusionDecoder(SubDecoderClass): indices = torch.tensor([[step]], device=hidden_vec.device) return indices - - def forward_(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False): - logits_dict = {} - hidden_vec = input_dict['hidden_vec'] # B x T x d_model - target = input_dict['target'] #B x T x d_model - - - # apply window on hidden_vec for enricher - if self.sub_decoder_enricher_use: - window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model - hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model - input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model - input_seq_pos = input_seq - # input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model - # prepare memory - memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False) - # ---- Generate(Inference) ---- # - if target is None: - sampled_token_dict = {} - b,t,d = hidden_vec.shape # B x T x d_model - l = len(self.prediction_order) # num_sub_tokens - memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d)) - all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model - - # indicate the position of the mask token,1 means that the token hsa been masked - masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool() - num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps) - # denoising c - stored_logits_dict = {} - stored_probs_dict = {} - for step in range(self.denoising_steps): - # nomalize the memory tensor - # memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model - if self.sub_decoder_enricher_use: - input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} - input_dict = self.feature_enricher_layers(input_dict) - memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask} - # input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask} - input_dict = self.sub_decoder_layers(input_dict) - attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - candidate_token_probs = {} - sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature) - - # set prob of the changed tokens to -inf - stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf) - # indices = self.choose_tokens(hidden_vec,step, "auto-regressive", stacked_logits_probs, num_transfer_tokens) - indices = self.choose_tokens(hidden_vec, step, self.method, stacked_logits_probs, num_transfer_tokens) - # breakpoint() - # undate the masked history - for i in range(b*t): - for j in range(l): - if j in indices[i]: - masked_history[i][j] = False - stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone() - stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone() - expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model - memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings) - # breakpoint() - # print("stored_probs_dict", stored_probs_dict) - # print("sampled_token_dict", sampled_token_dict) - return stored_logits_dict, sampled_token_dict - - # ---- Training ---- # - _, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model - memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model - # apply layer norm - - extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model - if worst_case: # mask all ,turn into parallel - extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device) - memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor) - if self.sub_decoder_enricher_use: - input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} - input_dict = self.feature_enricher_layers(input_dict) - memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask} - input_dict = self.sub_decoder_layers(input_dict) - attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - # get prob - for idx, feature in enumerate(self.prediction_order): - feature_pos = self.feature_order_in_output[feature] - logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :]) - logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size - logits_dict[feature] = logit - return logits_dict, (masked_indices, p_mask) - - def forward_old(self, input_dict, sampling_method=None, threshold=None, temperature=None, worst_case=False, validation=False): + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None): logits_dict = {} hidden_vec = input_dict['hidden_vec'] # B x T x d_model target = input_dict['target'] #B x T x d_model @@ -1070,139 +984,25 @@ class DiffusionDecoder(SubDecoderClass): # indicate the position of the mask token,1 means that the token hsa been masked masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool() - num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps) - # denoising c + # add attribute control here stored_logits_dict = {} - stored_probs_dict = {} - for step in range(self.denoising_steps): - memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model - # nomalize the memory tensor - # memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model - if self.sub_decoder_enricher_use: - input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} - input_dict = self.feature_enricher_layers(input_dict) - memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - # input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask} - input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask} - input_dict = self.sub_decoder_layers(input_dict) - attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - candidate_token_probs = {} - candidate_token_embeddings = {} - for idx, feature in enumerate(self.prediction_order): - feature_pos = self.feature_order_in_output[feature] - logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :]) - logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size - logits_dict[feature] = logit - sampled_token,probs = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature) - # print(idx,feature,sampled_token,probs) - sampled_token_dict[feature] = sampled_token - candidate_token_probs[feature] = probs - feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token) - feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size - candidate_token_embeddings[feature] = feature_emb_reshape - - stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, l)) # (B*T) x num_sub_tokens x vocab_size - stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, l, d)) - - # set prob of the changed tokens to -inf - stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf) - - if self.method == 'low-confidence': - _, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1) - elif self.method == 'random': - indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device) - elif self.method == 'auto-regressive': - indices = torch.tensor([[step]], device=logit.device) - # undate the masked history + stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device) + if condition_step is not None: + # print("shape of condition_step", condition_step.shape) + condition_step = condition_step.reshape((b*t, l)) for i in range(b*t): for j in range(l): - if j in indices[i]: - masked_history[i][j] = False - stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone() - stored_probs_dict[self.prediction_order[j]] = candidate_token_probs[self.prediction_order[j]].clone() - expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model - memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings) - return stored_logits_dict, sampled_token_dict - - # ---- Training ---- # - _, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model - memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model - # apply layer norm - - extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model - if worst_case: # mask all ,turn into parallel - extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device) - memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor) - memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model - # all is embedding - # memory_tensor = self.layer_norm(memory_tensor) - # apply feature enricher to memory - if self.sub_decoder_enricher_use: - input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} - input_dict = self.feature_enricher_layers(input_dict) - memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - # implement sub decoder cross attention - # input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask} - # inter_input = torch.cat([input_seq_pos, memory_tensor], dim=1) - # inter_input = input_seq_pos + memory_tensor # (B*T) x num_sub_tokens x d_model - # input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask} - input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask} - input_dict = self.sub_decoder_layers(input_dict) - attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - # get prob - for idx, feature in enumerate(self.prediction_order): - feature_pos = self.feature_order_in_output[feature] - logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :]) - logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size - logits_dict[feature] = logit - return logits_dict, (masked_indices, p_mask) - - def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, validation=False): - logits_dict = {} - hidden_vec = input_dict['hidden_vec'] # B x T x d_model - target = input_dict['target'] #B x T x d_model - bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder + token = condition_step[i][j] + if condition_step[i][j] != self.MASK_idx: - # apply window on hidden_vec for enricher - if self.sub_decoder_enricher_use: - window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model - hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model - input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model - input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model - - if bos_hidden_vec is None: # start of generation - if target is None: - bos_hidden_vec = input_seq_pos - else: - bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model - bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) - bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) - - else: - bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model - - # input_seq_pos = input_seq - input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask} - boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model - input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - # input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model - # input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model - # prepare memory - memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False) - # ---- Generate(Inference) ---- # - if target is None: - sampled_token_dict = {} - b,t,d = hidden_vec.shape # B x T x d_model - l = len(self.prediction_order) # num_sub_tokens - memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d)) - all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model + # print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}") + masked_history[i][j] = False + memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j]) + stored_token_embeddings[i][j][:] = memory_tensor[i][j][:] + # print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}") - # indicate the position of the mask token,1 means that the token hsa been masked - masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool() num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps) # denoising c - stored_logits_dict = {} - stored_probs_dict = {} # with torch.profiler.profile( # activities=[ # torch.profiler.ProfilerActivity.CPU, @@ -1213,8 +1013,6 @@ class DiffusionDecoder(SubDecoderClass): # ) as prof: for step in range(self.denoising_steps): memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model - # nomalize the memory tensor - # memory_tensor = self.layer_norm(memory_tensor) # (B*T) x num_sub_tokens x d_model if self.sub_decoder_enricher_use: input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} input_dict = self.feature_enricher_layers(input_dict) @@ -1223,14 +1021,15 @@ class DiffusionDecoder(SubDecoderClass): input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask} input_dict = self.sub_decoder_layers(input_dict) attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model - candidate_token_probs = {} - sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature, force_decode=Force_decode, step=step) + # print("step", step) + # print("toknes", sampled_token_dict) # set prob of the changed tokens to -inf stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf) + print("stacked_logits_probs", stacked_logits_probs.clone()) if self.method == 'low-confidence': _, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1) @@ -1242,12 +1041,25 @@ class DiffusionDecoder(SubDecoderClass): for i in range(b*t): for j in range(l): if j in indices[i]: + # print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}") masked_history[i][j] = False stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone() + stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:] expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model - memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stacked_token_embeddings) + memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings) + # skip if all tokens are unmasked + if not expand_masked_history.any(): + # print("All tokens have been unmasked. Ending denoising process.") + break # print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) + # get final sampled tokens by embedding the unmasked tokens + sampled_token_dict = {} + for idx, feature in enumerate(self.prediction_order): + sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :]) + sampled_token_dict[feature] = sampled_token + # print("Final sampled tokens:") # print(sampled_token_dict) + # print(condition_step) return stored_logits_dict, sampled_token_dict # ---- Training ---- # diff --git a/Amadeus/symbolic_encoding/data_utils.py b/Amadeus/symbolic_encoding/data_utils.py index 23e4583..26231f2 100644 --- a/Amadeus/symbolic_encoding/data_utils.py +++ b/Amadeus/symbolic_encoding/data_utils.py @@ -510,6 +510,7 @@ class Melody(SymbolicMusicDataset): shuffled_tune_names = list(self.tune_in_idx.keys()) song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] song_dict = {} + ratio = 0.8 for song, orig_song in zip(song_names_without_version, shuffled_tune_names): if song not in song_dict: song_dict[song] = [] diff --git a/Amadeus/symbolic_yamls/config-accelerate.yaml b/Amadeus/symbolic_yamls/config-accelerate.yaml index 5dc78b8..d472105 100644 --- a/Amadeus/symbolic_yamls/config-accelerate.yaml +++ b/Amadeus/symbolic_yamls/config-accelerate.yaml @@ -2,8 +2,8 @@ defaults: # - nn_params: nb8_embSum_NMT # - nn_params: remi8 # - nn_params: nb8_embSum_diff_t2m_150M_finetunning - # - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2 - - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2 + - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2 + # - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2 # - nn_params: nb8_embSum_subPararell # - nn_params: nb8_embSum_diff_t2m_150M @@ -15,7 +15,7 @@ defaults: # - nn_params: remi8_main12_head_16_dim512 # - nn_params: nb5_embSum_diff_main12head16dim768_sub3 -dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset +dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset captions_path: dataset/midicaps/train_set.json # dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean @@ -44,7 +44,7 @@ train_params: focal_gamma: 0 # learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details scheduler : cosinelr - initial_lr: 0.0004 + initial_lr: 0.0003 decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts' warmup_steps: 2000 #number of warmup steps @@ -59,7 +59,7 @@ inference_params: data_params: first_pred_feature: pitch # compound shifting for NB only, choose the target sub-token (remi and cp are not influenced by this argument) split_ratio: 0.998 # train-validation-test split ratio - aug_type: pitch # random, null | pitch and chord augmentation type + aug_type: null # random, null | pitch and chord augmentation type general: debug: False make_log: True # True, False | update the log file in wandb online to your designated project and entity diff --git a/Amadeus/trainer_accelerate.py b/Amadeus/trainer_accelerate.py index 3af0b6b..2e4224a 100644 --- a/Amadeus/trainer_accelerate.py +++ b/Amadeus/trainer_accelerate.py @@ -74,7 +74,8 @@ class LanguageModelTrainer: sampling_threshold: float, # Threshold for sampling decisions sampling_temperature: float, # Temperature for controlling sampling randomness config, # Configuration parameters (contains general, training, and inference settings) - model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional) + model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional) + # model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional) ): # Save model, optimizer, and other configurations self.model = model diff --git a/Amadeus/transformer_utils.py b/Amadeus/transformer_utils.py index f47ee17..8b970c3 100644 --- a/Amadeus/transformer_utils.py +++ b/Amadeus/transformer_utils.py @@ -111,6 +111,23 @@ class MultiEmbedding(nn.Module): def get_emb_by_key(self, key, token): layer_idx = self.feature_list.index(key) return self.layers[layer_idx](token) + + def get_token_by_emb(self, key, token_emb): + ''' + token_emb: B x emb_size + ''' + layer_idx = self.feature_list.index(key) + embedding_layer = self.layers[layer_idx] # nn.Embedding + # compute cosine similarity between token_emb and embedding weights + emb_weights = embedding_layer.weight # vocab_size x emb_size + cos_sim = torch.nn.functional.cosine_similarity( + token_emb.unsqueeze(1), # B x 1 x emb_size + emb_weights.unsqueeze(0), # 1 x vocab_size x emb_size + dim=-1 + ) # B x vocab_size + # get the index of the most similar embedding + token_idx = torch.argmax(cos_sim, dim=-1) # B + return token_idx class SummationEmbedder(MultiEmbedding): def __init__( diff --git a/data_representation/step1_midi2corpus_fined.py b/data_representation/step1_midi2corpus_fined.py index 49c8f25..713caaf 100644 --- a/data_representation/step1_midi2corpus_fined.py +++ b/data_representation/step1_midi2corpus_fined.py @@ -168,7 +168,7 @@ class CorpusMaker(): print("length of midi list: ", len(self.midi_list)) # Use set for faster lookup (O(1) per check) processed_files_set = set(processed_files) - # self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set] + self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set] # reverse the list to process the latest files first self.midi_list.reverse() print(f"length of midi list after filtering: ", len(self.midi_list)) diff --git a/data_representation/step2_corpus2event.py b/data_representation/step2_corpus2event.py index 74b2b55..6322516 100644 --- a/data_representation/step2_corpus2event.py +++ b/data_representation/step2_corpus2event.py @@ -61,7 +61,7 @@ class Corpus2Event(): # remove the corpus files that are already in the out_dir # Use set for faster existence checks existing_files = set(f.name for f in self.out_dir.glob("*.pkl")) - # corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files] + corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files] for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)): if event is None: broken_count += 1 diff --git a/generate-batch.py b/generate-batch.py index f371b3f..607f547 100644 --- a/generate-batch.py +++ b/generate-batch.py @@ -1,3 +1,4 @@ +from ast import arg import sys import os from pathlib import Path @@ -25,14 +26,25 @@ def get_argument_parser(): parser.add_argument( "-generation_type", type=str, - choices=('conditioned', 'unconditioned', 'text-conditioned'), + choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'), default='unconditioned', help="generation type", ) + parser.add_argument( + "-attr_list", + type=str, + default="beat,duration", + help="attribute list for attribute-controlled generation", + ) + parser.add_argument( + "-dataset", + type=str, + help="dataset name, only for conditioned generation", + ) parser.add_argument( "-sampling_method", type=str, - choices=('top_p', 'top_k'), + choices=('top_p', 'top_k', 'min_p'), default='top_p', help="sampling method", ) @@ -74,7 +86,7 @@ def get_argument_parser(): parser.add_argument( "-num_processes", type=int, - default=4, + default=1, help="number of processes to use", ) parser.add_argument( @@ -97,7 +109,7 @@ def get_argument_parser(): ) return parser -def load_resources(wandb_exp_dir, device): +def load_resources(wandb_exp_dir, condition_dataset, device): """Load model and dataset resources for a process""" wandb_dir = Path('wandb') ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir) @@ -107,7 +119,8 @@ def load_resources(wandb_exp_dir, device): # Load checkpoint to specified device print("Loading checkpoint from:", ckpt_path) ckpt = torch.load(ckpt_path, map_location=device) - model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path) + print(config) + model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset) model.load_state_dict(ckpt['model'], strict=False) model.to(device) model.eval() @@ -123,20 +136,33 @@ def load_resources(wandb_exp_dir, device): return config, model, dataset_for_prompt, vocab -def conditioned_worker(process_idx, gpu_id, args, data_slice): +def conditioned_worker(process_idx, gpu_id, args): """Worker process for conditioned generation""" torch.cuda.set_device(gpu_id) device = torch.device(f'cuda:{gpu_id}') # Load resources with proper device - config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device) + config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) + # print(test_set) + if args.choose_selected_tunes and test_set.dataset == 'SOD': + selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch', + "Clarinet Concert in A Major: 2nd Movement, Adagio_orch"] + else: + selected_tunes = [name for _, name in test_set][:args.num_samples] + + # Split selected data across processes + selected_data = [d for d in test_set if d[1] in selected_tunes] + chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes + start_idx = 1 + end_idx = min(chunk_size, len(selected_data)) + data_slice = selected_data[start_idx:end_idx] # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}" base_path.mkdir(parents=True, exist_ok=True) - evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device) + evaluator = Evaluator(config, model, data_slice, vocab, device=device) # Process assigned data slice for idx, (tune_in_idx, tune_name) in enumerate(data_slice): @@ -154,13 +180,62 @@ def conditioned_worker(process_idx, gpu_id, args, data_slice): generation_length=args.generate_length ) +def attr_conditioned_worker(process_idx, gpu_id, args): + """Worker process for conditioned generation""" + torch.cuda.set_device(gpu_id) + device = torch.device(f'cuda:{gpu_id}') + # attr_list = "position,duration" + attr_list = args.attr_list.split(',') + + # Load resources with proper device + config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) + # print(test_set) + if args.choose_selected_tunes and test_set.dataset == 'SOD': + selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch', + "Clarinet Concert in A Major: 2nd Movement, Adagio_orch"] + else: + selected_tunes = [name for _, name in test_set][:args.num_samples] + + # Split selected data across processes + selected_data = [d for d in test_set if d[1] in selected_tunes] + # chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes + # start_idx = 1 + # end_idx = min(chunk_size, len(selected_data)) + # data_slice = selected_data[start_idx:end_idx] + data_slice = selected_data + + # Create output directory with process index + base_path = Path('wandb') / args.wandb_exp_dir / \ + f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}" + base_path.mkdir(parents=True, exist_ok=True) + + evaluator = Evaluator(config, model, data_slice, vocab, device=device) + + # Process assigned data slice + for idx, (tune_in_idx, tune_name) in enumerate(data_slice): + batch_dir = base_path + batch_dir.mkdir(parents=True, exist_ok=True) + evaluator.generate_samples_with_attrCtl( + batch_dir, + args.num_target_measure, + tune_in_idx, + tune_name, + config.data_params.first_pred_feature, + args.sampling_method, + args.threshold, + args.temperature, + generation_length=args.generate_length, + attr_list=attr_list + ) + + def unconditioned_worker(process_idx, gpu_id, args, num_samples): """Worker process for unconditioned generation""" torch.cuda.set_device(gpu_id) device = torch.device(f'cuda:{gpu_id}') # Load resources with proper device - config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device) + config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ @@ -187,7 +262,7 @@ def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice): device = torch.device(f'cuda:{gpu_id}') # Load resources with proper device - config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device) + config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ @@ -237,40 +312,33 @@ def main(): if not wandb_dir.exists(): raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found") - # Load test set to get selected tunes (dummy load to get dataset info) - dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - _, test_set, _ = prepare_model_and_dataset_from_config( - wandb_dir / "files" / "config.yaml", - wandb_dir / "files" / "metadata.json", - wandb_dir / "files" / "vocab.json" - ) - - if args.choose_selected_tunes and test_set.dataset == 'SOD': - selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch', - "Clarinet Concert in A Major: 2nd Movement, Adagio_orch"] - else: - selected_tunes = [name for _, name in test_set.data_list][:args.num_samples] - - # Split selected data across processes - selected_data = [d for d in test_set.data_list if d[1] in selected_tunes] - chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes for i in range(args.num_processes): - start_idx = i * chunk_size - end_idx = min((i+1)*chunk_size, len(selected_data)) - data_slice = selected_data[start_idx:end_idx] - - if not data_slice: - continue gpu_id = gpu_ids[i % len(gpu_ids)] p = Process( target=conditioned_worker, - args=(i, gpu_id, args, data_slice) + args=(i, gpu_id, args) ) processes.append(p) p.start() - + elif args.generation_type == 'attr-conditioned': + # Prepare selected tunes + wandb_dir = Path('wandb') / args.wandb_exp_dir + if not wandb_dir.exists(): + raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found") + + + for i in range(args.num_processes): + + gpu_id = gpu_ids[i % len(gpu_ids)] + p = Process( + target=attr_conditioned_worker, + args=(i, gpu_id, args) + ) + processes.append(p) + p.start() + elif args.generation_type == 'unconditioned': samples_per_proc = args.num_samples // args.num_processes remainder = args.num_samples % args.num_processes diff --git a/len_tunes/Melody/len_nb8.png b/len_tunes/Melody/len_nb8.png new file mode 100644 index 0000000..01ef6ae Binary files /dev/null and b/len_tunes/Melody/len_nb8.png differ diff --git a/len_tunes/msmidi/len_nb8.png b/len_tunes/msmidi/len_nb8.png new file mode 100644 index 0000000..1e58ef8 Binary files /dev/null and b/len_tunes/msmidi/len_nb8.png differ diff --git a/midi_sim.py b/midi_sim.py new file mode 100644 index 0000000..44c5ba5 --- /dev/null +++ b/midi_sim.py @@ -0,0 +1,134 @@ +import os +from math import ceil +#CUDA_VISIBLE_DEVICES= "0" +import numpy as np +import pandas as pd +from symusic import Score +from concurrent.futures import ProcessPoolExecutor, as_completed +from tqdm import tqdm +semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2]) + +def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (0., 1.5)): + if(not a.shape[1] or not b.shape[1]): + return np.inf + a_onset, a_pitch = a + b_onset, b_pitch = b + a_onset = a_onset.astype(np.float32) + b_onset = b_onset.astype(np.float32) + a_pitch = a_pitch.astype(np.int16) + b_pitch = b_pitch.astype(np.int16) + + onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1)) + a2b_idx = onset_dist_matrix.argmin(1) + b2a_idx = onset_dist_matrix.argmin(0) + + a_pitch -= (np.median(a_pitch) - np.median(b_pitch)).astype(np.int16) # Normalize pitch + a_pitch = a_pitch + np.arange(-7, 7).reshape(-1, 1) # Transpose invarient + + interval_diff = np.concatenate([ + a_pitch[:, a2b_idx] - b_pitch, + b_pitch[b2a_idx] - a_pitch], axis=1) + pitch_dist = np.abs(semitone2degree[interval_diff % 8] + np.abs(interval_diff) // 8 * np.sign(interval_diff)).mean(1).min() + onset_dist = np.abs(np.concatenate([ + a_onset[a2b_idx] - b_onset, + b_onset[b2a_idx] - a_onset], axis=0)).mean() + + return (weight[0] * onset_dist + weight[1] * pitch_dist) / sum(weight) + + +def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 8., hop_size: float = 4.): + x = sorted(x) + trim_offset = (x[0][0] // hop_size) * hop_size + end_time = x[-1][0] + num_segment = ceil((end_time - window_size - trim_offset) / hop_size) + 1 + + time_matrix = (np.fromiter((time for time, _ in x), dtype=float) - trim_offset).reshape(1, -1).repeat(num_segment, axis=0) + seg_time_starts = np.arange(num_segment).reshape(-1, 1) * hop_size + + time_compare_matrix = np.where((time_matrix >= seg_time_starts) & (time_matrix <= seg_time_starts + window_size), 0, 1) + time_compare_matrix = np.diff(np.pad(time_compare_matrix, ((0, 0), (1, 1)), constant_values=1)) + start_idxs = sorted(np.where(time_compare_matrix == -1), key=lambda x: x[0])[1].tolist() + end_idxs = sorted(np.where(time_compare_matrix == 1), key=lambda x: x[0])[1].tolist() + + segments = [x[start:end] for start, end in zip(start_idxs, end_idxs)] + + return segments + + +def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4): + a = midi_time_sliding_window(a, window_size=window_size, hop_size=hop_size) + b = midi_time_sliding_window(b, window_size=window_size, hop_size=hop_size) + dist = np.inf + for x,i in enumerate(a): + for y,j in enumerate(b): + cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T) + if cur_dist == 0: + print(x, y) + if(cur_dist < dist): + dist = cur_dist + return float(dist) + + +def extract_notes(filepath: str): + """读取MIDI并返回 (time, pitch) 列表""" + try: + s = Score(filepath).to("quarter") + notes = [] + # for t in s.tracks: + # notes.extend([(n.time, n.pitch) for n in t.notes]) + notes = [(n.time, n.pitch) for n in s.tracks[0].notes] # 仅使用第一个track + return notes + except Exception as e: + print(f"读取 {filepath} 出错: {e}") + return [] + +def compare_pair(file_a: str, file_b: str): + try: + notes_a = extract_notes(file_a) + notes_b = extract_notes(file_b) + if not notes_a or not notes_b: + return (file_a, file_b, np.inf) + dist = midi_dist(notes_a, notes_b) + return (file_a, file_b, dist) + except Exception as e: + import traceback + print(f"⚠️ compare_pair 出错: {file_a} vs {file_b}") + traceback.print_exc() + return (file_a, file_b, np.inf) + + +def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8): + files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")] + files_a = files_a[:100] # 仅比较前100个文件以节省时间 + files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")] + + results = [] + pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files") + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b] + for fut in as_completed(futures): + pbar.update(1) + try: + results.append(fut.result()) + except Exception as e: + print(fut.result()) + print(f"Error comparing pair: {e}") + # print(f"Compared: {results[-1][0]} vs {results[-1][1]}, Distance: {results[-1][2]:.4f}") + # with tqdm(total=len(files_a) * len(files_b)) as pbar: + # for fa in files_a: + # for fb in files_b: + # results.append(compare_pair(fa, fb)) + # pbar.update(1) + # # 排序 + results = sorted(results, key=lambda x: x[2]) + + # 保存 + df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"]) + df.to_csv(out_csv, index=False) + print(f"已保存结果到 {out_csv}") + + +if __name__ == "__main__": + dir_a = "wandb/run-20251015_154556-f0pj3ys3/cond_4m_top_p_t0.99_temp1.25/process_2_batch_23" + dir_b = "dataset/Melody" + batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6) \ No newline at end of file diff --git a/,idi_sim.py b/,idi_sim.py deleted file mode 100644 index dfe9488..0000000 --- a/,idi_sim.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import numpy as np -import pandas as pd -from symusic import Score -from concurrent.futures import ProcessPoolExecutor, as_completed - -semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2]) - -def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True): - if(not a.shape[1] or not b.shape[1]): - return np.inf - a_onset, a_pitch = a - b_onset, b_pitch = b - a_onset = a_onset.astype(np.float32) - b_onset = b_onset.astype(np.float32) - a_pitch = a_pitch.astype(np.uint8) - b_pitch = b_pitch.astype(np.uint8) - - onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1)) - if(oti): - pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12] - dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight) - a2b = dist_matrix.min(2) - b2a = dist_matrix.min(1) - dist = np.concatenate([a2b, b2a], axis=1) - return dist.sum(axis=1).min() / len(dist) - else: - pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12] - dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight) - a2b = dist_matrix.min(1) - b2a = dist_matrix.min(0) - return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1])) - - -def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.): - x = sorted(x) - end_time = x[-1][0] - out = [[] for _ in range(int(end_time // hop_size))] - for i in sorted(x): - segment = min(int(i[0] // hop_size), len(out) - 1) - while(i[0] >= segment * hop_size): - out[segment].append(i) - segment -= 1 - if(segment < 0): - break - return out - - -def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4): - a = midi_time_sliding_window(a) - b = midi_time_sliding_window(b) - dist = np.inf - for i in a: - for j in b: - cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T) - if(cur_dist < dist): - dist = cur_dist - return dist - - -def extract_notes(filepath: str): - """读取MIDI并返回 (time, pitch) 列表""" - try: - s = Score(filepath).to("quarter") - notes = [] - for t in s.tracks: - notes.extend([(n.time, n.pitch) for n in t.notes]) - return notes - except Exception as e: - print(f"读取 {filepath} 出错: {e}") - return [] - - -def compare_pair(file_a: str, file_b: str): - notes_a = extract_notes(file_a) - notes_b = extract_notes(file_b) - if not notes_a or not notes_b: - return (file_a, file_b, np.inf) - dist = midi_dist(notes_a, notes_b) - return (file_a, file_b, dist) - - -def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8): - files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")] - files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")] - - results = [] - with ProcessPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b] - for fut in as_completed(futures): - results.append(fut.result()) - - # 排序 - results = sorted(results, key=lambda x: x[2]) - - # 保存 - df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"]) - df.to_csv(out_csv, index=False) - print(f"已保存结果到 {out_csv}") - - -if __name__ == "__main__": - dir_a = "folder_a" - dir_b = "folder_b" - batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8) \ No newline at end of file