diff --git a/Amadeus/evaluation_utils.py b/Amadeus/evaluation_utils.py index 37f9917..f5d7b13 100644 --- a/Amadeus/evaluation_utils.py +++ b/Amadeus/evaluation_utils.py @@ -96,7 +96,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, num_features = config.nn_params.num_features # get vocab - vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'} + vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'} selected_vocab_name = vocab_name[encoding_scheme] vocab = getattr(vocab_utils, selected_vocab_name)( @@ -477,7 +477,7 @@ class Evaluator: except KeyError: in_beat_resolution = 4 # Default resolution if dataset is not found - midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} + midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) @@ -499,7 +499,7 @@ class Evaluator: except KeyError: in_beat_resolution = 4 # Default resolution if dataset is not found - midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} + midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) @@ -509,7 +509,7 @@ class Evaluator: generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature) decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid")) - prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8) + prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=num_target_measures) decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid")) def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1): @@ -520,7 +520,7 @@ class Evaluator: in_beat_resolution = in_beat_resolution_dict[self.config.dataset] except KeyError: in_beat_resolution = 4 # Default resolution if dataset is not found - midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} + midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'} decoder_name = midi_decoder_dict[encoding_scheme] decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset) diff --git a/Amadeus/model_zoo.py b/Amadeus/model_zoo.py index 7b2788d..3c50b5c 100644 --- a/Amadeus/model_zoo.py +++ b/Amadeus/model_zoo.py @@ -103,7 +103,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): super().__init__() self.net = net # self.prediction_order = net.prediction_order - # self.attribute2idx = {key: idx for idx, key in enumerate(self.prediction_order)} + self.attribute2idx_after = {'pitch': 0, 'duration': 1, 'velocity': 2, @@ -113,6 +113,17 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): 'tempo': 6, 'instrument': 7} self.attribute2idx = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7} + # if using position attribute, change accordingly + if 'position' in self.net.vocab.feature_list: + self.attribute2idx_after = {'pitch': 0, + 'position': 1, + 'bar': 2, + 'velocity': 3, + 'duration': 4, + 'program': 5, + 'tempo': 6, + 'timesig': 7} + self.attribute2idx = {'pitch':0, 'position':1, 'bar':2, 'velocity':3, 'duration':4, 'program':5, 'tempo':6, 'timesig':7} def forward(self, input_seq:torch.Tensor, target:torch.Tensor,context=None): return self.net(input_seq, target, context=context) @@ -161,10 +172,12 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() # measure_bool = (condition[:,1] == 1) # measure tokens conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() - elif self.net.vocab.encoding_scheme == 'nb': + elif self.net.vocab.encoding_scheme == 'nb' or self.net.vocab.encoding_scheme == 'oct': measure_bool = (condition[:,0] == 2) | (condition[:,0] >= 5) # Empty measure or where new measure starts - conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() - + try: + conditional_input_len = torch.where(measure_bool)[0][num_target_measures].item() + except: + conditional_input_len = condition.shape[0] if conditional_input_len == 0: conditional_input_len = 50 @@ -262,7 +275,7 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): # print(self.attribute2idx) for attr, idx in self.attribute2idx.items(): if attr not in attr_list: - condition_filtered[:, :, idx] = 126336 + condition_filtered[:, 1:, idx] = 126336 # rearange condition_filtered to match prediction order cache = LayerIntermediates() @@ -286,8 +299,9 @@ class AmadeusModelAutoregressiveWrapper(nn.Module): for attr, idx in self.attribute2idx.items(): new_idx = self.attribute2idx_after[attr] condition_step_rearranged[:, :, new_idx] = condition_step[:, :, idx] - # print("condition_step shape:", condition_step.shape) + # print("condition_step shape:", condition_step) _, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context, condition_step=condition_step_rearranged) + # print("sampled_token shape:", sampled_token) else: _, sampled_token, cache, hidden_vec = self._run_one_step(input_tensor, cache=cache, sampling_method=sampling_method, threshold=threshold, temperature=temperature,bos_hidden_vec=bos_hidden_vec, context=context) time_end = time.time() diff --git a/Amadeus/sub_decoder_zoo.py b/Amadeus/sub_decoder_zoo.py index 9179211..db7ba31 100644 --- a/Amadeus/sub_decoder_zoo.py +++ b/Amadeus/sub_decoder_zoo.py @@ -713,7 +713,7 @@ class DiffusionDecoder(SubDecoderClass): dropout:float, sub_decoder_enricher_use:bool, MASK_IDX:int = 126336, - denoising_steps:int = 8, + denoising_steps:int = 6, eps:float = 1e-3, method:str = 'low-confidence', # or random or auto-regressive ): @@ -1029,7 +1029,338 @@ class DiffusionDecoder(SubDecoderClass): # print("toknes", sampled_token_dict) # set prob of the changed tokens to -inf stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf) - print("stacked_logits_probs", stacked_logits_probs.clone()) + # print("stacked_logits_probs", stacked_logits_probs.clone()) + + if self.method == 'low-confidence': + _, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1) + elif self.method == 'random': + indices = torch.randint(0, stacked_logits_probs.shape[-1], (num_transfer_tokens[:, step],)).to(logit.device) + elif self.method == 'auto-regressive': + indices = torch.tensor([[step]], device=logit.device) + # undate the masked history + for i in range(b*t): + for j in range(l): + if j in indices[i]: + # print(f"Step {step}: Updating token for feature {self.prediction_order[j]} at position {(i,j)}") + masked_history[i][j] = False + stored_logits_dict[self.prediction_order[j]] = logits_dict[self.prediction_order[j]].clone() + stored_token_embeddings[i][j][:] = stacked_token_embeddings[i][j][:] + expand_masked_history = masked_history.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x num_sub_tokens x d_model + memory_tensor = torch.where(expand_masked_history, all_noise_tensor, stored_token_embeddings) + # skip if all tokens are unmasked + if not expand_masked_history.any(): + # print("All tokens have been unmasked. Ending denoising process.") + break + # print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) + # get final sampled tokens by embedding the unmasked tokens + sampled_token_dict = {} + for idx, feature in enumerate(self.prediction_order): + sampled_token = self.emb_layer.get_token_by_emb(feature, memory_tensor[:, idx, :]) + sampled_token_dict[feature] = sampled_token + # print("Final sampled tokens:") + # print(sampled_token_dict) + # print(condition_step) + return stored_logits_dict, sampled_token_dict + + # ---- Training ---- # + _, masked_indices, p_mask = self._forward_process(target, mask_idx=self.MASK_idx) # (B*T) x (num_sub_tokens) x d_model + memory_tensor = self._prepare_embedding(memory_list, target) # (B*T) x (num_sub_tokens) x d_model + # apply layer norm + + extend_masked_indices = masked_indices.unsqueeze(-1).expand(-1, -1, memory_tensor.shape[-1]) # (B*T) x (num_sub_tokens) x d_model + if worst_case: # mask all ,turn into parallel + extend_masked_indices = torch.ones_like(extend_masked_indices).to(self.device) + memory_tensor = torch.where(extend_masked_indices, self.diffusion_mask_emb, memory_tensor) + memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model + # all is embedding + # memory_tensor = self.layer_norm(memory_tensor) + # apply feature enricher to memory + if self.sub_decoder_enricher_use: + input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} + input_dict = self.feature_enricher_layers(input_dict) + memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model + # implement sub decoder cross attention + input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask} + input_dict = self.sub_decoder_layers(input_dict) + attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model + # get prob + for idx, feature in enumerate(self.prediction_order): + feature_pos = self.feature_order_in_output[feature] + logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :]) + logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size + logits_dict[feature] = logit + return logits_dict, (masked_indices, p_mask) + +class DiffusionDecoder(SubDecoderClass): + def __init__( + self, + prediction_order:list, + vocab:LangTokenVocab, + sub_decoder_depth:int, + dim:int, + heads:int, + dropout:float, + sub_decoder_enricher_use:bool, + MASK_IDX:int = 126336, + denoising_steps:int = 6, + eps:float = 1e-3, + method:str = 'low-confidence', # or random or auto-regressive + ): + ''' + The power of Cross-attention and UniAudio style Self-attention lies in that using the output of the main decoder or hidden vec directly in the sub-decoder + As the output of the main decoder is the representation of the whole sequence, + it contains richer information which can even decode out sub-tokens in a parallel manner + So both architectures using the output of the main decoder in a direct way show better performance than the original self-attention sub-decoder + ''' + super().__init__(prediction_order, vocab, sub_decoder_depth, dim, heads, dropout, sub_decoder_enricher_use) + self.sub_decoder_enricher_use = sub_decoder_enricher_use + self.feature_order_in_output = {key: (idx-len(prediction_order)) for idx, key in enumerate(prediction_order)} + + self.pos_enc = nn.Embedding(len(self.prediction_order), dim) + nn.init.zeros_(self.pos_enc.weight) + + self.sub_decoder_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True) + self.diffusion_mask_emb = nn.Parameter(torch.empty(dim), requires_grad=True) # embedding of mask token,idx is 126336,which is not in vocab + nn.init.normal_(self.diffusion_mask_emb, mean=0.0, std=0.02) + self.MASK_idx = MASK_IDX + self.denoising_steps = denoising_steps + self.eps = eps + self.method = method + + self.input_norm = nn.LayerNorm(dim) + + self.feature_boost_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout)) + + if sub_decoder_enricher_use: + self.enricher_BOS_emb = nn.Parameter(torch.zeros(dim), requires_grad=True) + causal_mask = generate_SA_mask(len(prediction_order)) + causal_ca_mask = generate_none_causality_mask(len(prediction_order), len(prediction_order)).to(self.device) + self.register_buffer('causal_mask', causal_mask) + self.register_buffer('causal_ca_mask', causal_ca_mask) + + # get depth of the sub-decoder + if sub_decoder_depth > 1: + self.sub_decoder_layers = nn.Sequential(*[TransformerLayer(dim=dim, num_heads=heads, dropout=dropout) for _ in range(sub_decoder_depth)]) + else: + self.sub_decoder_layers = nn.Sequential(TransformerLayer(dim=dim, num_heads=heads, dropout=dropout)) + if sub_decoder_enricher_use: + self.feature_enricher_layers = nn.Sequential(FeatureEnricher(dim=dim, num_heads=heads, dropout=dropout)) + + + # simplified version of the forward process in diffusion model + def _forward_process(self, input_ids, eps=1e-3, mask_idx=None): + reshaped_input_ids = torch.reshape(input_ids, (-1, input_ids.shape[-1])) # B*T x num_sub_tokens + b, l = reshaped_input_ids.shape + t = torch.rand(b, device=input_ids.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) + + masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask + # 126336 is used for [MASK] token,attention that this token is not in the vocab + if mask_idx is not None: + noisy_batch = torch.where(masked_indices, mask_idx, reshaped_input_ids) + else: + noisy_batch = torch.where(masked_indices, 126336, reshaped_input_ids)# 126336 is used for [MASK] token in + return noisy_batch, masked_indices, p_mask + + + def _apply_window_on_hidden_vec(self, hidden_vec): + BOS_emb = self.enricher_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model + # through our experiments, we found that the size of the window doesn't affect the performance of the model much + window_size = 1 + zero_vec = torch.zeros((hidden_vec.shape[0], window_size-1, hidden_vec.shape[2])).to(self.device) # B x (window_size-1) x d_model + cat_hidden_vec = torch.cat([zero_vec, hidden_vec], dim=1) # B x (window_size-1+T) x d_model + new_hidden_vec = cat_hidden_vec.unfold(1, window_size, 1).transpose(2, 3) # B x T x window_size x d_model + new_hidden_vec = new_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], window_size, -1)) # (B*T) x window_size x d_model + new_hidden_vec = torch.cat([BOS_emb, new_hidden_vec], dim=1) # (B*T) x (window_size+1) x d_model + return new_hidden_vec + + def _apply_pos_enc(self, tgt): + pos = torch.arange(tgt.shape[1]).to(tgt.device) # num_sub_tokens + pos = pos.unsqueeze(0).repeat(tgt.shape[0], 1) # (B*T) x num_sub_tokens + tgt_pos = tgt + self.pos_enc(pos.long()) # (B*T) x num_sub_tokens x d_model + return tgt_pos + + def _prepare_token_embedding_for_teacher_forcing(self, memory_list, target): + for _, feature in enumerate(self.prediction_order[:-1]): + feature_idx = self.vocab.feature_list.index(feature) + feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size + feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size + memory_list.append(feature_emb_reshape) + memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens-1) x d_model + return memory_tensor + + # return a tensor + def _get_noisy_tensor(self, target_shape): + new_target = torch.zeros(target_shape).to(self.device) + # fill all the elements in the tensor with the embedding of the mask token + new_target[:, :, :] = self.diffusion_mask_emb + return new_target + + # prepare the embedding of the target, + def _prepare_embedding(self, memory_list, target): + for _, feature in enumerate(self.prediction_order): + feature_idx = self.vocab.feature_list.index(feature) + feature_emb = self.emb_layer.get_emb_by_key(feature, target[..., feature_idx]) # B x T x emb_size + feature_emb_reshape = feature_emb.reshape((feature_emb.shape[0]*feature_emb.shape[1], 1, -1)) # (B*T) x 1 x emb_size + memory_list.append(feature_emb_reshape) + memory_tensor = torch.cat(memory_list, dim=1) # (B*T) x (BOS + num_sub_tokens) x d_model + return memory_tensor + + + def _prepare_memory_list(self, hidden_vec, target=None, add_BOS=True): + memory_list = [] # used for key and value in cross attention + BOS_emb = self.sub_decoder_BOS_emb.reshape(1,1,-1).repeat(hidden_vec.shape[0]*hidden_vec.shape[1], 1, 1) # (B*T) x 1 x d_model + if add_BOS is true: + if target is not None: # training + memory_list.append(BOS_emb) + else: # inference + memory_list.append(BOS_emb[-1:, :, :]) + else: + pass + return memory_list + + def _get_num_transfer_tokens(self, mask_index, steps): + ''' + In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. + Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), + the expected number of tokens transitioned at each step should be consistent. + + This function is designed to precompute the number of tokens that need to be transitioned at each step. + ''' + mask_num = mask_index.sum(dim=1,keepdim=True) + base = mask_num // steps + remainder = mask_num % steps + + num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base + + for i in range(mask_num.size(0)): + num_transfer_tokens[i, :remainder[i]] += 1 + + return num_transfer_tokens + + def sample_from_logits(self, attn_output, hidden_vec, sampling_method=None, threshold=None, temperature=None, force_decode=False,step=None): + sampled_token_dict = {} + logits_dict = {} + candidate_token_embeddings = {} + candidate_token_probs = {} + b,t,d = hidden_vec.shape # B x T x d_model + # print("*"*8) + logits_list = [] + for idx, feature in enumerate(self.prediction_order): + feature_pos = self.feature_order_in_output[feature] + logit = self.hidden2logit[f"layer_{feature}"](attn_output[:, feature_pos, :]) + logit = logit.reshape((hidden_vec.shape[0], hidden_vec.shape[1], -1)) # B x T x vocab_size + logits_list.append(logit) + for idx, feature in enumerate(self.prediction_order): + logit = logits_list[idx] # B x T x vocab_siz + sampled_token, prob = sample_with_prob(logit, sampling_method=sampling_method, threshold=threshold, temperature=temperature) + if step==0 and force_decode: + if feature == 'velocity': + sampled_token = torch.tensor([2]).to(logit.device) + prob = torch.tensor([1.0]).to(logit.device) + else: + prob = torch.tensor([0.0]).to(logit.device) + # print(feature, sampled_token, prob) + sampled_token_dict[feature] = sampled_token + logits_dict[feature] = logit + candidate_token_probs[feature] = prob + feature_emb = self.emb_layer.get_emb_by_key(feature, sampled_token) + feature_emb_reshape = feature_emb.reshape((1, 1, -1)) # (B*T) x 1 x emb_size + candidate_token_embeddings[feature] = feature_emb_reshape + stacked_logits_probs = torch.stack(list(candidate_token_probs.values()), dim=0).reshape((b*t, -1)) # (B*T) x num_sub_tokens x vocab_size + stacked_token_embeddings = torch.stack(list(candidate_token_embeddings.values()), dim=0).reshape((b*t, -1, d)) # (B*T) x num_sub_tokens x d_model + # print("sampled_token_dict", sampled_token_dict) + return sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings + + def forward(self, input_dict, sampling_method=None, threshold=None, temperature=None, Force_decode=False, worst_case=False, condition_step=None): + logits_dict = {} + hidden_vec = input_dict['hidden_vec'] # B x T x d_model + target = input_dict['target'] #B x T x d_model + bos_hidden_vec = input_dict['bos_token_hidden'] # B x 1 x d_model, used for the first token in the sub-decoder + + # apply window on hidden_vec for enricher + if self.sub_decoder_enricher_use: + window_applied_hidden_vec = self._apply_window_on_hidden_vec(hidden_vec) # (B*T) x window_size x d_model + hidden_vec_reshape = hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) # (B*T) x 1 x d_model + input_seq = hidden_vec_reshape.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model + input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model + + if bos_hidden_vec is None: # start of generation + if target is None: + bos_hidden_vec = input_seq_pos + else: + bos_hidden_vec =hidden_vec[:, 0, :].unsqueeze(1).repeat(1, hidden_vec.shape[1], 1) # B x T x d_model + bos_hidden_vec = bos_hidden_vec.reshape((hidden_vec.shape[0]*hidden_vec.shape[1], 1, -1)) + bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) + + else: + bos_hidden_vec = bos_hidden_vec.repeat(1, len(self.prediction_order), 1) # (B*T) x num_sub_tokens x d_model + + # input_seq_pos = input_seq + input_dict = {'input_seq': input_seq_pos, 'memory': bos_hidden_vec, 'memory_mask': self.causal_ca_mask} + boosted_input_dict = self.feature_boost_layers(input_dict) # (B*T) x num_sub_tokens x d_model + input_seq_pos = boosted_input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model + # input_seq_pos = self.input_norm(input_seq_pos) # (B*T) x num_sub_tokens x d_model + # input_seq_pos = self._apply_pos_enc(input_seq) # (B*T) x num_sub_tokens x d_model + # prepare memory + memory_list = self._prepare_memory_list(hidden_vec=hidden_vec, target=target, add_BOS=False) + # ---- Generate(Inference) ---- # + if target is None: + sampled_token_dict = {} + b,t,d = hidden_vec.shape # B x T x d_model + l = len(self.prediction_order) # num_sub_tokens + memory_tensor = self._get_noisy_tensor(target_shape=(b*t, l, d)) + all_noise_tensor = memory_tensor.clone() # (B*T) x num_sub_tokens x d_model + + # indicate the position of the mask token,1 means that the token hsa been masked + masked_history = torch.ones((b*t, l), device=hidden_vec.device, dtype=torch.int64).bool() + # add attribute control here + stored_logits_dict = {} + stored_token_embeddings = torch.zeros((b*t, l, d), device=hidden_vec.device) + if condition_step is not None: + # print("shape of condition_step", condition_step.shape) + condition_step = condition_step.reshape((b*t, l)) + for i in range(b*t): + for j in range(l): + token = condition_step[i][j] + if condition_step[i][j] != self.MASK_idx: + + # print(f"Conditioning on token {token} for feature {self.prediction_order[j]} at position {(i,j)}") + masked_history[i][j] = False + memory_tensor[i][j][:] = self.emb_layer.get_emb_by_key(self.prediction_order[j], condition_step[i][j]) + stored_token_embeddings[i][j][:] = memory_tensor[i][j][:] + # print(f"Embedded token for feature {self.prediction_order[j]} at position {(i,j)}") + + num_transfer_tokens = self._get_num_transfer_tokens(masked_history, self.denoising_steps) + # denoising c + # with torch.profiler.profile( + # activities=[ + # torch.profiler.ProfilerActivity.CPU, + # torch.profiler.ProfilerActivity.CUDA], + # record_shapes=True, + # profile_memory=True, + # with_stack=True + # ) as prof: + for step in range(self.denoising_steps): + memory_tensor = self._apply_pos_enc(memory_tensor) # (B*T) x num_sub_tokens x d_model + if self.sub_decoder_enricher_use: + input_dict = {'input_seq': memory_tensor, 'memory': window_applied_hidden_vec} + input_dict = self.feature_enricher_layers(input_dict) + memory_tensor = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model + # input_dict = {'input_seq': input_seq_pos, 'memory': memory_tensor, 'memory_mask': self.causal_ca_mask} + input_dict = {'input_seq': memory_tensor, 'memory': input_seq_pos, 'memory_mask': self.causal_ca_mask} + input_dict = self.sub_decoder_layers(input_dict) + attn_output = input_dict['input_seq'] # (B*T) x num_sub_tokens x d_model + sampled_token_dict, logits_dict, candidate_token_probs, stacked_logits_probs, stacked_token_embeddings = self.sample_from_logits(attn_output, hidden_vec, sampling_method=sampling_method, threshold=threshold, temperature=temperature, + force_decode=Force_decode, + step=step) + + # print("step", step) + # print("toknes", sampled_token_dict) + # set prob of the changed tokens to -inf + stacked_logits_probs = torch.where(masked_history, stacked_logits_probs, -torch.inf) + # print("stacked_logits_probs", stacked_logits_probs.clone()) if self.method == 'low-confidence': _, indices = torch.topk(stacked_logits_probs, k=int(num_transfer_tokens[:,step]), dim=-1) diff --git a/Amadeus/symbolic_encoding/augmentor.py b/Amadeus/symbolic_encoding/augmentor.py index 40c1839..5ad5407 100644 --- a/Amadeus/symbolic_encoding/augmentor.py +++ b/Amadeus/symbolic_encoding/augmentor.py @@ -22,6 +22,8 @@ class Augmentor: self.chord_idx = self.feature_list.index('chord') def _get_shift(self, segment): + if self.encoding_scheme == 'oct': + return 0 # the pitch vocab has ignore token in 0 index if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb': pitch_mask = segment != 0 diff --git a/Amadeus/symbolic_encoding/compile_utils.py b/Amadeus/symbolic_encoding/compile_utils.py index 40eae76..e856e82 100644 --- a/Amadeus/symbolic_encoding/compile_utils.py +++ b/Amadeus/symbolic_encoding/compile_utils.py @@ -73,7 +73,7 @@ class VanillaTransformer_compiler(): for i in range(len(self.data_list)): tune_in_idx, tune_name = self.data_list[i] tune_in_idx = torch.LongTensor(tune_in_idx) - if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp': + if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct': eos_token = torch.LongTensor(self.eos_token) else: eos_token = torch.LongTensor(self.eos_token) @@ -148,7 +148,7 @@ class VanillaTransformer_compiler(): for i in range(len(self.data_list)): tune_in_idx, tune_name = self.data_list[i] tune_in_idx = torch.LongTensor(tune_in_idx) - if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp': + if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct': eos_token = torch.LongTensor(self.eos_token) else: eos_token = torch.LongTensor(self.eos_token) diff --git a/Amadeus/symbolic_encoding/data_utils.py b/Amadeus/symbolic_encoding/data_utils.py index 26231f2..af0ebca 100644 --- a/Amadeus/symbolic_encoding/data_utils.py +++ b/Amadeus/symbolic_encoding/data_utils.py @@ -95,11 +95,6 @@ class TuneCompiler(Dataset): print(f"Error encoding caption for tune {tune_name}: {e}") encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128) return segment, tensor_mask, tune_name, encoded_caption - if self.data_type == 'train': - augmented_segment = self.augmentor(segment) - return augmented_segment, tensor_mask, tune_name, encoded_caption - else: - return segment, tensor_mask, tune_name, encoded_caption def get_segments_with_tune_idx(self, tune_name, seg_order): ''' @@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset): self.data_type = data_type self.augmentor = augmentor self.eos_token = vocab.eos_token + self.vocab = vocab self.compile_function = VanillaTransformer_compiler( data_list=self.data_list, augmentor=self.augmentor, @@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset): encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128) except Exception as e: encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128) - if self.data_type == 'train': + if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct': segment = self.augmentor(segment) # use input_ids replace tune_name tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption diff --git a/Amadeus/symbolic_encoding/decoding_utils.py b/Amadeus/symbolic_encoding/decoding_utils.py index 99312a3..82efc32 100644 --- a/Amadeus/symbolic_encoding/decoding_utils.py +++ b/Amadeus/symbolic_encoding/decoding_utils.py @@ -1,13 +1,17 @@ +import re import os, sys from pathlib import Path import matplotlib.pyplot as plt from collections import defaultdict +import torch from music21 import converter import muspy import miditoolkit from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature +from symusic import Score +from miditok import Octuple, TokenizerConfig from .midi2audio import FluidSynth from data_representation.constants import PROGRAM_INSTRUMENT_MAP @@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI): music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav')) midi_obj.dump(output_path) # save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png')) - save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain) + # save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain) return midi_obj + +class MidiDecoder4Octuple(MidiDecoder4REMI): + def __init__(self, vocab, in_beat_resolution, dataset_name): + super().__init__(vocab, in_beat_resolution, dataset_name) + + + + def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor: + """ + 输入: + t: torch.Tensor, 形状 (1, N, M) + 功能: + 删除包含独立元素 0, 1, 2, 3 的子tensor行 + 返回: + torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M) + """ + if t.dim() != 3: + raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)") + + # 构造一个 mask,True 表示该行不包含 0,1,2,3 + exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device) + + # 判断每一行是否含有这些值 + mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2))) + + # 过滤行并保留 batch 维 + filtered_tensor = t[0][mask].unsqueeze(0) + + return filtered_tensor + + def __call__(self, generated_output, output_path=None): + config = TokenizerConfig( + use_time_signatures=True, + use_tempos=True, + use_velocities=True, + use_programs=True, + remove_duplicated_notes=True, + delete_equal_successive_tempo_changes=True, + ) + config.additional_params["max_bar_embedding"] = 512 + tokenizer = Octuple(config) + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token + generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output) + print(output_path) + try: + tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist()) + tok_seq.dump_midi(output_path) + except Exception as e: + print(generated_output) + print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}") + tok_seq = None + + return tok_seq diff --git a/Amadeus/symbolic_yamls/config-accelerate.yaml b/Amadeus/symbolic_yamls/config-accelerate.yaml index d472105..7fd217f 100644 --- a/Amadeus/symbolic_yamls/config-accelerate.yaml +++ b/Amadeus/symbolic_yamls/config-accelerate.yaml @@ -1,8 +1,8 @@ defaults: # - nn_params: nb8_embSum_NMT # - nn_params: remi8 - # - nn_params: nb8_embSum_diff_t2m_150M_finetunning - - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2 + - nn_params: oct8_embSum_diff_t2m_150M_pretrainingv2 + # - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2 # - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2 # - nn_params: nb8_embSum_subPararell # - nn_params: nb8_embSum_diff_t2m_150M @@ -15,7 +15,7 @@ defaults: # - nn_params: remi8_main12_head_16_dim512 # - nn_params: nb5_embSum_diff_main12head16dim768_sub3 -dataset: msmidi # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset +dataset: Melody # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset captions_path: dataset/midicaps/train_set.json # dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean @@ -31,7 +31,7 @@ tau: 0.5 train_params: device: cuda - batch_size: 5 + batch_size: 10 grad_clip: 1.0 num_iter: 300000 # total number of iterations num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference @@ -44,7 +44,7 @@ train_params: focal_gamma: 0 # learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details scheduler : cosinelr - initial_lr: 0.0003 + initial_lr: 0.00001 decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts' warmup_steps: 2000 #number of warmup steps diff --git a/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv2.yaml b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv2.yaml new file mode 100644 index 0000000..0e16086 --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/oct8_embSum_diff_t2m_150M_pretrainingv2.yaml @@ -0,0 +1,19 @@ +encoding_scheme: oct +num_features: 8 +vocab_name: MusicTokenVocabOct +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewPretrainingDecoder +sub_decoder_name: DiffusionDecoder +model_dropout: 0.2 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 768 + num_layer: 16 + num_head: 12 +sub_decoder: + decout_window_size: 1 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/train_utils.py b/Amadeus/train_utils.py index a98ce59..57fe28b 100644 --- a/Amadeus/train_utils.py +++ b/Amadeus/train_utils.py @@ -29,8 +29,12 @@ def adjust_prediction_order(encoding_scheme, num_features, target_feature, nn_pa 7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"], 8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"] } - - if encoding_scheme == 'remi': + oct_prediction_order = { + 7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"], + 8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]} + if encoding_scheme == 'oct': + prediction_order = oct_prediction_order[num_features] + elif encoding_scheme == 'remi': prediction_order = feature_prediction_order_dict[num_features] elif encoding_scheme == 'cp': if nn_params.get("partial_sequential_prediction", False): @@ -239,11 +243,11 @@ class DiffusionLoss4CompoundToken(): training_loss = self.get_nll_loss(logits_dict[key], shifted_tgt[..., idx], mask, mask_indices[..., idx], p_mask[..., idx]) train_loss_list.append(training_loss) if valid: - if key == 'type': + if key == 'type' or key == 'timesig': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=None, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) - elif key == 'beat': + elif key == 'beat' or key == 'position' or key == 'bar': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) - elif key == 'chord' or key == 'tempo' or key == 'instrument': + elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program': log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=9999, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) else: log_normal_loss = self.get_nll_loss_for_logging(logits_dict[key], shifted_tgt[..., idx], mask, ignore_token=0, conti_token=None, mask_indices=mask_indices[..., idx], p_mask=p_mask[..., idx]) diff --git a/Amadeus/trainer_accelerate.py b/Amadeus/trainer_accelerate.py index 2e4224a..94054cf 100644 --- a/Amadeus/trainer_accelerate.py +++ b/Amadeus/trainer_accelerate.py @@ -74,7 +74,7 @@ class LanguageModelTrainer: sampling_threshold: float, # Threshold for sampling decisions sampling_temperature: float, # Temperature for controlling sampling randomness config, # Configuration parameters (contains general, training, and inference settings) - model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional) + model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional) # model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional) ): # Save model, optimizer, and other configurations @@ -104,7 +104,6 @@ class LanguageModelTrainer: checkpoint = torch.load(model_checkpoint, map_location='cpu') # print state dict keys print("Loading model checkpoint from", model_checkpoint) - print("Checkpoint keys:", checkpoint['model'].keys()) if isinstance(self.model, DDP): self.model.module.load_state_dict(checkpoint['model'], strict=False) else: @@ -902,9 +901,9 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer): correct_guess_by_feature = defaultdict(int) num_tokens_by_feature = defaultdict(int) for idx, key in enumerate(self.vocab.feature_list): - if key == 'type': + if key == 'type' or key == 'timesig' : num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None) - elif key == 'chord' or key == 'tempo' or key == 'instrument': + elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program': num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999) elif key == 'beat': # NB's beat vocab has Ignore and CONTI token diff --git a/data_representation/octuple2tuneinidx.py b/data_representation/octuple2tuneinidx.py new file mode 100644 index 0000000..3b7f282 --- /dev/null +++ b/data_representation/octuple2tuneinidx.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +MIDI 预处理脚本(并行版) +功能: +1. 使用 miditok 的 Octuple 分词器。 +2. 限制 MIDI 文件时长在 8~2000 秒。 +3. 缺失 tempo 时默认 120 BPM;缺失 time signature 时默认 4/4。 +4. 保存 vocab.json。 +5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。 +""" + +import os +import glob +import struct +import numpy as np +from multiprocessing import RLock +from concurrent.futures import ProcessPoolExecutor, as_completed + +from tqdm import tqdm +from symusic import Score +from miditok import Octuple, TokenizerConfig + + +lock = RLock() + +def convert_event_dicts(dict_list): + """ + 将 event 词表列表按顺序转换为结构化输出 + 输入: list[dict] + 每个 dict 对应一个类别,按固定顺序排列: + 0: Pitch/PitchDrum + 1: Position + 2: Bar + (+ Optional) Velocity + (+ Optional) Duration + (+ Optional) Program + (+ Optional) Tempo + (+ Optional) TimeSignature + 输出示例: + { + "pitch": {"0": 0, "1": "Pitch_60", ...}, + "position": {...}, + ... + } + """ + keys_order = [ + "pitch", "position", "bar", + "velocity", "duration", "program", + "tempo", "timesig" + ] + + result = {} + for i, d in enumerate(dict_list): + if i >= len(keys_order): + break # 超出定义范围的忽略 + category = keys_order[i] + result[category] = {str(v): k for k, v in d.items()} + + return result + +def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str): + try: + score = Score(midi_path, ttype="tick") + + if(not (8 <= (duration := score.to('second').end()) <= 2000)): + with lock: + print(f" × 时长不符合要求:{midi_path} -> {duration}s") + return + + # 分词 + tok_seq = tokenizer(score) + token_ids = tok_seq.ids + # add sos token at the beginning + vocab = tokenizer.vocab + sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1) + token_ids.insert(0, sos_token) + token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2 + + # 保存单个 npz 文件 + filename = os.path.splitext(os.path.basename(midi_path))[0] + save_path = os.path.join(output_dir, f"{filename}.npz") + np.savez_compressed(save_path, np.array(token_ids)) + except Exception as e: + with lock: + print(f" × 处理文件时出错:{midi_path} -> {e}") + + +def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)): + # === 1. 初始化分词器并保存词表 === + print("初始化分词器 Octuple...") + config = TokenizerConfig( + use_time_signatures=True, + use_tempos=True, + use_velocities=True, + use_programs=True, + remove_duplicated_notes=True, + delete_equal_successive_tempo_changes=True, + ) + config.additional_params["max_bar_embedding"] = 512 + tokenizer = Octuple(config) + vocab = tokenizer.vocab + vocab_structured = convert_event_dicts(vocab) + with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f: + import json + json.dump(vocab_structured, f, ensure_ascii=False, indent=4) + # === 2. 创建输出目录 === + os.makedirs(output_dir, exist_ok=True) + + # === 3. 收集 MIDI 文件 === + midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \ + glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True) + midi_paths = list(midi_paths) + print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n") + + # === 4. 并行处理 === + results = [] + with ProcessPoolExecutor(max_workers=num_threads) as executor: + futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths} + + for future in tqdm(as_completed(futures), total=len(futures)): + res = future.result() + if res: + results.append(res) + + # === 5. 汇总结果 === + print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。") + + +if __name__ == "__main__": + midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录 + dataset_name = midi_directory.split("/")[-1] + tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8" + output_dir = tuneidx_prefix + preprocess_midi_directory(midi_directory, output_dir) \ No newline at end of file diff --git a/data_representation/test.py b/data_representation/test.py new file mode 100644 index 0000000..08dafa7 --- /dev/null +++ b/data_representation/test.py @@ -0,0 +1,14 @@ +import numpy as np + +# 读取 npz 文件 +data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True) + +# 查看保存的键 +print(data.files) +# 输出:['filename', 'sequence'] + +# 访问数据 +sequence = data["arr_0"] + +print("token 序列长度:", len(sequence)) +print("前 20 个 token:", sequence[:20]) \ No newline at end of file diff --git a/data_representation/vocab_utils.py b/data_representation/vocab_utils.py index 999385b..70638e9 100644 --- a/data_representation/vocab_utils.py +++ b/data_representation/vocab_utils.py @@ -1,5 +1,6 @@ import pickle from pathlib import Path +from re import L from typing import Union from multiprocessing import Pool, cpu_count from collections import defaultdict @@ -58,8 +59,8 @@ class LangTokenVocab: if in_vocab_file_path is not None: with open(in_vocab_file_path, 'r') as f: idx2event_temp = json.load(f) - if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb': - for key in idx2event_temp.keys(): + if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb' or self.encoding_scheme == 'oct': + for key in idx2event_temp.keys(): idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()} elif self.encoding_scheme == 'remi': idx2event_temp = {int(idx):tok for idx, tok in idx2event_temp.items()} @@ -71,13 +72,18 @@ class LangTokenVocab: # Extracts features depending on the number of features chosen (4, 5, 7, 8). def _get_features(self): - feature_args = { - 4: ["type", "beat", "pitch", "duration"], - 5: ["type", "beat", "instrument", "pitch", "duration"], - 7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"], - 8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]} - self.feature_list = feature_args[self.num_features] - + if self.encoding_scheme != 'oct': + feature_args = { + 4: ["type", "beat", "pitch", "duration"], + 5: ["type", "beat", "instrument", "pitch", "duration"], + 7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"], + 8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]} + self.feature_list = feature_args[self.num_features] + else: + feature_args = { + 7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"], + 8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]} + self.feature_list = feature_args[self.num_features] # Saves the current vocabulary to a specified JSON path. def save_vocab(self, json_path): with open(json_path, 'w') as f: @@ -93,13 +99,17 @@ class LangTokenVocab: self.sos_token = [self.event2idx['SOS_None']] self.eos_token = [[self.event2idx['EOS_None']]] else: - self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)] - self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)] - + if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb': + self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)] + self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)] + else: # oct + self.sos_token = [[self.event2idx['pitch']['BOS_None']] + [0] * (self.num_features - 1)] + self.eos_token = [[self.event2idx['pitch']['EOS_None']] + [0] * (self.num_features - 1)] + # Generates vocabularies by either loading from a file or creating them based on the event data. def _get_vocab(self, event_data, unique_vocabs=None): # make new vocab from given event_data - if event_data is not None: + if event_data is not None and self.encoding_scheme != 'oct': unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))])) unique_vocabs = sorted(unique_char_list) unique_vocabs.remove('SOS_None') @@ -119,6 +129,7 @@ class LangTokenVocab: # load premade vocab else: idx2event = unique_vocabs + print(idx2event) event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()} return idx2event, event2idx @@ -392,4 +403,47 @@ class MusicTokenVocabNB(MusicTokenVocabCP): unique_vocabs.insert(3, 'SSS') unique_vocabs.insert(4, 'SSN') unique_vocabs.insert(5, 'SNN') - return unique_vocabs \ No newline at end of file + return unique_vocabs + + +class MusicTokenVocabOct(LangTokenVocab): + def __init__( + self, + in_vocab_file_path:Union[Path, None], + event_data: list, + encoding_scheme: str, + num_features: int + ): + super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features) + + def _get_vocab(self, event_data, unique_vocabs=None): + if event_data is not None: + # Create vocab mappings (event2idx, idx2event) from the provided event data + print('start to get unique vocab') + event2idx = {} + idx2event = {} + unique_vocabs = defaultdict(set) + # Use multiprocessing to extract unique vocabularies for each event + with Pool(16) as p: + results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data])) + # Combine results from different processes + for result in results: + for key in self.feature_list: + unique_vocabs[key].update(result[key]) + # Process each feature type + for key in self.feature_list: + unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x))) + # Create event2idx and idx2event mappings for each feature + event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])} + idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])} + return idx2event, event2idx + else: + # If no event data, simply map unique vocab to indexes + event2idx = {} + for key in self.feature_list: + event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()} + return unique_vocabs, event2idx + + def get_vocab_size(self): + # Return the size of the vocabulary for each feature + return {key: len(self.idx2event[key]) for key in self.feature_list} \ No newline at end of file diff --git a/generate-batch.py b/generate-batch.py index 607f547..e533482 100644 --- a/generate-batch.py +++ b/generate-batch.py @@ -33,7 +33,9 @@ def get_argument_parser(): parser.add_argument( "-attr_list", type=str, - default="beat,duration", + # default="beat,duration,,instrument,tempo", + default="pitch", + # default='bar,position,velocity,duration,program,tempo,timesig', help="attribute list for attribute-controlled generation", ) parser.add_argument( @@ -69,7 +71,7 @@ def get_argument_parser(): parser.add_argument( "-num_target_measure", type=int, - default=4, + default=128, help="number of target measures for conditioned generation", ) parser.add_argument( @@ -86,13 +88,13 @@ def get_argument_parser(): parser.add_argument( "-num_processes", type=int, - default=1, + default=4, help="number of processes to use", ) parser.add_argument( "-gpu_ids", type=str, - default="1,2,3,5", + default="0,1,2,3,5", help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')", ) parser.add_argument( @@ -203,6 +205,7 @@ def attr_conditioned_worker(process_idx, gpu_id, args): # end_idx = min(chunk_size, len(selected_data)) # data_slice = selected_data[start_idx:end_idx] data_slice = selected_data + print("data_slice length:", len(data_slice)) # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ diff --git a/len_tunes/IrishMan/len_oct8.png b/len_tunes/IrishMan/len_oct8.png new file mode 100644 index 0000000..009e685 Binary files /dev/null and b/len_tunes/IrishMan/len_oct8.png differ diff --git a/len_tunes/Melody/len_nb8.png b/len_tunes/Melody/len_nb8.png index 01ef6ae..c9c29a1 100644 Binary files a/len_tunes/Melody/len_nb8.png and b/len_tunes/Melody/len_nb8.png differ diff --git a/len_tunes/Melody/len_oct8.png b/len_tunes/Melody/len_oct8.png new file mode 100644 index 0000000..f83eb5c Binary files /dev/null and b/len_tunes/Melody/len_oct8.png differ diff --git a/len_tunes/msmidi/len_oct8.png b/len_tunes/msmidi/len_oct8.png new file mode 100644 index 0000000..a0a4bdb Binary files /dev/null and b/len_tunes/msmidi/len_oct8.png differ diff --git a/midi_sim.py b/midi_sim.py index 44c5ba5..7ae08c2 100644 --- a/midi_sim.py +++ b/midi_sim.py @@ -129,6 +129,6 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", if __name__ == "__main__": - dir_a = "wandb/run-20251015_154556-f0pj3ys3/cond_4m_top_p_t0.99_temp1.25/process_2_batch_23" + dir_a = "wandb/run-20251027_161354-f9j1mwp2/uncond_min_p_t0.05_temp1.25_epochch8" dir_b = "dataset/Melody" batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6) \ No newline at end of file diff --git a/train_accelerate.py b/train_accelerate.py index 7279114..d286d9f 100644 --- a/train_accelerate.py +++ b/train_accelerate.py @@ -8,9 +8,6 @@ import torch import torch.multiprocessing as mp from torch.distributed import init_process_group, destroy_process_group -from accelerate import Accelerator -from accelerate.utils import set_seed - import wandb import hydra from hydra.core.hydra_config import HydraConfig @@ -20,6 +17,8 @@ from omegaconf import DictConfig, OmegaConf from accelerate import Accelerator from accelerate.utils import set_seed +from miditok import Octuple, TokenizerConfig + from Amadeus.symbolic_encoding import data_utils, decoding_utils from Amadeus.symbolic_encoding.data_utils import get_emb_total_size from Amadeus import model_zoo, trainer_accelerate as trainer @@ -99,7 +98,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La out_vocab_path = Path(save_dir) / f'vocab_{dataset_name}_{encoding_scheme}{num_features}.json' # get vocab - vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'} + vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'} selected_vocab_name = vocab_name[encoding_scheme] vocab = getattr(vocab_utils, selected_vocab_name)( @@ -159,7 +158,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La focal_gamma = config.train_params.focal_gamma if encoding_scheme == 'remi': loss_fn = NLLLoss4REMI(focal_alpha=focal_alpha, focal_gamma=focal_gamma) - elif encoding_scheme in ['cp', 'nb']: + elif encoding_scheme in ['cp', 'nb', 'oct']: if config.use_diff is False: loss_fn = NLLLoss4CompoundToken(feature_list=symbolic_dataset.vocab.feature_list, focal_alpha=focal_alpha, focal_gamma=focal_gamma) else: @@ -181,11 +180,11 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La in_beat_resolution = in_beat_resolution_dict[dataset_name] except KeyError: in_beat_resolution = 4 - midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'} + midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'} midi_decoder = getattr(decoding_utils, midi_decoder_dict[encoding_scheme])(vocab=symbolic_dataset.vocab, in_beat_resolution=in_beat_resolution, dataset_name=dataset_name) # Select trainer class based on encoding scheme - trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken'} + trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken', 'oct':'LanguageModelTrainer4CompoundToken'} trainer_option = trainer_option_dict[encoding_scheme] sampling_method = None sampling_threshold = 0.99 diff --git a/vocab.json b/vocab.json new file mode 100644 index 0000000..1a6e6ec --- /dev/null +++ b/vocab.json @@ -0,0 +1,442 @@ +{ + "config": { + "pitch_range": [ + 21, + 109 + ], + "beat_res": { + "0_4": 8, + "4_12": 4 + }, + "num_velocities": 32, + "remove_duplicated_notes": false, + "encode_ids_split": "bar", + "special_tokens": [ + "PAD_None", + "BOS_None", + "EOS_None", + "MASK_None" + ], + "use_velocities": true, + "use_note_duration_programs": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + -1 + ], + "use_chords": false, + "use_rests": false, + "use_tempos": true, + "use_time_signatures": true, + "use_sustain_pedals": false, + "use_pitch_bends": false, + "use_programs": false, + "use_pitch_intervals": false, + "use_pitchdrum_tokens": true, + "default_note_duration": 0.5, + "programs": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + -1 + ], + "one_token_stream_for_programs": false, + "program_changes": false, + "beat_res_rest": { + "0_1": 8, + "1_2": 4, + "2_12": 2 + }, + "chord_maps": { + "min": [ + 0, + 3, + 7 + ], + "maj": [ + 0, + 4, + 7 + ], + "dim": [ + 0, + 3, + 6 + ], + "aug": [ + 0, + 4, + 8 + ], + "sus2": [ + 0, + 2, + 7 + ], + "sus4": [ + 0, + 5, + 7 + ], + "7dom": [ + 0, + 4, + 7, + 10 + ], + "7min": [ + 0, + 3, + 7, + 10 + ], + "7maj": [ + 0, + 4, + 7, + 11 + ], + "7halfdim": [ + 0, + 3, + 6, + 10 + ], + "7dim": [ + 0, + 3, + 6, + 9 + ], + "7aug": [ + 0, + 4, + 8, + 11 + ], + "9maj": [ + 0, + 4, + 7, + 10, + 14 + ], + "9min": [ + 0, + 4, + 7, + 10, + 13 + ] + }, + "chord_tokens_with_root_note": false, + "chord_unknown": null, + "num_tempos": 32, + "tempo_range": [ + 40, + 250 + ], + "log_tempos": false, + "delete_equal_successive_tempo_changes": true, + "time_signature_range": { + "8": [ + 3, + 12, + 6 + ], + "4": [ + 5, + 6, + 3, + 2, + 1, + 4 + ] + }, + "delete_equal_successive_time_sig_changes": false, + "sustain_pedal_duration": false, + "pitch_bend_range": [ + -8192, + 8191, + 32 + ], + "max_pitch_interval": 16, + "pitch_intervals_max_time_dist": 1, + "drums_pitch_range": [ + 27, + 88 + ], + "ac_polyphony_track": false, + "ac_polyphony_bar": false, + "ac_polyphony_min": 1, + "ac_polyphony_max": 6, + "ac_pitch_class_bar": false, + "ac_note_density_track": false, + "ac_note_density_track_min": 0, + "ac_note_density_track_max": 18, + "ac_note_density_bar": false, + "ac_note_density_bar_max": 18, + "ac_note_duration_bar": false, + "ac_note_duration_track": false, + "ac_repetition_track": false, + "ac_repetition_track_num_bins": 10, + "ac_repetition_track_num_consec_bars": 4, + "additional_params": { + "max_bar_embedding": 60 + } + }, + "tokenization": "Octuple", + "miditok_version": "3.0.6.post1", + "symusic_version": "0.5.8", + "hf_tokenizers_version": "0.21.2" +} \ No newline at end of file