diff --git a/.gitignore b/.gitignore index 8caba3c..3ae93de 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ wandb/ .vscode/ checkpoints/ metadata/ +*.sf2 diff --git a/Amadeus/custom_x_transformers.py b/Amadeus/custom_x_transformers.py index 64759b0..fbfb538 100644 --- a/Amadeus/custom_x_transformers.py +++ b/Amadeus/custom_x_transformers.py @@ -1358,6 +1358,7 @@ class Attention(Module): dim_latent_kv = None, latent_rope_subheads = None, onnxable = False, + use_gated_attention = False, # https://arxiv.org/abs/2505.06708 attend_sdp_kwargs: dict = dict( enable_flash = True, enable_math = True, @@ -1387,6 +1388,7 @@ class Attention(Module): k_dim = dim_head * kv_heads v_dim = value_dim_head * kv_heads out_dim = value_dim_head * heads + gated_dim = out_dim # determine input dimensions to qkv based on whether intermediate latent q and kv are being used # for eventually supporting multi-latent attention (MLA) @@ -1447,7 +1449,8 @@ class Attention(Module): self.to_v_gate = None if gate_values: - self.to_v_gate = nn.Linear(dim, out_dim) + # self.to_v_gate = nn.Linear(dim, out_dim) + self.to_v_gate = nn.Linear(dim_kv_input, gated_dim) self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid nn.init.constant_(self.to_v_gate.weight, 0) nn.init.constant_(self.to_v_gate.bias, 10) diff --git a/Amadeus/sampling_utils.py b/Amadeus/sampling_utils.py index 28f652b..c5742ca 100644 --- a/Amadeus/sampling_utils.py +++ b/Amadeus/sampling_utils.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F - + def top_p_sampling(logits, thres=0.9): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -84,7 +84,7 @@ def sample_with_prob(logits, sampling_method, threshold, temperature): # temporarily apply the sampling method to logits logits = logits / temperature # logits = add_gumbel_noise(logits, temperature) - + if sampling_method == "top_p": modified_logits = top_p_sampling(logits, thres=threshold) elif sampling_method == "typical": diff --git a/Amadeus/symbolic_encoding/data_utils.py b/Amadeus/symbolic_encoding/data_utils.py index 0d23e84..23e4583 100644 --- a/Amadeus/symbolic_encoding/data_utils.py +++ b/Amadeus/symbolic_encoding/data_utils.py @@ -530,6 +530,62 @@ class Melody(SymbolicMusicDataset): test_names.extend(song_dict[song_name]) return train_names, valid_names, test_names, shuffled_tune_names +class msmidi(SymbolicMusicDataset): + def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, + for_evaluation: bool = False): + super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, + for_evaluation=for_evaluation) + + def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]: + ''' + Irregular tunes are removed from the dataset for better generation quality + It includes tunes that are not quantized properly, mostly theay are expressive performance data + ''' + print("preprocessed tune_in_idx data is being loaded") + tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz"))) + if self.debug: + tune_in_idx_list = tune_in_idx_list[:5000] + tune_in_idx_dict = OrderedDict() + len_tunes = OrderedDict() + file_name_list = [] + with open("metadata/LakhClean_irregular_tunes.json", "r") as f: + irregular_tunes = json.load(f) + for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): + if tune_in_idx_file.stem in irregular_tunes: + continue + tune_in_idx = np.load(tune_in_idx_file)['arr_0'] + tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx + len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) + file_name_list.append(tune_in_idx_file.stem) + print(f"number of loaded tunes: {len(tune_in_idx_dict)}") + return tune_in_idx_dict, len_tunes, file_name_list + + def _get_split_list_from_tune_in_idx(self, ratio, seed): + ''' + As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name + ''' + shuffled_tune_names = list(self.tune_in_idx.keys()) + song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] + song_dict = {} + for song, orig_song in zip(song_names_without_version, shuffled_tune_names): + if song not in song_dict: + song_dict[song] = [] + song_dict[song].append(orig_song) + unique_song_names = list(song_dict.keys()) + random.seed(seed) + random.shuffle(unique_song_names) + num_train = int(len(unique_song_names)*ratio) + num_valid = int(len(unique_song_names)*(1-ratio)/2) + train_names = [] + valid_names = [] + test_names = [] + for song_name in unique_song_names[:num_train]: + train_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train:num_train+num_valid]: + valid_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train+num_valid:]: + test_names.extend(song_dict[song_name]) + return train_names, valid_names, test_names, shuffled_tune_names class IrishMan(SymbolicMusicDataset): def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, @@ -648,62 +704,62 @@ class ariamidi(SymbolicMusicDataset): test_names.extend(song_dict[song_name]) return train_names, valid_names, test_names, shuffled_tune_names -class gigamidi(SymbolicMusicDataset): - def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, - for_evaluation: bool = False): - super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, - for_evaluation=for_evaluation) +# class gigamidi(SymbolicMusicDataset): +# def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, +# for_evaluation: bool = False): +# super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, +# for_evaluation=for_evaluation) - def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]: - ''' - Irregular tunes are removed from the dataset for better generation quality - It includes tunes that are not quantized properly, mostly theay are expressive performance data - ''' - print("preprocessed tune_in_idx data is being loaded") - tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz"))) - if self.debug: - tune_in_idx_list = tune_in_idx_list[:5000] - tune_in_idx_dict = OrderedDict() - len_tunes = OrderedDict() - file_name_list = [] - with open("metadata/LakhClean_irregular_tunes.json", "r") as f: - irregular_tunes = json.load(f) - for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): - if tune_in_idx_file.stem in irregular_tunes: - continue - tune_in_idx = np.load(tune_in_idx_file)['arr_0'] - tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx - len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) - file_name_list.append(tune_in_idx_file.stem) - print(f"number of loaded tunes: {len(tune_in_idx_dict)}") - return tune_in_idx_dict, len_tunes, file_name_list +# def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]: +# ''' +# Irregular tunes are removed from the dataset for better generation quality +# It includes tunes that are not quantized properly, mostly theay are expressive performance data +# ''' +# print("preprocessed tune_in_idx data is being loaded") +# tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz"))) +# if self.debug: +# tune_in_idx_list = tune_in_idx_list[:5000] +# tune_in_idx_dict = OrderedDict() +# len_tunes = OrderedDict() +# file_name_list = [] +# with open("metadata/LakhClean_irregular_tunes.json", "r") as f: +# irregular_tunes = json.load(f) +# for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): +# if tune_in_idx_file.stem in irregular_tunes: +# continue +# tune_in_idx = np.load(tune_in_idx_file)['arr_0'] +# tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx +# len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) +# file_name_list.append(tune_in_idx_file.stem) +# print(f"number of loaded tunes: {len(tune_in_idx_dict)}") +# return tune_in_idx_dict, len_tunes, file_name_list - def _get_split_list_from_tune_in_idx(self, ratio, seed): - ''' - As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name - ''' - shuffled_tune_names = list(self.tune_in_idx.keys()) - song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] - song_dict = {} - for song, orig_song in zip(song_names_without_version, shuffled_tune_names): - if song not in song_dict: - song_dict[song] = [] - song_dict[song].append(orig_song) - unique_song_names = list(song_dict.keys()) - random.seed(seed) - random.shuffle(unique_song_names) - num_train = int(len(unique_song_names)*ratio) - num_valid = int(len(unique_song_names)*(1-ratio)/2) - train_names = [] - valid_names = [] - test_names = [] - for song_name in unique_song_names[:num_train]: - train_names.extend(song_dict[song_name]) - for song_name in unique_song_names[num_train:num_train+num_valid]: - valid_names.extend(song_dict[song_name]) - for song_name in unique_song_names[num_train+num_valid:]: - test_names.extend(song_dict[song_name]) - return train_names, valid_names, test_names, shuffled_tune_names +# def _get_split_list_from_tune_in_idx(self, ratio, seed): +# ''' +# As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name +# ''' +# shuffled_tune_names = list(self.tune_in_idx.keys()) +# song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] +# song_dict = {} +# for song, orig_song in zip(song_names_without_version, shuffled_tune_names): +# if song not in song_dict: +# song_dict[song] = [] +# song_dict[song].append(orig_song) +# unique_song_names = list(song_dict.keys()) +# random.seed(seed) +# random.shuffle(unique_song_names) +# num_train = int(len(unique_song_names)*ratio) +# num_valid = int(len(unique_song_names)*(1-ratio)/2) +# train_names = [] +# valid_names = [] +# test_names = [] +# for song_name in unique_song_names[:num_train]: +# train_names.extend(song_dict[song_name]) +# for song_name in unique_song_names[num_train:num_train+num_valid]: +# valid_names.extend(song_dict[song_name]) +# for song_name in unique_song_names[num_train+num_valid:]: +# test_names.extend(song_dict[song_name]) +# return train_names, valid_names, test_names, shuffled_tune_names class ariamidi(SymbolicMusicDataset): def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, @@ -788,6 +844,9 @@ class gigamidi(SymbolicMusicDataset): for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): if tune_in_idx_file.stem in irregular_tunes: continue + if "drums-only" in tune_in_idx_file.stem: + print(f"skipping {tune_in_idx_file.stem} as it is a drums-only file") + continue tune_in_idx = np.load(tune_in_idx_file)['arr_0'] tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) diff --git a/Amadeus/symbolic_encoding/midi2audio.py b/Amadeus/symbolic_encoding/midi2audio.py index ddbae0f..a037a0d 100644 --- a/Amadeus/symbolic_encoding/midi2audio.py +++ b/Amadeus/symbolic_encoding/midi2audio.py @@ -11,7 +11,7 @@ License: MIT, see the LICENSE file __all__ = ['FluidSynth'] -DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2' +DEFAULT_SOUND_FONT = 'Alex_GM.sf2' DEFAULT_SAMPLE_RATE = 48000 DEFAULT_GAIN = 0.05 # DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2" diff --git a/Amadeus/symbolic_yamls/config-accelerate.yaml b/Amadeus/symbolic_yamls/config-accelerate.yaml index fef73f4..5dc78b8 100644 --- a/Amadeus/symbolic_yamls/config-accelerate.yaml +++ b/Amadeus/symbolic_yamls/config-accelerate.yaml @@ -2,7 +2,8 @@ defaults: # - nn_params: nb8_embSum_NMT # - nn_params: remi8 # - nn_params: nb8_embSum_diff_t2m_150M_finetunning - - nn_params: nb8_embSum_diff_t2m_150M_pretrainingv2 + # - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2 + - nn_params: nb8_embSum_diff_t2m_600M_finetunningv2 # - nn_params: nb8_embSum_subPararell # - nn_params: nb8_embSum_diff_t2m_150M @@ -14,7 +15,7 @@ defaults: # - nn_params: remi8_main12_head_16_dim512 # - nn_params: nb5_embSum_diff_main12head16dim768_sub3 -dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset +dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset captions_path: dataset/midicaps/train_set.json # dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean @@ -30,20 +31,20 @@ tau: 0.5 train_params: device: cuda - batch_size: 3 + batch_size: 5 grad_clip: 1.0 num_iter: 300000 # total number of iterations num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint iterations_per_training_cycle: 10 # number of iterations for logging training loss - iterations_per_validation_cycle: 5000 # number of iterations for validation process + iterations_per_validation_cycle: 3000 # number of iterations for validation process input_length: 3072 # input sequence length3072 # you can use focal loss, it it's not used, set focal_gamma to 0 focal_alpha: 1 focal_gamma: 0 # learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details scheduler : cosinelr - initial_lr: 0.00005 + initial_lr: 0.0004 decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts' warmup_steps: 2000 #number of warmup steps diff --git a/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_150M_pretrainingv2.yaml b/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_150M_pretrainingv2.yaml index ada3f49..d7d7550 100644 --- a/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_150M_pretrainingv2.yaml +++ b/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_150M_pretrainingv2.yaml @@ -5,13 +5,13 @@ model_name: AmadeusModel input_embedder_name: SummationEmbedder main_decoder_name: XtransformerNewPretrainingDecoder sub_decoder_name: DiffusionDecoder -model_dropout: 0 +model_dropout: 0.2 input_embedder: num_layer: 1 num_head: 8 main_decoder: dim_model: 768 - num_layer: 20 + num_layer: 16 num_head: 12 sub_decoder: decout_window_size: 1 # 1 means no previous decoding output added diff --git a/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_600M_finetunningv2.yaml b/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_600M_finetunningv2.yaml new file mode 100644 index 0000000..90406c4 --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_600M_finetunningv2.yaml @@ -0,0 +1,19 @@ +encoding_scheme: nb +num_features: 8 +vocab_name: MusicTokenVocabNB +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewFinetunningDecoder +sub_decoder_name: DiffusionDecoder +model_dropout: 0 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 1024 + num_layer: 32 + num_head: 16 +sub_decoder: + decout_window_size: 1 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_600M_pretrainingv2.yaml b/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_600M_pretrainingv2.yaml new file mode 100644 index 0000000..9a39556 --- /dev/null +++ b/Amadeus/symbolic_yamls/nn_params/nb8_embSum_diff_t2m_600M_pretrainingv2.yaml @@ -0,0 +1,19 @@ +encoding_scheme: nb +num_features: 8 +vocab_name: MusicTokenVocabNB +model_name: AmadeusModel +input_embedder_name: SummationEmbedder +main_decoder_name: XtransformerNewPretrainingDecoder +sub_decoder_name: DiffusionDecoder +model_dropout: 0 +input_embedder: + num_layer: 1 + num_head: 8 +main_decoder: + dim_model: 1024 + num_layer: 32 + num_head: 16 +sub_decoder: + decout_window_size: 1 # 1 means no previous decoding output added + num_layer: 1 + feature_enricher_use: False \ No newline at end of file diff --git a/Amadeus/transformer_utils.py b/Amadeus/transformer_utils.py index 6fb776a..f47ee17 100644 --- a/Amadeus/transformer_utils.py +++ b/Amadeus/transformer_utils.py @@ -729,9 +729,8 @@ class XtransformerNewPretrainingDecoder(nn.Module): rotary_pos_emb = True, ff_swish = True, # set this to True ff_glu = True, # set to true to use for all feedforwards - # shift_tokens = 1, - # attn_qk_norm = True, - # attn_qk_norm_dim_scale = True + attn_gate_values = True, + attn_qk_norm = True, ) # add final dropout print('Applying Xavier Uniform Init to x-transformer following torch.Transformer') @@ -758,7 +757,7 @@ class XtransformerNewPretrainingDecoder(nn.Module): def _apply_xavier_init(self): for name, param in self.transformer_decoder.named_parameters(): - if 'to_q' in name or 'to_k' in name or 'to_v' in name: + if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name: torch.nn.init.xavier_uniform_(param, gain=0.5**0.5) def forward(self, seq, cache=None,train=False,context=None, context_embedding=None): @@ -906,6 +905,102 @@ class XtransformerFinetuningDecoder(nn.Module): else: context = context_embedding + # concatenate context with seq + seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size + if cache is not None: # implementing run_one_step in inference + if cache.hiddens is None: cache = None + hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True) + # cut to only return the seq part + return hidden_vec, intermediates + else: + if train: + hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True) + # cut to only return the seq part + hidden_vec = hidden_vec[:, context.shape[1]:, :] + return hidden_vec, intermediates + else: + # cut to only return the seq part + hidden_vec = self.transformer_decoder(seq) + hidden_vec = hidden_vec[:, context.shape[1]:, :] + return hidden_vec + +class XtransformerNewFinetunningDecoder(nn.Module): + def __init__( + self, + dim:int, + depth:int, + heads:int, + dropout:float + ): + super().__init__() + self._make_decoder_layer(dim, depth, heads, dropout) + self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base') + # frozen text encoder + for param in self.text_encoder.parameters(): + param.requires_grad = False + if dim != 768: + self.text_project = nn.Linear(768, dim) # assuming T5 base hidden size is 768 + + def _make_decoder_layer(self, dim, depth, heads, dropout): + self.transformer_decoder = Decoder( + dim = dim, + depth = depth, + heads = heads, + attn_dropout = dropout, + ff_dropout = dropout, + attn_flash = True, + use_rmsnorm=True, + rotary_pos_emb = True, + ff_swish = True, # set this to True + ff_glu = True, # set to true to use for all feedforwards + attn_gate_values = True, + attn_qk_norm = True, + ) # add final dropout + print('Applying Xavier Uniform Init to x-transformer following torch.Transformer') + self._apply_xavier_init() + print('Adding dropout after feedforward layer in x-transformer') + self._add_dropout_after_ff(dropout) + print('Adding dropout after attention layer in x-transformer') + self._add_dropout_after_attn(dropout) + + def _add_dropout_after_attn(self, dropout): + for layer in self.transformer_decoder.layers: + if 'Attention' in str(type(layer[1])): + if isinstance(layer[1].to_out, nn.Sequential): # if GLU + layer[1].to_out.append(nn.Dropout(dropout)) + elif isinstance(layer[1].to_out, nn.Linear): # if simple linear + layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout)) + else: + raise ValueError('to_out should be either nn.Sequential or nn.Linear') + + def _add_dropout_after_ff(self, dropout): + for layer in self.transformer_decoder.layers: + if 'FeedForward' in str(type(layer[1])): + layer[1].ff.append(nn.Dropout(dropout)) + + def _apply_xavier_init(self): + for name, param in self.transformer_decoder.named_parameters(): + if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name: + torch.nn.init.xavier_uniform_(param, gain=0.5**0.5) + + def forward(self, seq, cache=None,train=False,context=None,context_embedding=None): + assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder' + if context_embedding is None: + input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids'] + attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask'] + assert input_ids is not None, 'input_ids should be provided for prefix decoder' + assert attention_mask is not None, 'attention_mask should be provided for prefix decoder' + assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder' + + context = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ).last_hidden_state + else: + context = context_embedding + if hasattr(self, 'text_project'): + context = self.text_project(context) + # concatenate context with seq seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size if cache is not None: # implementing run_one_step in inference diff --git a/data_representation/step1_midi2corpus_fined.py b/data_representation/step1_midi2corpus_fined.py index e42cbfa..49c8f25 100644 --- a/data_representation/step1_midi2corpus_fined.py +++ b/data_representation/step1_midi2corpus_fined.py @@ -113,7 +113,7 @@ class CorpusMaker(): 0 to 2000 means no limitation ''' # last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)} - last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)} + last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (8, 3000), 'Symphony': (60, 1500)} try: self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name] except KeyError: diff --git a/generate-batch.py b/generate-batch.py index f43f1c0..f371b3f 100644 --- a/generate-batch.py +++ b/generate-batch.py @@ -105,6 +105,7 @@ def load_resources(wandb_exp_dir, device): config = wandb_style_config_to_omega_config(config) # Load checkpoint to specified device + print("Loading checkpoint from:", ckpt_path) ckpt = torch.load(ckpt_path, map_location=device) model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path) model.load_state_dict(ckpt['model'], strict=False) diff --git a/len_tunes/FinetuneDataset/len_nb8.png b/len_tunes/FinetuneDataset/len_nb8.png new file mode 100644 index 0000000..0fe1b87 Binary files /dev/null and b/len_tunes/FinetuneDataset/len_nb8.png differ diff --git a/len_tunes/IrishMan/len_nb8.png b/len_tunes/IrishMan/len_nb8.png new file mode 100644 index 0000000..fab7942 Binary files /dev/null and b/len_tunes/IrishMan/len_nb8.png differ diff --git a/len_tunes/gigamidi/len_nb8.png b/len_tunes/gigamidi/len_nb8.png new file mode 100644 index 0000000..cba11fc Binary files /dev/null and b/len_tunes/gigamidi/len_nb8.png differ diff --git a/midi_stastic.py b/midi_stastic.py new file mode 100644 index 0000000..eae24ba --- /dev/null +++ b/midi_stastic.py @@ -0,0 +1,442 @@ +#!/usr/bin/env python3 +""" +MIDI Statistics Extractor + +Usage: python midi_statistics.py [options] + +This script traverses a directory and all subdirectories to find MID files, +extracts musical features from each file using multi-threading for speed, +and saves the results to CSV files. +""" + +import argparse +import pathlib +import os +import csv +import json +from multiprocessing import Pool +from itertools import chain +from math import ceil +from functools import partial + +import numpy as np +from numpy.lib.stride_tricks import sliding_window_view +from symusic import Score +import pandas as pd +from tqdm import tqdm +from numba import njit, prange + + +@njit +def merge_intervals(intervals: list[tuple[int, int]], threshold: int): + """Merge overlapping or close intervals.""" + out = [] + last_s, last_e = intervals[0] + + for i in range(1, len(intervals)): + s, e = intervals[i] + + if s - last_e <= threshold: + if e > last_e: + last_e = e + else: + out.append((last_s, last_e)) + last_s, last_e = s, e + + out.append((last_s, last_e)) + return out + + +@njit(fastmath=True) +def note_distribution(events: list[tuple[float, int]], threshold: int = 2, segment_threshold: int = 0): + """Calculate polyphony rate and sounding segments.""" + try: + events.sort() + active_notes = 0 + polyphonic_steps = 0 + total_steps = 0 + last_time = None + last_state = False + last_seg_start = 0 + sounding_segments = [] + + for time, change in events: + if last_time is not None and time != last_time: + if active_notes >= threshold: + polyphonic_steps += (time - last_time) + if active_notes: + total_steps += (time - last_time) + if(last_state != bool(active_notes)): + if(last_state): + last_seg_start = time + else: + sounding_segments.append((last_seg_start, time)) + + active_notes += change + last_state = bool(active_notes) + last_time = time + + if(segment_threshold != 0): + sounding_segments = merge_intervals(sounding_segments, segment_threshold) + + return polyphonic_steps / total_steps, total_steps, sounding_segments + except: + return None, None, None + + +@njit(fastmath=True) +def entropy(X: np.ndarray, base: float = 2.0) -> float: + """Calculate entropy function optimized with numba.""" + N, M = X.shape + out = np.empty(N, dtype=np.float64) + log_base = np.log(base) if base > 0.0 else 1.0 + + for i in prange(N): + row = X[i] + total = np.nansum(row) + if total <= 0.0: + out[i] = 0.0 + continue + + mask = (~np.isnan(row)) & (row > 0.0) + probs = row[mask] / total + if probs.size == 0: + out[i] = 0.0 + else: + H = -np.sum(probs * np.log(probs)) + if base > 0.0: + H /= log_base + out[i] = H + + nz = out > 0.0 + if not np.any(nz): + return 0.0 + return float(np.exp(np.mean(np.log(out[nz])))) + + +@njit(fastmath=True) +def n_gram_co_occurence_entropy(seq: list[list[int]], N: int = 5): + """Calculate n-gram co-occurrence entropy.""" + counts = [] + + for seg in seq: + if len(seg) < 2: + continue + + arr = np.asarray(seg, dtype=np.int64) + + min_val = np.min(arr) + if min_val < 0: + arr = arr - min_val + + vocabs = int(np.max(arr) + 1) + + wlen = N if len(arr) >= N else len(arr) + nwin = len(arr) - wlen + 1 + + C = np.zeros((vocabs, vocabs), dtype=np.int64) + + for start in range(nwin): + for i in range(wlen - 1): + a = int(arr[start + i]) + for j in range(i + 1, wlen): + b = int(arr[start + j]) + if a < vocabs and b < vocabs: + C[a, b] += 1 + + for i in range(vocabs): + counts.append(int(C[i, i])) + for j in range(i + 1, vocabs): + counts.append(int(C[i, j])) + + total = 0 + for v in counts: + total += v + + if total <= 0: + return 0.0 + + H = 0.0 + for v in counts: + if v > 0: + p = v / total + H -= p * np.log(p) + + return H + + +def calc_pitch_distribution(pitches: np.ndarray, window_size: int = 32, hop_size: int = 16): + """Calculate pitch distribution features.""" + sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(pitches) > window_size else (lambda x: x.reshape(1, -1)) + + used_pitches = np.unique(pitches) + n_pitches_used = len(used_pitches) + pitch_entropy = entropy(sw(pitches)) + pitch_range = [int(min(used_pitches)), int(max(used_pitches))] + + pitch_classes = pitches % 12 + n_pitch_classes_used = len(np.unique(pitch_classes)) + pitch_class_entropy = entropy(sw(pitch_classes)) + + return n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range + + +def calc_rhythmic_entropy(ioi: np.ndarray, window_size: int = 32, hop_size: int = 16): + """Calculate rhythmic entropy.""" + sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(ioi) > window_size else (lambda x: x.reshape(1, -1)) + if(len(ioi) == 0): + return None + return entropy(sw(ioi)) + + +def extract_features(midi_path: pathlib.Path, tpq: int = 6): + """Extract features from a single MIDI file.""" + try: + seg_threshold = tpq * 8 + midi_id = midi_path.parent.name + '/' + midi_path.stem + score = Score(midi_path).resample(tpq) + + track_features = [] + for i, t in enumerate(score.tracks): + if(not len(t.notes)): + track_features.append(( + midi_id, # midi_id + i, # track_id + 128 if t.is_drum else t.program, # instrument + + 0, # end_time + 0, # note_num + None, # sounding_interval + + None, # note_density + None, # polyphony_rate + None, # rhythmic_entropy + None, # rhythmic_token_co_occurrence_entropy + + None, # n_pitch_classes_used + None, # n_pitches_used + None, # pitch_class_entropy + None, # pitch_entropy + None, # pitch_range + None # interval_token_co_occurrence_entropy + )) + continue + t.sort() + + features = t.notes.numpy() + + ioi = np.diff(features['time']) + seg_points = np.where(ioi > tpq * seg_threshold)[0] + + polyphony_rate, sounding_interval_length, sounding_segment = note_distribution(list(chain(* + [((note.start, 1), (note.end, -1)) for note in t.notes]))) + rhythmic_entropy = calc_rhythmic_entropy(ioi) + + rhythmic_token_co_occurrence_entropy = n_gram_co_occurence_entropy([i for i in np.split(ioi, seg_points) if np.all(i) <= seg_threshold]) + + if(t.is_drum or len(t.notes) < 2): + track_features.append(( + midi_id, # midi_id + i, # track_id + 128 if t.is_drum else t.program, # instrument + + t.end(), # end_time + len(t.notes), # note_num + sounding_interval_length, # sounding_interval + + len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density + polyphony_rate, # polyphony_rate + rhythmic_entropy, # rhythmic_entropy + rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy + + None, # n_pitch_classes_used + None, # n_pitches_used + None, # pitch_class_entropy + None, # pitch_entropy + None, # pitch_range + None # interval_token_co_occurrence_entropy + )) + else: + n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range = calc_pitch_distribution(features['pitch']) + intervals = np.diff(features['pitch']) + track_features.append(( + midi_id, # midi_id + i, # track_id + t.program, # instrument + + t.end(), # end_time + len(t.notes), # note_num + sounding_interval_length, # sounding_interval + + len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density + polyphony_rate, # polyphony_rate + rhythmic_entropy, # rhythmic_entropy + rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy + + n_pitch_classes_used, # n_pitch_classes_used + n_pitches_used, # n_pitches_used + pitch_class_entropy, # pitch_class_entropy + pitch_entropy, # pitch_entropy + json.dumps(pitch_range), # pitch_range + n_gram_co_occurence_entropy([p for i, p in zip(np.split(ioi, seg_points), np.split(intervals, seg_points)) if np.all(i) <= seg_threshold]) # interval_token_co_occurrence_entropy + )) + + score_features = ( + midi_id, # midi_id + sum(tf[4] for tf in track_features) if track_features else 0, # note_num + max(tf[3] for tf in track_features) if track_features else 0, # end_time + json.dumps([[ks.time, ks.key, ks.tonality] for ks in score.key_signatures]), # key + json.dumps([[ts.time, ts.numerator, ts.denominator] for ts in score.time_signatures]), # time_signature + json.dumps([[t.time, t.qpm] for t in score.tempos]) # tempo + ) + + return score_features, track_features + except Exception as e: + print(f"Error processing {midi_path}: {e}") + return None, None + + +def find_midi_files(directory: pathlib.Path): + """Find all MIDI files in directory and subdirectories.""" + midi_extensions = {'.mid', '.midi', '.MID', '.MIDI'} + midi_files = [] + + # Use rglob to recursively find MIDI files + for file_path in directory.rglob('*'): + if file_path.is_file() and file_path.suffix in midi_extensions: + midi_files.append(file_path) + + return midi_files + + +def process_midi_files(directory: pathlib.Path, output_prefix: str = "midi_features", + num_threads: int = 4, tpq: int = 6): + """Process MIDI files with multi-threading and save to CSV.""" + + # Find all MIDI files + print(f"Searching for MIDI files in: {directory}") + midi_files = find_midi_files(directory) + + if not midi_files: + print(f"No MIDI files found in {directory}") + return + + print(f"Found {len(midi_files)} MIDI files") + + # Create extractor function with fixed parameters + extractor = partial(extract_features, tpq=tpq) + + # Feature column names + score_feat_cols = ['midi_id', 'note_num', 'end_time', 'key', 'time_signature', 'tempo'] + track_feat_cols = ['midi_id', 'track_id', 'instrument', 'end_time', 'note_num', + 'sounding_interval', 'note_density', 'polyphony_rate', 'rhythmic_entropy', + 'rhythmic_token_co_occurrence_entropy', 'n_pitch_classes_used', + 'n_pitches_used', 'pitch_class_entropy', 'pitch_entropy', 'pitch_range', + 'interval_token_co_occurrence_entropy'] + + # Process files with multiprocessing + print(f"Processing files with {num_threads} threads...") + + with Pool(num_threads) as pool: + # Open CSV files for writing + with open(f'{output_prefix}_score_features.csv', 'w', newline='', encoding='utf-8') as score_csvfile: + score_writer = csv.writer(score_csvfile) + score_writer.writerow(score_feat_cols) + + with open(f'{output_prefix}_track_features.csv', 'w', newline='', encoding='utf-8') as track_csvfile: + track_writer = csv.writer(track_csvfile) + track_writer.writerow(track_feat_cols) + + # Process files with progress bar + processed_count = 0 + skipped_count = 0 + + for score_feat, track_feats in tqdm(pool.imap_unordered(extractor, midi_files), + total=len(midi_files), + desc="Processing MIDI files"): + if not (score_feat, track_feats): + skipped_count += 1 + continue + + processed_count += 1 + + # Write score features + score_writer.writerow(score_feat) + + # Write track features + if track_feats: + track_writer.writerows(track_feats) + + print(f"\nProcessing complete!") + print(f"Successfully processed: {processed_count} files") + print(f"Skipped due to errors: {skipped_count} files") + print(f"Score features saved to: {output_prefix}_score_features.csv") + print(f"Track features saved to: {output_prefix}_track_features.csv") + + +def main(): + """Main function with command line argument parsing.""" + parser = argparse.ArgumentParser( + description="Extract musical features from MIDI files and save to CSV", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python midi_statistics.py /path/to/midi/files + python midi_statistics.py /path/to/midi/files --threads 8 --output my_features + python midi_statistics.py /path/to/midi/files --tpq 12 --threads 2 + +Features extracted: + - Score level: note count, end time, key signatures, time signatures, tempo + - Track level: instrument, note density, polyphony rate, rhythmic entropy, + pitch distribution, and more + """ + ) + + parser.add_argument('directory', + help='Path to directory containing MIDI files') + + parser.add_argument('--threads', '-t', + type=int, + default=4, + help='Number of threads to use (default: 4)') + + parser.add_argument('--output', '-o', + type=str, + default='midi_features', + help='Output file prefix (default: midi_features)') + + parser.add_argument('--tpq', + type=int, + default=6, + help='Ticks per quarter note for resampling (default: 6)') + + args = parser.parse_args() + + # Validate directory + directory = pathlib.Path(args.directory) + if not directory.exists(): + print(f"Error: Directory '{directory}' does not exist") + return 1 + + if not directory.is_dir(): + print(f"Error: '{directory}' is not a directory") + return 1 + + # Validate threads + if args.threads < 1: + print("Error: Number of threads must be at least 1") + return 1 + + try: + process_midi_files(directory, args.output, args.threads, args.tpq) + return 0 + except KeyboardInterrupt: + print("\nProcessing interrupted by user") + return 1 + except Exception as e: + print(f"Error: {e}") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/,idi_sim.py b/,idi_sim.py new file mode 100644 index 0000000..dfe9488 --- /dev/null +++ b/,idi_sim.py @@ -0,0 +1,105 @@ +import os +import numpy as np +import pandas as pd +from symusic import Score +from concurrent.futures import ProcessPoolExecutor, as_completed + +semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2]) + +def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True): + if(not a.shape[1] or not b.shape[1]): + return np.inf + a_onset, a_pitch = a + b_onset, b_pitch = b + a_onset = a_onset.astype(np.float32) + b_onset = b_onset.astype(np.float32) + a_pitch = a_pitch.astype(np.uint8) + b_pitch = b_pitch.astype(np.uint8) + + onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1)) + if(oti): + pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12] + dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight) + a2b = dist_matrix.min(2) + b2a = dist_matrix.min(1) + dist = np.concatenate([a2b, b2a], axis=1) + return dist.sum(axis=1).min() / len(dist) + else: + pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12] + dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight) + a2b = dist_matrix.min(1) + b2a = dist_matrix.min(0) + return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1])) + + +def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.): + x = sorted(x) + end_time = x[-1][0] + out = [[] for _ in range(int(end_time // hop_size))] + for i in sorted(x): + segment = min(int(i[0] // hop_size), len(out) - 1) + while(i[0] >= segment * hop_size): + out[segment].append(i) + segment -= 1 + if(segment < 0): + break + return out + + +def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4): + a = midi_time_sliding_window(a) + b = midi_time_sliding_window(b) + dist = np.inf + for i in a: + for j in b: + cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T) + if(cur_dist < dist): + dist = cur_dist + return dist + + +def extract_notes(filepath: str): + """读取MIDI并返回 (time, pitch) 列表""" + try: + s = Score(filepath).to("quarter") + notes = [] + for t in s.tracks: + notes.extend([(n.time, n.pitch) for n in t.notes]) + return notes + except Exception as e: + print(f"读取 {filepath} 出错: {e}") + return [] + + +def compare_pair(file_a: str, file_b: str): + notes_a = extract_notes(file_a) + notes_b = extract_notes(file_b) + if not notes_a or not notes_b: + return (file_a, file_b, np.inf) + dist = midi_dist(notes_a, notes_b) + return (file_a, file_b, dist) + + +def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8): + files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")] + files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")] + + results = [] + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b] + for fut in as_completed(futures): + results.append(fut.result()) + + # 排序 + results = sorted(results, key=lambda x: x[2]) + + # 保存 + df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"]) + df.to_csv(out_csv, index=False) + print(f"已保存结果到 {out_csv}") + + +if __name__ == "__main__": + dir_a = "folder_a" + dir_b = "folder_b" + batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8) \ No newline at end of file