1013 update
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -13,3 +13,4 @@ wandb/
|
|||||||
.vscode/
|
.vscode/
|
||||||
checkpoints/
|
checkpoints/
|
||||||
metadata/
|
metadata/
|
||||||
|
*.sf2
|
||||||
|
|||||||
@ -1358,6 +1358,7 @@ class Attention(Module):
|
|||||||
dim_latent_kv = None,
|
dim_latent_kv = None,
|
||||||
latent_rope_subheads = None,
|
latent_rope_subheads = None,
|
||||||
onnxable = False,
|
onnxable = False,
|
||||||
|
use_gated_attention = False, # https://arxiv.org/abs/2505.06708
|
||||||
attend_sdp_kwargs: dict = dict(
|
attend_sdp_kwargs: dict = dict(
|
||||||
enable_flash = True,
|
enable_flash = True,
|
||||||
enable_math = True,
|
enable_math = True,
|
||||||
@ -1387,6 +1388,7 @@ class Attention(Module):
|
|||||||
k_dim = dim_head * kv_heads
|
k_dim = dim_head * kv_heads
|
||||||
v_dim = value_dim_head * kv_heads
|
v_dim = value_dim_head * kv_heads
|
||||||
out_dim = value_dim_head * heads
|
out_dim = value_dim_head * heads
|
||||||
|
gated_dim = out_dim
|
||||||
|
|
||||||
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
|
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
|
||||||
# for eventually supporting multi-latent attention (MLA)
|
# for eventually supporting multi-latent attention (MLA)
|
||||||
@ -1447,7 +1449,8 @@ class Attention(Module):
|
|||||||
|
|
||||||
self.to_v_gate = None
|
self.to_v_gate = None
|
||||||
if gate_values:
|
if gate_values:
|
||||||
self.to_v_gate = nn.Linear(dim, out_dim)
|
# self.to_v_gate = nn.Linear(dim, out_dim)
|
||||||
|
self.to_v_gate = nn.Linear(dim_kv_input, gated_dim)
|
||||||
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
|
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
|
||||||
nn.init.constant_(self.to_v_gate.weight, 0)
|
nn.init.constant_(self.to_v_gate.weight, 0)
|
||||||
nn.init.constant_(self.to_v_gate.bias, 10)
|
nn.init.constant_(self.to_v_gate.bias, 10)
|
||||||
|
|||||||
@ -530,6 +530,62 @@ class Melody(SymbolicMusicDataset):
|
|||||||
test_names.extend(song_dict[song_name])
|
test_names.extend(song_dict[song_name])
|
||||||
return train_names, valid_names, test_names, shuffled_tune_names
|
return train_names, valid_names, test_names, shuffled_tune_names
|
||||||
|
|
||||||
|
class msmidi(SymbolicMusicDataset):
|
||||||
|
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||||
|
for_evaluation: bool = False):
|
||||||
|
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
|
||||||
|
for_evaluation=for_evaluation)
|
||||||
|
|
||||||
|
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
|
||||||
|
'''
|
||||||
|
Irregular tunes are removed from the dataset for better generation quality
|
||||||
|
It includes tunes that are not quantized properly, mostly theay are expressive performance data
|
||||||
|
'''
|
||||||
|
print("preprocessed tune_in_idx data is being loaded")
|
||||||
|
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
|
||||||
|
if self.debug:
|
||||||
|
tune_in_idx_list = tune_in_idx_list[:5000]
|
||||||
|
tune_in_idx_dict = OrderedDict()
|
||||||
|
len_tunes = OrderedDict()
|
||||||
|
file_name_list = []
|
||||||
|
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
|
||||||
|
irregular_tunes = json.load(f)
|
||||||
|
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
||||||
|
if tune_in_idx_file.stem in irregular_tunes:
|
||||||
|
continue
|
||||||
|
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
||||||
|
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
||||||
|
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
||||||
|
file_name_list.append(tune_in_idx_file.stem)
|
||||||
|
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
|
||||||
|
return tune_in_idx_dict, len_tunes, file_name_list
|
||||||
|
|
||||||
|
def _get_split_list_from_tune_in_idx(self, ratio, seed):
|
||||||
|
'''
|
||||||
|
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
|
||||||
|
'''
|
||||||
|
shuffled_tune_names = list(self.tune_in_idx.keys())
|
||||||
|
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
||||||
|
song_dict = {}
|
||||||
|
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
||||||
|
if song not in song_dict:
|
||||||
|
song_dict[song] = []
|
||||||
|
song_dict[song].append(orig_song)
|
||||||
|
unique_song_names = list(song_dict.keys())
|
||||||
|
random.seed(seed)
|
||||||
|
random.shuffle(unique_song_names)
|
||||||
|
num_train = int(len(unique_song_names)*ratio)
|
||||||
|
num_valid = int(len(unique_song_names)*(1-ratio)/2)
|
||||||
|
train_names = []
|
||||||
|
valid_names = []
|
||||||
|
test_names = []
|
||||||
|
for song_name in unique_song_names[:num_train]:
|
||||||
|
train_names.extend(song_dict[song_name])
|
||||||
|
for song_name in unique_song_names[num_train:num_train+num_valid]:
|
||||||
|
valid_names.extend(song_dict[song_name])
|
||||||
|
for song_name in unique_song_names[num_train+num_valid:]:
|
||||||
|
test_names.extend(song_dict[song_name])
|
||||||
|
return train_names, valid_names, test_names, shuffled_tune_names
|
||||||
|
|
||||||
class IrishMan(SymbolicMusicDataset):
|
class IrishMan(SymbolicMusicDataset):
|
||||||
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||||
@ -648,62 +704,62 @@ class ariamidi(SymbolicMusicDataset):
|
|||||||
test_names.extend(song_dict[song_name])
|
test_names.extend(song_dict[song_name])
|
||||||
return train_names, valid_names, test_names, shuffled_tune_names
|
return train_names, valid_names, test_names, shuffled_tune_names
|
||||||
|
|
||||||
class gigamidi(SymbolicMusicDataset):
|
# class gigamidi(SymbolicMusicDataset):
|
||||||
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
# def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||||
for_evaluation: bool = False):
|
# for_evaluation: bool = False):
|
||||||
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
|
# super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
|
||||||
for_evaluation=for_evaluation)
|
# for_evaluation=for_evaluation)
|
||||||
|
|
||||||
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
|
# def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
|
||||||
'''
|
# '''
|
||||||
Irregular tunes are removed from the dataset for better generation quality
|
# Irregular tunes are removed from the dataset for better generation quality
|
||||||
It includes tunes that are not quantized properly, mostly theay are expressive performance data
|
# It includes tunes that are not quantized properly, mostly theay are expressive performance data
|
||||||
'''
|
# '''
|
||||||
print("preprocessed tune_in_idx data is being loaded")
|
# print("preprocessed tune_in_idx data is being loaded")
|
||||||
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
|
# tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
|
||||||
if self.debug:
|
# if self.debug:
|
||||||
tune_in_idx_list = tune_in_idx_list[:5000]
|
# tune_in_idx_list = tune_in_idx_list[:5000]
|
||||||
tune_in_idx_dict = OrderedDict()
|
# tune_in_idx_dict = OrderedDict()
|
||||||
len_tunes = OrderedDict()
|
# len_tunes = OrderedDict()
|
||||||
file_name_list = []
|
# file_name_list = []
|
||||||
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
|
# with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
|
||||||
irregular_tunes = json.load(f)
|
# irregular_tunes = json.load(f)
|
||||||
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
# for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
||||||
if tune_in_idx_file.stem in irregular_tunes:
|
# if tune_in_idx_file.stem in irregular_tunes:
|
||||||
continue
|
# continue
|
||||||
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
# tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
||||||
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
# tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
||||||
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
# len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
||||||
file_name_list.append(tune_in_idx_file.stem)
|
# file_name_list.append(tune_in_idx_file.stem)
|
||||||
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
|
# print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
|
||||||
return tune_in_idx_dict, len_tunes, file_name_list
|
# return tune_in_idx_dict, len_tunes, file_name_list
|
||||||
|
|
||||||
def _get_split_list_from_tune_in_idx(self, ratio, seed):
|
# def _get_split_list_from_tune_in_idx(self, ratio, seed):
|
||||||
'''
|
# '''
|
||||||
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
|
# As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
|
||||||
'''
|
# '''
|
||||||
shuffled_tune_names = list(self.tune_in_idx.keys())
|
# shuffled_tune_names = list(self.tune_in_idx.keys())
|
||||||
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
# song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
||||||
song_dict = {}
|
# song_dict = {}
|
||||||
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
# for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
||||||
if song not in song_dict:
|
# if song not in song_dict:
|
||||||
song_dict[song] = []
|
# song_dict[song] = []
|
||||||
song_dict[song].append(orig_song)
|
# song_dict[song].append(orig_song)
|
||||||
unique_song_names = list(song_dict.keys())
|
# unique_song_names = list(song_dict.keys())
|
||||||
random.seed(seed)
|
# random.seed(seed)
|
||||||
random.shuffle(unique_song_names)
|
# random.shuffle(unique_song_names)
|
||||||
num_train = int(len(unique_song_names)*ratio)
|
# num_train = int(len(unique_song_names)*ratio)
|
||||||
num_valid = int(len(unique_song_names)*(1-ratio)/2)
|
# num_valid = int(len(unique_song_names)*(1-ratio)/2)
|
||||||
train_names = []
|
# train_names = []
|
||||||
valid_names = []
|
# valid_names = []
|
||||||
test_names = []
|
# test_names = []
|
||||||
for song_name in unique_song_names[:num_train]:
|
# for song_name in unique_song_names[:num_train]:
|
||||||
train_names.extend(song_dict[song_name])
|
# train_names.extend(song_dict[song_name])
|
||||||
for song_name in unique_song_names[num_train:num_train+num_valid]:
|
# for song_name in unique_song_names[num_train:num_train+num_valid]:
|
||||||
valid_names.extend(song_dict[song_name])
|
# valid_names.extend(song_dict[song_name])
|
||||||
for song_name in unique_song_names[num_train+num_valid:]:
|
# for song_name in unique_song_names[num_train+num_valid:]:
|
||||||
test_names.extend(song_dict[song_name])
|
# test_names.extend(song_dict[song_name])
|
||||||
return train_names, valid_names, test_names, shuffled_tune_names
|
# return train_names, valid_names, test_names, shuffled_tune_names
|
||||||
|
|
||||||
class ariamidi(SymbolicMusicDataset):
|
class ariamidi(SymbolicMusicDataset):
|
||||||
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||||
@ -788,6 +844,9 @@ class gigamidi(SymbolicMusicDataset):
|
|||||||
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
||||||
if tune_in_idx_file.stem in irregular_tunes:
|
if tune_in_idx_file.stem in irregular_tunes:
|
||||||
continue
|
continue
|
||||||
|
if "drums-only" in tune_in_idx_file.stem:
|
||||||
|
print(f"skipping {tune_in_idx_file.stem} as it is a drums-only file")
|
||||||
|
continue
|
||||||
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
||||||
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
||||||
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
||||||
|
|||||||
@ -11,7 +11,7 @@ License: MIT, see the LICENSE file
|
|||||||
|
|
||||||
__all__ = ['FluidSynth']
|
__all__ = ['FluidSynth']
|
||||||
|
|
||||||
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
|
DEFAULT_SOUND_FONT = 'Alex_GM.sf2'
|
||||||
DEFAULT_SAMPLE_RATE = 48000
|
DEFAULT_SAMPLE_RATE = 48000
|
||||||
DEFAULT_GAIN = 0.05
|
DEFAULT_GAIN = 0.05
|
||||||
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
|
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
|
||||||
|
|||||||
@ -2,7 +2,8 @@ defaults:
|
|||||||
# - nn_params: nb8_embSum_NMT
|
# - nn_params: nb8_embSum_NMT
|
||||||
# - nn_params: remi8
|
# - nn_params: remi8
|
||||||
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
||||||
- nn_params: nb8_embSum_diff_t2m_150M_pretrainingv2
|
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||||
|
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||||
# - nn_params: nb8_embSum_subPararell
|
# - nn_params: nb8_embSum_subPararell
|
||||||
# - nn_params: nb8_embSum_diff_t2m_150M
|
# - nn_params: nb8_embSum_diff_t2m_150M
|
||||||
|
|
||||||
@ -14,7 +15,7 @@ defaults:
|
|||||||
# - nn_params: remi8_main12_head_16_dim512
|
# - nn_params: remi8_main12_head_16_dim512
|
||||||
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||||
|
|
||||||
dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||||
captions_path: dataset/midicaps/train_set.json
|
captions_path: dataset/midicaps/train_set.json
|
||||||
|
|
||||||
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||||
@ -30,20 +31,20 @@ tau: 0.5
|
|||||||
|
|
||||||
train_params:
|
train_params:
|
||||||
device: cuda
|
device: cuda
|
||||||
batch_size: 3
|
batch_size: 5
|
||||||
grad_clip: 1.0
|
grad_clip: 1.0
|
||||||
num_iter: 300000 # total number of iterations
|
num_iter: 300000 # total number of iterations
|
||||||
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
||||||
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
|
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
|
||||||
iterations_per_training_cycle: 10 # number of iterations for logging training loss
|
iterations_per_training_cycle: 10 # number of iterations for logging training loss
|
||||||
iterations_per_validation_cycle: 5000 # number of iterations for validation process
|
iterations_per_validation_cycle: 3000 # number of iterations for validation process
|
||||||
input_length: 3072 # input sequence length3072
|
input_length: 3072 # input sequence length3072
|
||||||
# you can use focal loss, it it's not used, set focal_gamma to 0
|
# you can use focal loss, it it's not used, set focal_gamma to 0
|
||||||
focal_alpha: 1
|
focal_alpha: 1
|
||||||
focal_gamma: 0
|
focal_gamma: 0
|
||||||
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||||
scheduler : cosinelr
|
scheduler : cosinelr
|
||||||
initial_lr: 0.00005
|
initial_lr: 0.0004
|
||||||
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||||
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||||
warmup_steps: 2000 #number of warmup steps
|
warmup_steps: 2000 #number of warmup steps
|
||||||
|
|||||||
@ -5,13 +5,13 @@ model_name: AmadeusModel
|
|||||||
input_embedder_name: SummationEmbedder
|
input_embedder_name: SummationEmbedder
|
||||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||||
sub_decoder_name: DiffusionDecoder
|
sub_decoder_name: DiffusionDecoder
|
||||||
model_dropout: 0
|
model_dropout: 0.2
|
||||||
input_embedder:
|
input_embedder:
|
||||||
num_layer: 1
|
num_layer: 1
|
||||||
num_head: 8
|
num_head: 8
|
||||||
main_decoder:
|
main_decoder:
|
||||||
dim_model: 768
|
dim_model: 768
|
||||||
num_layer: 20
|
num_layer: 16
|
||||||
num_head: 12
|
num_head: 12
|
||||||
sub_decoder:
|
sub_decoder:
|
||||||
decout_window_size: 1 # 1 means no previous decoding output added
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
|||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: AmadeusModel
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerNewFinetunningDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 1024
|
||||||
|
num_layer: 32
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
encoding_scheme: nb
|
||||||
|
num_features: 8
|
||||||
|
vocab_name: MusicTokenVocabNB
|
||||||
|
model_name: AmadeusModel
|
||||||
|
input_embedder_name: SummationEmbedder
|
||||||
|
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||||
|
sub_decoder_name: DiffusionDecoder
|
||||||
|
model_dropout: 0
|
||||||
|
input_embedder:
|
||||||
|
num_layer: 1
|
||||||
|
num_head: 8
|
||||||
|
main_decoder:
|
||||||
|
dim_model: 1024
|
||||||
|
num_layer: 32
|
||||||
|
num_head: 16
|
||||||
|
sub_decoder:
|
||||||
|
decout_window_size: 1 # 1 means no previous decoding output added
|
||||||
|
num_layer: 1
|
||||||
|
feature_enricher_use: False
|
||||||
@ -729,9 +729,8 @@ class XtransformerNewPretrainingDecoder(nn.Module):
|
|||||||
rotary_pos_emb = True,
|
rotary_pos_emb = True,
|
||||||
ff_swish = True, # set this to True
|
ff_swish = True, # set this to True
|
||||||
ff_glu = True, # set to true to use for all feedforwards
|
ff_glu = True, # set to true to use for all feedforwards
|
||||||
# shift_tokens = 1,
|
attn_gate_values = True,
|
||||||
# attn_qk_norm = True,
|
attn_qk_norm = True,
|
||||||
# attn_qk_norm_dim_scale = True
|
|
||||||
)
|
)
|
||||||
# add final dropout
|
# add final dropout
|
||||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
@ -758,7 +757,7 @@ class XtransformerNewPretrainingDecoder(nn.Module):
|
|||||||
|
|
||||||
def _apply_xavier_init(self):
|
def _apply_xavier_init(self):
|
||||||
for name, param in self.transformer_decoder.named_parameters():
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
|
||||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
|
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
|
||||||
@ -924,3 +923,99 @@ class XtransformerFinetuningDecoder(nn.Module):
|
|||||||
hidden_vec = self.transformer_decoder(seq)
|
hidden_vec = self.transformer_decoder(seq)
|
||||||
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||||
return hidden_vec
|
return hidden_vec
|
||||||
|
|
||||||
|
class XtransformerNewFinetunningDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim:int,
|
||||||
|
depth:int,
|
||||||
|
heads:int,
|
||||||
|
dropout:float
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||||
|
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||||
|
# frozen text encoder
|
||||||
|
for param in self.text_encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
if dim != 768:
|
||||||
|
self.text_project = nn.Linear(768, dim) # assuming T5 base hidden size is 768
|
||||||
|
|
||||||
|
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||||
|
self.transformer_decoder = Decoder(
|
||||||
|
dim = dim,
|
||||||
|
depth = depth,
|
||||||
|
heads = heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
|
attn_flash = True,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
rotary_pos_emb = True,
|
||||||
|
ff_swish = True, # set this to True
|
||||||
|
ff_glu = True, # set to true to use for all feedforwards
|
||||||
|
attn_gate_values = True,
|
||||||
|
attn_qk_norm = True,
|
||||||
|
) # add final dropout
|
||||||
|
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||||
|
self._apply_xavier_init()
|
||||||
|
print('Adding dropout after feedforward layer in x-transformer')
|
||||||
|
self._add_dropout_after_ff(dropout)
|
||||||
|
print('Adding dropout after attention layer in x-transformer')
|
||||||
|
self._add_dropout_after_attn(dropout)
|
||||||
|
|
||||||
|
def _add_dropout_after_attn(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'Attention' in str(type(layer[1])):
|
||||||
|
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||||
|
layer[1].to_out.append(nn.Dropout(dropout))
|
||||||
|
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||||
|
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||||
|
|
||||||
|
def _add_dropout_after_ff(self, dropout):
|
||||||
|
for layer in self.transformer_decoder.layers:
|
||||||
|
if 'FeedForward' in str(type(layer[1])):
|
||||||
|
layer[1].ff.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def _apply_xavier_init(self):
|
||||||
|
for name, param in self.transformer_decoder.named_parameters():
|
||||||
|
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
|
||||||
|
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||||
|
|
||||||
|
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||||
|
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||||
|
if context_embedding is None:
|
||||||
|
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||||
|
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||||
|
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||||
|
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||||
|
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||||
|
|
||||||
|
context = self.text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
).last_hidden_state
|
||||||
|
else:
|
||||||
|
context = context_embedding
|
||||||
|
if hasattr(self, 'text_project'):
|
||||||
|
context = self.text_project(context)
|
||||||
|
|
||||||
|
# concatenate context with seq
|
||||||
|
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
|
||||||
|
if cache is not None: # implementing run_one_step in inference
|
||||||
|
if cache.hiddens is None: cache = None
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||||
|
# cut to only return the seq part
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
if train:
|
||||||
|
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||||
|
# cut to only return the seq part
|
||||||
|
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||||
|
return hidden_vec, intermediates
|
||||||
|
else:
|
||||||
|
# cut to only return the seq part
|
||||||
|
hidden_vec = self.transformer_decoder(seq)
|
||||||
|
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||||
|
return hidden_vec
|
||||||
@ -113,7 +113,7 @@ class CorpusMaker():
|
|||||||
0 to 2000 means no limitation
|
0 to 2000 means no limitation
|
||||||
'''
|
'''
|
||||||
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
|
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
|
||||||
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)}
|
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (8, 3000), 'Symphony': (60, 1500)}
|
||||||
try:
|
try:
|
||||||
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
|
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|||||||
@ -105,6 +105,7 @@ def load_resources(wandb_exp_dir, device):
|
|||||||
config = wandb_style_config_to_omega_config(config)
|
config = wandb_style_config_to_omega_config(config)
|
||||||
|
|
||||||
# Load checkpoint to specified device
|
# Load checkpoint to specified device
|
||||||
|
print("Loading checkpoint from:", ckpt_path)
|
||||||
ckpt = torch.load(ckpt_path, map_location=device)
|
ckpt = torch.load(ckpt_path, map_location=device)
|
||||||
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
||||||
model.load_state_dict(ckpt['model'], strict=False)
|
model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
|||||||
BIN
len_tunes/FinetuneDataset/len_nb8.png
Normal file
BIN
len_tunes/FinetuneDataset/len_nb8.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
len_tunes/IrishMan/len_nb8.png
Normal file
BIN
len_tunes/IrishMan/len_nb8.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
BIN
len_tunes/gigamidi/len_nb8.png
Normal file
BIN
len_tunes/gigamidi/len_nb8.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
442
midi_stastic.py
Normal file
442
midi_stastic.py
Normal file
@ -0,0 +1,442 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MIDI Statistics Extractor
|
||||||
|
|
||||||
|
Usage: python midi_statistics.py <path_to_directory> [options]
|
||||||
|
|
||||||
|
This script traverses a directory and all subdirectories to find MID files,
|
||||||
|
extracts musical features from each file using multi-threading for speed,
|
||||||
|
and saves the results to CSV files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from itertools import chain
|
||||||
|
from math import ceil
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.lib.stride_tricks import sliding_window_view
|
||||||
|
from symusic import Score
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from numba import njit, prange
|
||||||
|
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def merge_intervals(intervals: list[tuple[int, int]], threshold: int):
|
||||||
|
"""Merge overlapping or close intervals."""
|
||||||
|
out = []
|
||||||
|
last_s, last_e = intervals[0]
|
||||||
|
|
||||||
|
for i in range(1, len(intervals)):
|
||||||
|
s, e = intervals[i]
|
||||||
|
|
||||||
|
if s - last_e <= threshold:
|
||||||
|
if e > last_e:
|
||||||
|
last_e = e
|
||||||
|
else:
|
||||||
|
out.append((last_s, last_e))
|
||||||
|
last_s, last_e = s, e
|
||||||
|
|
||||||
|
out.append((last_s, last_e))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@njit(fastmath=True)
|
||||||
|
def note_distribution(events: list[tuple[float, int]], threshold: int = 2, segment_threshold: int = 0):
|
||||||
|
"""Calculate polyphony rate and sounding segments."""
|
||||||
|
try:
|
||||||
|
events.sort()
|
||||||
|
active_notes = 0
|
||||||
|
polyphonic_steps = 0
|
||||||
|
total_steps = 0
|
||||||
|
last_time = None
|
||||||
|
last_state = False
|
||||||
|
last_seg_start = 0
|
||||||
|
sounding_segments = []
|
||||||
|
|
||||||
|
for time, change in events:
|
||||||
|
if last_time is not None and time != last_time:
|
||||||
|
if active_notes >= threshold:
|
||||||
|
polyphonic_steps += (time - last_time)
|
||||||
|
if active_notes:
|
||||||
|
total_steps += (time - last_time)
|
||||||
|
if(last_state != bool(active_notes)):
|
||||||
|
if(last_state):
|
||||||
|
last_seg_start = time
|
||||||
|
else:
|
||||||
|
sounding_segments.append((last_seg_start, time))
|
||||||
|
|
||||||
|
active_notes += change
|
||||||
|
last_state = bool(active_notes)
|
||||||
|
last_time = time
|
||||||
|
|
||||||
|
if(segment_threshold != 0):
|
||||||
|
sounding_segments = merge_intervals(sounding_segments, segment_threshold)
|
||||||
|
|
||||||
|
return polyphonic_steps / total_steps, total_steps, sounding_segments
|
||||||
|
except:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@njit(fastmath=True)
|
||||||
|
def entropy(X: np.ndarray, base: float = 2.0) -> float:
|
||||||
|
"""Calculate entropy function optimized with numba."""
|
||||||
|
N, M = X.shape
|
||||||
|
out = np.empty(N, dtype=np.float64)
|
||||||
|
log_base = np.log(base) if base > 0.0 else 1.0
|
||||||
|
|
||||||
|
for i in prange(N):
|
||||||
|
row = X[i]
|
||||||
|
total = np.nansum(row)
|
||||||
|
if total <= 0.0:
|
||||||
|
out[i] = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask = (~np.isnan(row)) & (row > 0.0)
|
||||||
|
probs = row[mask] / total
|
||||||
|
if probs.size == 0:
|
||||||
|
out[i] = 0.0
|
||||||
|
else:
|
||||||
|
H = -np.sum(probs * np.log(probs))
|
||||||
|
if base > 0.0:
|
||||||
|
H /= log_base
|
||||||
|
out[i] = H
|
||||||
|
|
||||||
|
nz = out > 0.0
|
||||||
|
if not np.any(nz):
|
||||||
|
return 0.0
|
||||||
|
return float(np.exp(np.mean(np.log(out[nz]))))
|
||||||
|
|
||||||
|
|
||||||
|
@njit(fastmath=True)
|
||||||
|
def n_gram_co_occurence_entropy(seq: list[list[int]], N: int = 5):
|
||||||
|
"""Calculate n-gram co-occurrence entropy."""
|
||||||
|
counts = []
|
||||||
|
|
||||||
|
for seg in seq:
|
||||||
|
if len(seg) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
arr = np.asarray(seg, dtype=np.int64)
|
||||||
|
|
||||||
|
min_val = np.min(arr)
|
||||||
|
if min_val < 0:
|
||||||
|
arr = arr - min_val
|
||||||
|
|
||||||
|
vocabs = int(np.max(arr) + 1)
|
||||||
|
|
||||||
|
wlen = N if len(arr) >= N else len(arr)
|
||||||
|
nwin = len(arr) - wlen + 1
|
||||||
|
|
||||||
|
C = np.zeros((vocabs, vocabs), dtype=np.int64)
|
||||||
|
|
||||||
|
for start in range(nwin):
|
||||||
|
for i in range(wlen - 1):
|
||||||
|
a = int(arr[start + i])
|
||||||
|
for j in range(i + 1, wlen):
|
||||||
|
b = int(arr[start + j])
|
||||||
|
if a < vocabs and b < vocabs:
|
||||||
|
C[a, b] += 1
|
||||||
|
|
||||||
|
for i in range(vocabs):
|
||||||
|
counts.append(int(C[i, i]))
|
||||||
|
for j in range(i + 1, vocabs):
|
||||||
|
counts.append(int(C[i, j]))
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for v in counts:
|
||||||
|
total += v
|
||||||
|
|
||||||
|
if total <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
H = 0.0
|
||||||
|
for v in counts:
|
||||||
|
if v > 0:
|
||||||
|
p = v / total
|
||||||
|
H -= p * np.log(p)
|
||||||
|
|
||||||
|
return H
|
||||||
|
|
||||||
|
|
||||||
|
def calc_pitch_distribution(pitches: np.ndarray, window_size: int = 32, hop_size: int = 16):
|
||||||
|
"""Calculate pitch distribution features."""
|
||||||
|
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(pitches) > window_size else (lambda x: x.reshape(1, -1))
|
||||||
|
|
||||||
|
used_pitches = np.unique(pitches)
|
||||||
|
n_pitches_used = len(used_pitches)
|
||||||
|
pitch_entropy = entropy(sw(pitches))
|
||||||
|
pitch_range = [int(min(used_pitches)), int(max(used_pitches))]
|
||||||
|
|
||||||
|
pitch_classes = pitches % 12
|
||||||
|
n_pitch_classes_used = len(np.unique(pitch_classes))
|
||||||
|
pitch_class_entropy = entropy(sw(pitch_classes))
|
||||||
|
|
||||||
|
return n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range
|
||||||
|
|
||||||
|
|
||||||
|
def calc_rhythmic_entropy(ioi: np.ndarray, window_size: int = 32, hop_size: int = 16):
|
||||||
|
"""Calculate rhythmic entropy."""
|
||||||
|
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(ioi) > window_size else (lambda x: x.reshape(1, -1))
|
||||||
|
if(len(ioi) == 0):
|
||||||
|
return None
|
||||||
|
return entropy(sw(ioi))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(midi_path: pathlib.Path, tpq: int = 6):
|
||||||
|
"""Extract features from a single MIDI file."""
|
||||||
|
try:
|
||||||
|
seg_threshold = tpq * 8
|
||||||
|
midi_id = midi_path.parent.name + '/' + midi_path.stem
|
||||||
|
score = Score(midi_path).resample(tpq)
|
||||||
|
|
||||||
|
track_features = []
|
||||||
|
for i, t in enumerate(score.tracks):
|
||||||
|
if(not len(t.notes)):
|
||||||
|
track_features.append((
|
||||||
|
midi_id, # midi_id
|
||||||
|
i, # track_id
|
||||||
|
128 if t.is_drum else t.program, # instrument
|
||||||
|
|
||||||
|
0, # end_time
|
||||||
|
0, # note_num
|
||||||
|
None, # sounding_interval
|
||||||
|
|
||||||
|
None, # note_density
|
||||||
|
None, # polyphony_rate
|
||||||
|
None, # rhythmic_entropy
|
||||||
|
None, # rhythmic_token_co_occurrence_entropy
|
||||||
|
|
||||||
|
None, # n_pitch_classes_used
|
||||||
|
None, # n_pitches_used
|
||||||
|
None, # pitch_class_entropy
|
||||||
|
None, # pitch_entropy
|
||||||
|
None, # pitch_range
|
||||||
|
None # interval_token_co_occurrence_entropy
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
t.sort()
|
||||||
|
|
||||||
|
features = t.notes.numpy()
|
||||||
|
|
||||||
|
ioi = np.diff(features['time'])
|
||||||
|
seg_points = np.where(ioi > tpq * seg_threshold)[0]
|
||||||
|
|
||||||
|
polyphony_rate, sounding_interval_length, sounding_segment = note_distribution(list(chain(*
|
||||||
|
[((note.start, 1), (note.end, -1)) for note in t.notes])))
|
||||||
|
rhythmic_entropy = calc_rhythmic_entropy(ioi)
|
||||||
|
|
||||||
|
rhythmic_token_co_occurrence_entropy = n_gram_co_occurence_entropy([i for i in np.split(ioi, seg_points) if np.all(i) <= seg_threshold])
|
||||||
|
|
||||||
|
if(t.is_drum or len(t.notes) < 2):
|
||||||
|
track_features.append((
|
||||||
|
midi_id, # midi_id
|
||||||
|
i, # track_id
|
||||||
|
128 if t.is_drum else t.program, # instrument
|
||||||
|
|
||||||
|
t.end(), # end_time
|
||||||
|
len(t.notes), # note_num
|
||||||
|
sounding_interval_length, # sounding_interval
|
||||||
|
|
||||||
|
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
|
||||||
|
polyphony_rate, # polyphony_rate
|
||||||
|
rhythmic_entropy, # rhythmic_entropy
|
||||||
|
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
|
||||||
|
|
||||||
|
None, # n_pitch_classes_used
|
||||||
|
None, # n_pitches_used
|
||||||
|
None, # pitch_class_entropy
|
||||||
|
None, # pitch_entropy
|
||||||
|
None, # pitch_range
|
||||||
|
None # interval_token_co_occurrence_entropy
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range = calc_pitch_distribution(features['pitch'])
|
||||||
|
intervals = np.diff(features['pitch'])
|
||||||
|
track_features.append((
|
||||||
|
midi_id, # midi_id
|
||||||
|
i, # track_id
|
||||||
|
t.program, # instrument
|
||||||
|
|
||||||
|
t.end(), # end_time
|
||||||
|
len(t.notes), # note_num
|
||||||
|
sounding_interval_length, # sounding_interval
|
||||||
|
|
||||||
|
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
|
||||||
|
polyphony_rate, # polyphony_rate
|
||||||
|
rhythmic_entropy, # rhythmic_entropy
|
||||||
|
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
|
||||||
|
|
||||||
|
n_pitch_classes_used, # n_pitch_classes_used
|
||||||
|
n_pitches_used, # n_pitches_used
|
||||||
|
pitch_class_entropy, # pitch_class_entropy
|
||||||
|
pitch_entropy, # pitch_entropy
|
||||||
|
json.dumps(pitch_range), # pitch_range
|
||||||
|
n_gram_co_occurence_entropy([p for i, p in zip(np.split(ioi, seg_points), np.split(intervals, seg_points)) if np.all(i) <= seg_threshold]) # interval_token_co_occurrence_entropy
|
||||||
|
))
|
||||||
|
|
||||||
|
score_features = (
|
||||||
|
midi_id, # midi_id
|
||||||
|
sum(tf[4] for tf in track_features) if track_features else 0, # note_num
|
||||||
|
max(tf[3] for tf in track_features) if track_features else 0, # end_time
|
||||||
|
json.dumps([[ks.time, ks.key, ks.tonality] for ks in score.key_signatures]), # key
|
||||||
|
json.dumps([[ts.time, ts.numerator, ts.denominator] for ts in score.time_signatures]), # time_signature
|
||||||
|
json.dumps([[t.time, t.qpm] for t in score.tempos]) # tempo
|
||||||
|
)
|
||||||
|
|
||||||
|
return score_features, track_features
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {midi_path}: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def find_midi_files(directory: pathlib.Path):
|
||||||
|
"""Find all MIDI files in directory and subdirectories."""
|
||||||
|
midi_extensions = {'.mid', '.midi', '.MID', '.MIDI'}
|
||||||
|
midi_files = []
|
||||||
|
|
||||||
|
# Use rglob to recursively find MIDI files
|
||||||
|
for file_path in directory.rglob('*'):
|
||||||
|
if file_path.is_file() and file_path.suffix in midi_extensions:
|
||||||
|
midi_files.append(file_path)
|
||||||
|
|
||||||
|
return midi_files
|
||||||
|
|
||||||
|
|
||||||
|
def process_midi_files(directory: pathlib.Path, output_prefix: str = "midi_features",
|
||||||
|
num_threads: int = 4, tpq: int = 6):
|
||||||
|
"""Process MIDI files with multi-threading and save to CSV."""
|
||||||
|
|
||||||
|
# Find all MIDI files
|
||||||
|
print(f"Searching for MIDI files in: {directory}")
|
||||||
|
midi_files = find_midi_files(directory)
|
||||||
|
|
||||||
|
if not midi_files:
|
||||||
|
print(f"No MIDI files found in {directory}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(midi_files)} MIDI files")
|
||||||
|
|
||||||
|
# Create extractor function with fixed parameters
|
||||||
|
extractor = partial(extract_features, tpq=tpq)
|
||||||
|
|
||||||
|
# Feature column names
|
||||||
|
score_feat_cols = ['midi_id', 'note_num', 'end_time', 'key', 'time_signature', 'tempo']
|
||||||
|
track_feat_cols = ['midi_id', 'track_id', 'instrument', 'end_time', 'note_num',
|
||||||
|
'sounding_interval', 'note_density', 'polyphony_rate', 'rhythmic_entropy',
|
||||||
|
'rhythmic_token_co_occurrence_entropy', 'n_pitch_classes_used',
|
||||||
|
'n_pitches_used', 'pitch_class_entropy', 'pitch_entropy', 'pitch_range',
|
||||||
|
'interval_token_co_occurrence_entropy']
|
||||||
|
|
||||||
|
# Process files with multiprocessing
|
||||||
|
print(f"Processing files with {num_threads} threads...")
|
||||||
|
|
||||||
|
with Pool(num_threads) as pool:
|
||||||
|
# Open CSV files for writing
|
||||||
|
with open(f'{output_prefix}_score_features.csv', 'w', newline='', encoding='utf-8') as score_csvfile:
|
||||||
|
score_writer = csv.writer(score_csvfile)
|
||||||
|
score_writer.writerow(score_feat_cols)
|
||||||
|
|
||||||
|
with open(f'{output_prefix}_track_features.csv', 'w', newline='', encoding='utf-8') as track_csvfile:
|
||||||
|
track_writer = csv.writer(track_csvfile)
|
||||||
|
track_writer.writerow(track_feat_cols)
|
||||||
|
|
||||||
|
# Process files with progress bar
|
||||||
|
processed_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
for score_feat, track_feats in tqdm(pool.imap_unordered(extractor, midi_files),
|
||||||
|
total=len(midi_files),
|
||||||
|
desc="Processing MIDI files"):
|
||||||
|
if not (score_feat, track_feats):
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_count += 1
|
||||||
|
|
||||||
|
# Write score features
|
||||||
|
score_writer.writerow(score_feat)
|
||||||
|
|
||||||
|
# Write track features
|
||||||
|
if track_feats:
|
||||||
|
track_writer.writerows(track_feats)
|
||||||
|
|
||||||
|
print(f"\nProcessing complete!")
|
||||||
|
print(f"Successfully processed: {processed_count} files")
|
||||||
|
print(f"Skipped due to errors: {skipped_count} files")
|
||||||
|
print(f"Score features saved to: {output_prefix}_score_features.csv")
|
||||||
|
print(f"Track features saved to: {output_prefix}_track_features.csv")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function with command line argument parsing."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Extract musical features from MIDI files and save to CSV",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
python midi_statistics.py /path/to/midi/files
|
||||||
|
python midi_statistics.py /path/to/midi/files --threads 8 --output my_features
|
||||||
|
python midi_statistics.py /path/to/midi/files --tpq 12 --threads 2
|
||||||
|
|
||||||
|
Features extracted:
|
||||||
|
- Score level: note count, end time, key signatures, time signatures, tempo
|
||||||
|
- Track level: instrument, note density, polyphony rate, rhythmic entropy,
|
||||||
|
pitch distribution, and more
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('directory',
|
||||||
|
help='Path to directory containing MIDI files')
|
||||||
|
|
||||||
|
parser.add_argument('--threads', '-t',
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help='Number of threads to use (default: 4)')
|
||||||
|
|
||||||
|
parser.add_argument('--output', '-o',
|
||||||
|
type=str,
|
||||||
|
default='midi_features',
|
||||||
|
help='Output file prefix (default: midi_features)')
|
||||||
|
|
||||||
|
parser.add_argument('--tpq',
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
help='Ticks per quarter note for resampling (default: 6)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate directory
|
||||||
|
directory = pathlib.Path(args.directory)
|
||||||
|
if not directory.exists():
|
||||||
|
print(f"Error: Directory '{directory}' does not exist")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if not directory.is_dir():
|
||||||
|
print(f"Error: '{directory}' is not a directory")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Validate threads
|
||||||
|
if args.threads < 1:
|
||||||
|
print("Error: Number of threads must be at least 1")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
process_midi_files(directory, args.output, args.threads, args.tpq)
|
||||||
|
return 0
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nProcessing interrupted by user")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
105
,idi_sim.py
Normal file
105
,idi_sim.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from symusic import Score
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
|
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
|
||||||
|
|
||||||
|
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True):
|
||||||
|
if(not a.shape[1] or not b.shape[1]):
|
||||||
|
return np.inf
|
||||||
|
a_onset, a_pitch = a
|
||||||
|
b_onset, b_pitch = b
|
||||||
|
a_onset = a_onset.astype(np.float32)
|
||||||
|
b_onset = b_onset.astype(np.float32)
|
||||||
|
a_pitch = a_pitch.astype(np.uint8)
|
||||||
|
b_pitch = b_pitch.astype(np.uint8)
|
||||||
|
|
||||||
|
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
|
||||||
|
if(oti):
|
||||||
|
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12]
|
||||||
|
dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight)
|
||||||
|
a2b = dist_matrix.min(2)
|
||||||
|
b2a = dist_matrix.min(1)
|
||||||
|
dist = np.concatenate([a2b, b2a], axis=1)
|
||||||
|
return dist.sum(axis=1).min() / len(dist)
|
||||||
|
else:
|
||||||
|
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12]
|
||||||
|
dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight)
|
||||||
|
a2b = dist_matrix.min(1)
|
||||||
|
b2a = dist_matrix.min(0)
|
||||||
|
return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.):
|
||||||
|
x = sorted(x)
|
||||||
|
end_time = x[-1][0]
|
||||||
|
out = [[] for _ in range(int(end_time // hop_size))]
|
||||||
|
for i in sorted(x):
|
||||||
|
segment = min(int(i[0] // hop_size), len(out) - 1)
|
||||||
|
while(i[0] >= segment * hop_size):
|
||||||
|
out[segment].append(i)
|
||||||
|
segment -= 1
|
||||||
|
if(segment < 0):
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
|
||||||
|
a = midi_time_sliding_window(a)
|
||||||
|
b = midi_time_sliding_window(b)
|
||||||
|
dist = np.inf
|
||||||
|
for i in a:
|
||||||
|
for j in b:
|
||||||
|
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
|
||||||
|
if(cur_dist < dist):
|
||||||
|
dist = cur_dist
|
||||||
|
return dist
|
||||||
|
|
||||||
|
|
||||||
|
def extract_notes(filepath: str):
|
||||||
|
"""读取MIDI并返回 (time, pitch) 列表"""
|
||||||
|
try:
|
||||||
|
s = Score(filepath).to("quarter")
|
||||||
|
notes = []
|
||||||
|
for t in s.tracks:
|
||||||
|
notes.extend([(n.time, n.pitch) for n in t.notes])
|
||||||
|
return notes
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取 {filepath} 出错: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def compare_pair(file_a: str, file_b: str):
|
||||||
|
notes_a = extract_notes(file_a)
|
||||||
|
notes_b = extract_notes(file_b)
|
||||||
|
if not notes_a or not notes_b:
|
||||||
|
return (file_a, file_b, np.inf)
|
||||||
|
dist = midi_dist(notes_a, notes_b)
|
||||||
|
return (file_a, file_b, dist)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
|
||||||
|
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
|
||||||
|
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
|
||||||
|
for fut in as_completed(futures):
|
||||||
|
results.append(fut.result())
|
||||||
|
|
||||||
|
# 排序
|
||||||
|
results = sorted(results, key=lambda x: x[2])
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
|
||||||
|
df.to_csv(out_csv, index=False)
|
||||||
|
print(f"已保存结果到 {out_csv}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dir_a = "folder_a"
|
||||||
|
dir_b = "folder_b"
|
||||||
|
batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8)
|
||||||
Reference in New Issue
Block a user