1013 update

This commit is contained in:
FelixChan
2025-10-13 17:56:36 +08:00
parent d077e3210e
commit d6b68ef90b
17 changed files with 815 additions and 70 deletions

1
.gitignore vendored
View File

@ -13,3 +13,4 @@ wandb/
.vscode/
checkpoints/
metadata/
*.sf2

View File

@ -1358,6 +1358,7 @@ class Attention(Module):
dim_latent_kv = None,
latent_rope_subheads = None,
onnxable = False,
use_gated_attention = False, # https://arxiv.org/abs/2505.06708
attend_sdp_kwargs: dict = dict(
enable_flash = True,
enable_math = True,
@ -1387,6 +1388,7 @@ class Attention(Module):
k_dim = dim_head * kv_heads
v_dim = value_dim_head * kv_heads
out_dim = value_dim_head * heads
gated_dim = out_dim
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
# for eventually supporting multi-latent attention (MLA)
@ -1447,7 +1449,8 @@ class Attention(Module):
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, out_dim)
# self.to_v_gate = nn.Linear(dim, out_dim)
self.to_v_gate = nn.Linear(dim_kv_input, gated_dim)
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 10)

View File

@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
def top_p_sampling(logits, thres=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@ -84,7 +84,7 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
# temporarily apply the sampling method to logits
logits = logits / temperature
# logits = add_gumbel_noise(logits, temperature)
if sampling_method == "top_p":
modified_logits = top_p_sampling(logits, thres=threshold)
elif sampling_method == "typical":

View File

@ -530,6 +530,62 @@ class Melody(SymbolicMusicDataset):
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class msmidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class IrishMan(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
@ -648,62 +704,62 @@ class ariamidi(SymbolicMusicDataset):
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class gigamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
# class gigamidi(SymbolicMusicDataset):
# def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
# for_evaluation: bool = False):
# super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
# for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
# def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
# '''
# Irregular tunes are removed from the dataset for better generation quality
# It includes tunes that are not quantized properly, mostly theay are expressive performance data
# '''
# print("preprocessed tune_in_idx data is being loaded")
# tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
# if self.debug:
# tune_in_idx_list = tune_in_idx_list[:5000]
# tune_in_idx_dict = OrderedDict()
# len_tunes = OrderedDict()
# file_name_list = []
# with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
# irregular_tunes = json.load(f)
# for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# if tune_in_idx_file.stem in irregular_tunes:
# continue
# tune_in_idx = np.load(tune_in_idx_file)['arr_0']
# tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
# len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
# file_name_list.append(tune_in_idx_file.stem)
# print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
# return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
# def _get_split_list_from_tune_in_idx(self, ratio, seed):
# '''
# As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
# '''
# shuffled_tune_names = list(self.tune_in_idx.keys())
# song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
# song_dict = {}
# for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
# if song not in song_dict:
# song_dict[song] = []
# song_dict[song].append(orig_song)
# unique_song_names = list(song_dict.keys())
# random.seed(seed)
# random.shuffle(unique_song_names)
# num_train = int(len(unique_song_names)*ratio)
# num_valid = int(len(unique_song_names)*(1-ratio)/2)
# train_names = []
# valid_names = []
# test_names = []
# for song_name in unique_song_names[:num_train]:
# train_names.extend(song_dict[song_name])
# for song_name in unique_song_names[num_train:num_train+num_valid]:
# valid_names.extend(song_dict[song_name])
# for song_name in unique_song_names[num_train+num_valid:]:
# test_names.extend(song_dict[song_name])
# return train_names, valid_names, test_names, shuffled_tune_names
class ariamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
@ -788,6 +844,9 @@ class gigamidi(SymbolicMusicDataset):
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
if "drums-only" in tune_in_idx_file.stem:
print(f"skipping {tune_in_idx_file.stem} as it is a drums-only file")
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)

View File

@ -11,7 +11,7 @@ License: MIT, see the LICENSE file
__all__ = ['FluidSynth']
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
DEFAULT_SOUND_FONT = 'Alex_GM.sf2'
DEFAULT_SAMPLE_RATE = 48000
DEFAULT_GAIN = 0.05
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"

View File

@ -2,7 +2,8 @@ defaults:
# - nn_params: nb8_embSum_NMT
# - nn_params: remi8
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
- nn_params: nb8_embSum_diff_t2m_150M_pretrainingv2
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
# - nn_params: nb8_embSum_subPararell
# - nn_params: nb8_embSum_diff_t2m_150M
@ -14,7 +15,7 @@ defaults:
# - nn_params: remi8_main12_head_16_dim512
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
captions_path: dataset/midicaps/train_set.json
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
@ -30,20 +31,20 @@ tau: 0.5
train_params:
device: cuda
batch_size: 3
batch_size: 5
grad_clip: 1.0
num_iter: 300000 # total number of iterations
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
iterations_per_training_cycle: 10 # number of iterations for logging training loss
iterations_per_validation_cycle: 5000 # number of iterations for validation process
iterations_per_validation_cycle: 3000 # number of iterations for validation process
input_length: 3072 # input sequence length3072
# you can use focal loss, it it's not used, set focal_gamma to 0
focal_alpha: 1
focal_gamma: 0
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
scheduler : cosinelr
initial_lr: 0.00005
initial_lr: 0.0004
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
warmup_steps: 2000 #number of warmup steps

View File

@ -5,13 +5,13 @@ model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
model_dropout: 0.2
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 768
num_layer: 20
num_layer: 16
num_head: 12
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewFinetunningDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1024
num_layer: 32
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -0,0 +1,19 @@
encoding_scheme: nb
num_features: 8
vocab_name: MusicTokenVocabNB
model_name: AmadeusModel
input_embedder_name: SummationEmbedder
main_decoder_name: XtransformerNewPretrainingDecoder
sub_decoder_name: DiffusionDecoder
model_dropout: 0
input_embedder:
num_layer: 1
num_head: 8
main_decoder:
dim_model: 1024
num_layer: 32
num_head: 16
sub_decoder:
decout_window_size: 1 # 1 means no previous decoding output added
num_layer: 1
feature_enricher_use: False

View File

@ -729,9 +729,8 @@ class XtransformerNewPretrainingDecoder(nn.Module):
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
# shift_tokens = 1,
# attn_qk_norm = True,
# attn_qk_norm_dim_scale = True
attn_gate_values = True,
attn_qk_norm = True,
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
@ -758,7 +757,7 @@ class XtransformerNewPretrainingDecoder(nn.Module):
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
@ -906,6 +905,102 @@ class XtransformerFinetuningDecoder(nn.Module):
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec
class XtransformerNewFinetunningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
if dim != 768:
self.text_project = nn.Linear(768, dim) # assuming T5 base hidden size is 768
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
attn_gate_values = True,
attn_qk_norm = True,
) # add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
if hasattr(self, 'text_project'):
context = self.text_project(context)
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference

View File

@ -113,7 +113,7 @@ class CorpusMaker():
0 to 2000 means no limitation
'''
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)}
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (8, 3000), 'Symphony': (60, 1500)}
try:
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
except KeyError:

View File

@ -105,6 +105,7 @@ def load_resources(wandb_exp_dir, device):
config = wandb_style_config_to_omega_config(config)
# Load checkpoint to specified device
print("Loading checkpoint from:", ckpt_path)
ckpt = torch.load(ckpt_path, map_location=device)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
model.load_state_dict(ckpt['model'], strict=False)

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

442
midi_stastic.py Normal file
View File

@ -0,0 +1,442 @@
#!/usr/bin/env python3
"""
MIDI Statistics Extractor
Usage: python midi_statistics.py <path_to_directory> [options]
This script traverses a directory and all subdirectories to find MID files,
extracts musical features from each file using multi-threading for speed,
and saves the results to CSV files.
"""
import argparse
import pathlib
import os
import csv
import json
from multiprocessing import Pool
from itertools import chain
from math import ceil
from functools import partial
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from symusic import Score
import pandas as pd
from tqdm import tqdm
from numba import njit, prange
@njit
def merge_intervals(intervals: list[tuple[int, int]], threshold: int):
"""Merge overlapping or close intervals."""
out = []
last_s, last_e = intervals[0]
for i in range(1, len(intervals)):
s, e = intervals[i]
if s - last_e <= threshold:
if e > last_e:
last_e = e
else:
out.append((last_s, last_e))
last_s, last_e = s, e
out.append((last_s, last_e))
return out
@njit(fastmath=True)
def note_distribution(events: list[tuple[float, int]], threshold: int = 2, segment_threshold: int = 0):
"""Calculate polyphony rate and sounding segments."""
try:
events.sort()
active_notes = 0
polyphonic_steps = 0
total_steps = 0
last_time = None
last_state = False
last_seg_start = 0
sounding_segments = []
for time, change in events:
if last_time is not None and time != last_time:
if active_notes >= threshold:
polyphonic_steps += (time - last_time)
if active_notes:
total_steps += (time - last_time)
if(last_state != bool(active_notes)):
if(last_state):
last_seg_start = time
else:
sounding_segments.append((last_seg_start, time))
active_notes += change
last_state = bool(active_notes)
last_time = time
if(segment_threshold != 0):
sounding_segments = merge_intervals(sounding_segments, segment_threshold)
return polyphonic_steps / total_steps, total_steps, sounding_segments
except:
return None, None, None
@njit(fastmath=True)
def entropy(X: np.ndarray, base: float = 2.0) -> float:
"""Calculate entropy function optimized with numba."""
N, M = X.shape
out = np.empty(N, dtype=np.float64)
log_base = np.log(base) if base > 0.0 else 1.0
for i in prange(N):
row = X[i]
total = np.nansum(row)
if total <= 0.0:
out[i] = 0.0
continue
mask = (~np.isnan(row)) & (row > 0.0)
probs = row[mask] / total
if probs.size == 0:
out[i] = 0.0
else:
H = -np.sum(probs * np.log(probs))
if base > 0.0:
H /= log_base
out[i] = H
nz = out > 0.0
if not np.any(nz):
return 0.0
return float(np.exp(np.mean(np.log(out[nz]))))
@njit(fastmath=True)
def n_gram_co_occurence_entropy(seq: list[list[int]], N: int = 5):
"""Calculate n-gram co-occurrence entropy."""
counts = []
for seg in seq:
if len(seg) < 2:
continue
arr = np.asarray(seg, dtype=np.int64)
min_val = np.min(arr)
if min_val < 0:
arr = arr - min_val
vocabs = int(np.max(arr) + 1)
wlen = N if len(arr) >= N else len(arr)
nwin = len(arr) - wlen + 1
C = np.zeros((vocabs, vocabs), dtype=np.int64)
for start in range(nwin):
for i in range(wlen - 1):
a = int(arr[start + i])
for j in range(i + 1, wlen):
b = int(arr[start + j])
if a < vocabs and b < vocabs:
C[a, b] += 1
for i in range(vocabs):
counts.append(int(C[i, i]))
for j in range(i + 1, vocabs):
counts.append(int(C[i, j]))
total = 0
for v in counts:
total += v
if total <= 0:
return 0.0
H = 0.0
for v in counts:
if v > 0:
p = v / total
H -= p * np.log(p)
return H
def calc_pitch_distribution(pitches: np.ndarray, window_size: int = 32, hop_size: int = 16):
"""Calculate pitch distribution features."""
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(pitches) > window_size else (lambda x: x.reshape(1, -1))
used_pitches = np.unique(pitches)
n_pitches_used = len(used_pitches)
pitch_entropy = entropy(sw(pitches))
pitch_range = [int(min(used_pitches)), int(max(used_pitches))]
pitch_classes = pitches % 12
n_pitch_classes_used = len(np.unique(pitch_classes))
pitch_class_entropy = entropy(sw(pitch_classes))
return n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range
def calc_rhythmic_entropy(ioi: np.ndarray, window_size: int = 32, hop_size: int = 16):
"""Calculate rhythmic entropy."""
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(ioi) > window_size else (lambda x: x.reshape(1, -1))
if(len(ioi) == 0):
return None
return entropy(sw(ioi))
def extract_features(midi_path: pathlib.Path, tpq: int = 6):
"""Extract features from a single MIDI file."""
try:
seg_threshold = tpq * 8
midi_id = midi_path.parent.name + '/' + midi_path.stem
score = Score(midi_path).resample(tpq)
track_features = []
for i, t in enumerate(score.tracks):
if(not len(t.notes)):
track_features.append((
midi_id, # midi_id
i, # track_id
128 if t.is_drum else t.program, # instrument
0, # end_time
0, # note_num
None, # sounding_interval
None, # note_density
None, # polyphony_rate
None, # rhythmic_entropy
None, # rhythmic_token_co_occurrence_entropy
None, # n_pitch_classes_used
None, # n_pitches_used
None, # pitch_class_entropy
None, # pitch_entropy
None, # pitch_range
None # interval_token_co_occurrence_entropy
))
continue
t.sort()
features = t.notes.numpy()
ioi = np.diff(features['time'])
seg_points = np.where(ioi > tpq * seg_threshold)[0]
polyphony_rate, sounding_interval_length, sounding_segment = note_distribution(list(chain(*
[((note.start, 1), (note.end, -1)) for note in t.notes])))
rhythmic_entropy = calc_rhythmic_entropy(ioi)
rhythmic_token_co_occurrence_entropy = n_gram_co_occurence_entropy([i for i in np.split(ioi, seg_points) if np.all(i) <= seg_threshold])
if(t.is_drum or len(t.notes) < 2):
track_features.append((
midi_id, # midi_id
i, # track_id
128 if t.is_drum else t.program, # instrument
t.end(), # end_time
len(t.notes), # note_num
sounding_interval_length, # sounding_interval
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
polyphony_rate, # polyphony_rate
rhythmic_entropy, # rhythmic_entropy
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
None, # n_pitch_classes_used
None, # n_pitches_used
None, # pitch_class_entropy
None, # pitch_entropy
None, # pitch_range
None # interval_token_co_occurrence_entropy
))
else:
n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range = calc_pitch_distribution(features['pitch'])
intervals = np.diff(features['pitch'])
track_features.append((
midi_id, # midi_id
i, # track_id
t.program, # instrument
t.end(), # end_time
len(t.notes), # note_num
sounding_interval_length, # sounding_interval
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
polyphony_rate, # polyphony_rate
rhythmic_entropy, # rhythmic_entropy
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
n_pitch_classes_used, # n_pitch_classes_used
n_pitches_used, # n_pitches_used
pitch_class_entropy, # pitch_class_entropy
pitch_entropy, # pitch_entropy
json.dumps(pitch_range), # pitch_range
n_gram_co_occurence_entropy([p for i, p in zip(np.split(ioi, seg_points), np.split(intervals, seg_points)) if np.all(i) <= seg_threshold]) # interval_token_co_occurrence_entropy
))
score_features = (
midi_id, # midi_id
sum(tf[4] for tf in track_features) if track_features else 0, # note_num
max(tf[3] for tf in track_features) if track_features else 0, # end_time
json.dumps([[ks.time, ks.key, ks.tonality] for ks in score.key_signatures]), # key
json.dumps([[ts.time, ts.numerator, ts.denominator] for ts in score.time_signatures]), # time_signature
json.dumps([[t.time, t.qpm] for t in score.tempos]) # tempo
)
return score_features, track_features
except Exception as e:
print(f"Error processing {midi_path}: {e}")
return None, None
def find_midi_files(directory: pathlib.Path):
"""Find all MIDI files in directory and subdirectories."""
midi_extensions = {'.mid', '.midi', '.MID', '.MIDI'}
midi_files = []
# Use rglob to recursively find MIDI files
for file_path in directory.rglob('*'):
if file_path.is_file() and file_path.suffix in midi_extensions:
midi_files.append(file_path)
return midi_files
def process_midi_files(directory: pathlib.Path, output_prefix: str = "midi_features",
num_threads: int = 4, tpq: int = 6):
"""Process MIDI files with multi-threading and save to CSV."""
# Find all MIDI files
print(f"Searching for MIDI files in: {directory}")
midi_files = find_midi_files(directory)
if not midi_files:
print(f"No MIDI files found in {directory}")
return
print(f"Found {len(midi_files)} MIDI files")
# Create extractor function with fixed parameters
extractor = partial(extract_features, tpq=tpq)
# Feature column names
score_feat_cols = ['midi_id', 'note_num', 'end_time', 'key', 'time_signature', 'tempo']
track_feat_cols = ['midi_id', 'track_id', 'instrument', 'end_time', 'note_num',
'sounding_interval', 'note_density', 'polyphony_rate', 'rhythmic_entropy',
'rhythmic_token_co_occurrence_entropy', 'n_pitch_classes_used',
'n_pitches_used', 'pitch_class_entropy', 'pitch_entropy', 'pitch_range',
'interval_token_co_occurrence_entropy']
# Process files with multiprocessing
print(f"Processing files with {num_threads} threads...")
with Pool(num_threads) as pool:
# Open CSV files for writing
with open(f'{output_prefix}_score_features.csv', 'w', newline='', encoding='utf-8') as score_csvfile:
score_writer = csv.writer(score_csvfile)
score_writer.writerow(score_feat_cols)
with open(f'{output_prefix}_track_features.csv', 'w', newline='', encoding='utf-8') as track_csvfile:
track_writer = csv.writer(track_csvfile)
track_writer.writerow(track_feat_cols)
# Process files with progress bar
processed_count = 0
skipped_count = 0
for score_feat, track_feats in tqdm(pool.imap_unordered(extractor, midi_files),
total=len(midi_files),
desc="Processing MIDI files"):
if not (score_feat, track_feats):
skipped_count += 1
continue
processed_count += 1
# Write score features
score_writer.writerow(score_feat)
# Write track features
if track_feats:
track_writer.writerows(track_feats)
print(f"\nProcessing complete!")
print(f"Successfully processed: {processed_count} files")
print(f"Skipped due to errors: {skipped_count} files")
print(f"Score features saved to: {output_prefix}_score_features.csv")
print(f"Track features saved to: {output_prefix}_track_features.csv")
def main():
"""Main function with command line argument parsing."""
parser = argparse.ArgumentParser(
description="Extract musical features from MIDI files and save to CSV",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python midi_statistics.py /path/to/midi/files
python midi_statistics.py /path/to/midi/files --threads 8 --output my_features
python midi_statistics.py /path/to/midi/files --tpq 12 --threads 2
Features extracted:
- Score level: note count, end time, key signatures, time signatures, tempo
- Track level: instrument, note density, polyphony rate, rhythmic entropy,
pitch distribution, and more
"""
)
parser.add_argument('directory',
help='Path to directory containing MIDI files')
parser.add_argument('--threads', '-t',
type=int,
default=4,
help='Number of threads to use (default: 4)')
parser.add_argument('--output', '-o',
type=str,
default='midi_features',
help='Output file prefix (default: midi_features)')
parser.add_argument('--tpq',
type=int,
default=6,
help='Ticks per quarter note for resampling (default: 6)')
args = parser.parse_args()
# Validate directory
directory = pathlib.Path(args.directory)
if not directory.exists():
print(f"Error: Directory '{directory}' does not exist")
return 1
if not directory.is_dir():
print(f"Error: '{directory}' is not a directory")
return 1
# Validate threads
if args.threads < 1:
print("Error: Number of threads must be at least 1")
return 1
try:
process_midi_files(directory, args.output, args.threads, args.tpq)
return 0
except KeyboardInterrupt:
print("\nProcessing interrupted by user")
return 1
except Exception as e:
print(f"Error: {e}")
return 1
if __name__ == "__main__":
exit(main())

105
,idi_sim.py Normal file
View File

@ -0,0 +1,105 @@
import os
import numpy as np
import pandas as pd
from symusic import Score
from concurrent.futures import ProcessPoolExecutor, as_completed
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True):
if(not a.shape[1] or not b.shape[1]):
return np.inf
a_onset, a_pitch = a
b_onset, b_pitch = b
a_onset = a_onset.astype(np.float32)
b_onset = b_onset.astype(np.float32)
a_pitch = a_pitch.astype(np.uint8)
b_pitch = b_pitch.astype(np.uint8)
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
if(oti):
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12]
dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight)
a2b = dist_matrix.min(2)
b2a = dist_matrix.min(1)
dist = np.concatenate([a2b, b2a], axis=1)
return dist.sum(axis=1).min() / len(dist)
else:
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12]
dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight)
a2b = dist_matrix.min(1)
b2a = dist_matrix.min(0)
return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1]))
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.):
x = sorted(x)
end_time = x[-1][0]
out = [[] for _ in range(int(end_time // hop_size))]
for i in sorted(x):
segment = min(int(i[0] // hop_size), len(out) - 1)
while(i[0] >= segment * hop_size):
out[segment].append(i)
segment -= 1
if(segment < 0):
break
return out
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
a = midi_time_sliding_window(a)
b = midi_time_sliding_window(b)
dist = np.inf
for i in a:
for j in b:
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
if(cur_dist < dist):
dist = cur_dist
return dist
def extract_notes(filepath: str):
"""读取MIDI并返回 (time, pitch) 列表"""
try:
s = Score(filepath).to("quarter")
notes = []
for t in s.tracks:
notes.extend([(n.time, n.pitch) for n in t.notes])
return notes
except Exception as e:
print(f"读取 {filepath} 出错: {e}")
return []
def compare_pair(file_a: str, file_b: str):
notes_a = extract_notes(file_a)
notes_b = extract_notes(file_b)
if not notes_a or not notes_b:
return (file_a, file_b, np.inf)
dist = midi_dist(notes_a, notes_b)
return (file_a, file_b, dist)
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
results = []
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
for fut in as_completed(futures):
results.append(fut.result())
# 排序
results = sorted(results, key=lambda x: x[2])
# 保存
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
df.to_csv(out_csv, index=False)
print(f"已保存结果到 {out_csv}")
if __name__ == "__main__":
dir_a = "folder_a"
dir_b = "folder_b"
batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8)