443 lines
16 KiB
Python
443 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
MIDI Statistics Extractor
|
|
|
|
Usage: python midi_statistics.py <path_to_directory> [options]
|
|
|
|
This script traverses a directory and all subdirectories to find MID files,
|
|
extracts musical features from each file using multi-threading for speed,
|
|
and saves the results to CSV files.
|
|
"""
|
|
|
|
import argparse
|
|
import pathlib
|
|
import os
|
|
import csv
|
|
import json
|
|
from multiprocessing import Pool
|
|
from itertools import chain
|
|
from math import ceil
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
from numpy.lib.stride_tricks import sliding_window_view
|
|
from symusic import Score
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from numba import njit, prange
|
|
|
|
|
|
@njit
|
|
def merge_intervals(intervals: list[tuple[int, int]], threshold: int):
|
|
"""Merge overlapping or close intervals."""
|
|
out = []
|
|
last_s, last_e = intervals[0]
|
|
|
|
for i in range(1, len(intervals)):
|
|
s, e = intervals[i]
|
|
|
|
if s - last_e <= threshold:
|
|
if e > last_e:
|
|
last_e = e
|
|
else:
|
|
out.append((last_s, last_e))
|
|
last_s, last_e = s, e
|
|
|
|
out.append((last_s, last_e))
|
|
return out
|
|
|
|
|
|
@njit(fastmath=True)
|
|
def note_distribution(events: list[tuple[float, int]], threshold: int = 2, segment_threshold: int = 0):
|
|
"""Calculate polyphony rate and sounding segments."""
|
|
try:
|
|
events.sort()
|
|
active_notes = 0
|
|
polyphonic_steps = 0
|
|
total_steps = 0
|
|
last_time = None
|
|
last_state = False
|
|
last_seg_start = 0
|
|
sounding_segments = []
|
|
|
|
for time, change in events:
|
|
if last_time is not None and time != last_time:
|
|
if active_notes >= threshold:
|
|
polyphonic_steps += (time - last_time)
|
|
if active_notes:
|
|
total_steps += (time - last_time)
|
|
if(last_state != bool(active_notes)):
|
|
if(last_state):
|
|
last_seg_start = time
|
|
else:
|
|
sounding_segments.append((last_seg_start, time))
|
|
|
|
active_notes += change
|
|
last_state = bool(active_notes)
|
|
last_time = time
|
|
|
|
if(segment_threshold != 0):
|
|
sounding_segments = merge_intervals(sounding_segments, segment_threshold)
|
|
|
|
return polyphonic_steps / total_steps, total_steps, sounding_segments
|
|
except:
|
|
return None, None, None
|
|
|
|
|
|
@njit(fastmath=True)
|
|
def entropy(X: np.ndarray, base: float = 2.0) -> float:
|
|
"""Calculate entropy function optimized with numba."""
|
|
N, M = X.shape
|
|
out = np.empty(N, dtype=np.float64)
|
|
log_base = np.log(base) if base > 0.0 else 1.0
|
|
|
|
for i in prange(N):
|
|
row = X[i]
|
|
total = np.nansum(row)
|
|
if total <= 0.0:
|
|
out[i] = 0.0
|
|
continue
|
|
|
|
mask = (~np.isnan(row)) & (row > 0.0)
|
|
probs = row[mask] / total
|
|
if probs.size == 0:
|
|
out[i] = 0.0
|
|
else:
|
|
H = -np.sum(probs * np.log(probs))
|
|
if base > 0.0:
|
|
H /= log_base
|
|
out[i] = H
|
|
|
|
nz = out > 0.0
|
|
if not np.any(nz):
|
|
return 0.0
|
|
return float(np.exp(np.mean(np.log(out[nz]))))
|
|
|
|
|
|
@njit(fastmath=True)
|
|
def n_gram_co_occurence_entropy(seq: list[list[int]], N: int = 5):
|
|
"""Calculate n-gram co-occurrence entropy."""
|
|
counts = []
|
|
|
|
for seg in seq:
|
|
if len(seg) < 2:
|
|
continue
|
|
|
|
arr = np.asarray(seg, dtype=np.int64)
|
|
|
|
min_val = np.min(arr)
|
|
if min_val < 0:
|
|
arr = arr - min_val
|
|
|
|
vocabs = int(np.max(arr) + 1)
|
|
|
|
wlen = N if len(arr) >= N else len(arr)
|
|
nwin = len(arr) - wlen + 1
|
|
|
|
C = np.zeros((vocabs, vocabs), dtype=np.int64)
|
|
|
|
for start in range(nwin):
|
|
for i in range(wlen - 1):
|
|
a = int(arr[start + i])
|
|
for j in range(i + 1, wlen):
|
|
b = int(arr[start + j])
|
|
if a < vocabs and b < vocabs:
|
|
C[a, b] += 1
|
|
|
|
for i in range(vocabs):
|
|
counts.append(int(C[i, i]))
|
|
for j in range(i + 1, vocabs):
|
|
counts.append(int(C[i, j]))
|
|
|
|
total = 0
|
|
for v in counts:
|
|
total += v
|
|
|
|
if total <= 0:
|
|
return 0.0
|
|
|
|
H = 0.0
|
|
for v in counts:
|
|
if v > 0:
|
|
p = v / total
|
|
H -= p * np.log(p)
|
|
|
|
return H
|
|
|
|
|
|
def calc_pitch_distribution(pitches: np.ndarray, window_size: int = 32, hop_size: int = 16):
|
|
"""Calculate pitch distribution features."""
|
|
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(pitches) > window_size else (lambda x: x.reshape(1, -1))
|
|
|
|
used_pitches = np.unique(pitches)
|
|
n_pitches_used = len(used_pitches)
|
|
pitch_entropy = entropy(sw(pitches))
|
|
pitch_range = [int(min(used_pitches)), int(max(used_pitches))]
|
|
|
|
pitch_classes = pitches % 12
|
|
n_pitch_classes_used = len(np.unique(pitch_classes))
|
|
pitch_class_entropy = entropy(sw(pitch_classes))
|
|
|
|
return n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range
|
|
|
|
|
|
def calc_rhythmic_entropy(ioi: np.ndarray, window_size: int = 32, hop_size: int = 16):
|
|
"""Calculate rhythmic entropy."""
|
|
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(ioi) > window_size else (lambda x: x.reshape(1, -1))
|
|
if(len(ioi) == 0):
|
|
return None
|
|
return entropy(sw(ioi))
|
|
|
|
|
|
def extract_features(midi_path: pathlib.Path, tpq: int = 6):
|
|
"""Extract features from a single MIDI file."""
|
|
try:
|
|
seg_threshold = tpq * 8
|
|
midi_id = midi_path.parent.name + '/' + midi_path.stem
|
|
score = Score(midi_path).resample(tpq)
|
|
|
|
track_features = []
|
|
for i, t in enumerate(score.tracks):
|
|
if(not len(t.notes)):
|
|
track_features.append((
|
|
midi_id, # midi_id
|
|
i, # track_id
|
|
128 if t.is_drum else t.program, # instrument
|
|
|
|
0, # end_time
|
|
0, # note_num
|
|
None, # sounding_interval
|
|
|
|
None, # note_density
|
|
None, # polyphony_rate
|
|
None, # rhythmic_entropy
|
|
None, # rhythmic_token_co_occurrence_entropy
|
|
|
|
None, # n_pitch_classes_used
|
|
None, # n_pitches_used
|
|
None, # pitch_class_entropy
|
|
None, # pitch_entropy
|
|
None, # pitch_range
|
|
None # interval_token_co_occurrence_entropy
|
|
))
|
|
continue
|
|
t.sort()
|
|
|
|
features = t.notes.numpy()
|
|
|
|
ioi = np.diff(features['time'])
|
|
seg_points = np.where(ioi > tpq * seg_threshold)[0]
|
|
|
|
polyphony_rate, sounding_interval_length, sounding_segment = note_distribution(list(chain(*
|
|
[((note.start, 1), (note.end, -1)) for note in t.notes])))
|
|
rhythmic_entropy = calc_rhythmic_entropy(ioi)
|
|
|
|
rhythmic_token_co_occurrence_entropy = n_gram_co_occurence_entropy([i for i in np.split(ioi, seg_points) if np.all(i) <= seg_threshold])
|
|
|
|
if(t.is_drum or len(t.notes) < 2):
|
|
track_features.append((
|
|
midi_id, # midi_id
|
|
i, # track_id
|
|
128 if t.is_drum else t.program, # instrument
|
|
|
|
t.end(), # end_time
|
|
len(t.notes), # note_num
|
|
sounding_interval_length, # sounding_interval
|
|
|
|
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
|
|
polyphony_rate, # polyphony_rate
|
|
rhythmic_entropy, # rhythmic_entropy
|
|
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
|
|
|
|
None, # n_pitch_classes_used
|
|
None, # n_pitches_used
|
|
None, # pitch_class_entropy
|
|
None, # pitch_entropy
|
|
None, # pitch_range
|
|
None # interval_token_co_occurrence_entropy
|
|
))
|
|
else:
|
|
n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range = calc_pitch_distribution(features['pitch'])
|
|
intervals = np.diff(features['pitch'])
|
|
track_features.append((
|
|
midi_id, # midi_id
|
|
i, # track_id
|
|
t.program, # instrument
|
|
|
|
t.end(), # end_time
|
|
len(t.notes), # note_num
|
|
sounding_interval_length, # sounding_interval
|
|
|
|
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
|
|
polyphony_rate, # polyphony_rate
|
|
rhythmic_entropy, # rhythmic_entropy
|
|
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
|
|
|
|
n_pitch_classes_used, # n_pitch_classes_used
|
|
n_pitches_used, # n_pitches_used
|
|
pitch_class_entropy, # pitch_class_entropy
|
|
pitch_entropy, # pitch_entropy
|
|
json.dumps(pitch_range), # pitch_range
|
|
n_gram_co_occurence_entropy([p for i, p in zip(np.split(ioi, seg_points), np.split(intervals, seg_points)) if np.all(i) <= seg_threshold]) # interval_token_co_occurrence_entropy
|
|
))
|
|
|
|
score_features = (
|
|
midi_id, # midi_id
|
|
sum(tf[4] for tf in track_features) if track_features else 0, # note_num
|
|
max(tf[3] for tf in track_features) if track_features else 0, # end_time
|
|
json.dumps([[ks.time, ks.key, ks.tonality] for ks in score.key_signatures]), # key
|
|
json.dumps([[ts.time, ts.numerator, ts.denominator] for ts in score.time_signatures]), # time_signature
|
|
json.dumps([[t.time, t.qpm] for t in score.tempos]) # tempo
|
|
)
|
|
|
|
return score_features, track_features
|
|
except Exception as e:
|
|
print(f"Error processing {midi_path}: {e}")
|
|
return None, None
|
|
|
|
|
|
def find_midi_files(directory: pathlib.Path):
|
|
"""Find all MIDI files in directory and subdirectories."""
|
|
midi_extensions = {'.mid', '.midi', '.MID', '.MIDI'}
|
|
midi_files = []
|
|
|
|
# Use rglob to recursively find MIDI files
|
|
for file_path in directory.rglob('*'):
|
|
if file_path.is_file() and file_path.suffix in midi_extensions:
|
|
midi_files.append(file_path)
|
|
|
|
return midi_files
|
|
|
|
|
|
def process_midi_files(directory: pathlib.Path, output_prefix: str = "midi_features",
|
|
num_threads: int = 4, tpq: int = 6):
|
|
"""Process MIDI files with multi-threading and save to CSV."""
|
|
|
|
# Find all MIDI files
|
|
print(f"Searching for MIDI files in: {directory}")
|
|
midi_files = find_midi_files(directory)
|
|
|
|
if not midi_files:
|
|
print(f"No MIDI files found in {directory}")
|
|
return
|
|
|
|
print(f"Found {len(midi_files)} MIDI files")
|
|
|
|
# Create extractor function with fixed parameters
|
|
extractor = partial(extract_features, tpq=tpq)
|
|
|
|
# Feature column names
|
|
score_feat_cols = ['midi_id', 'note_num', 'end_time', 'key', 'time_signature', 'tempo']
|
|
track_feat_cols = ['midi_id', 'track_id', 'instrument', 'end_time', 'note_num',
|
|
'sounding_interval', 'note_density', 'polyphony_rate', 'rhythmic_entropy',
|
|
'rhythmic_token_co_occurrence_entropy', 'n_pitch_classes_used',
|
|
'n_pitches_used', 'pitch_class_entropy', 'pitch_entropy', 'pitch_range',
|
|
'interval_token_co_occurrence_entropy']
|
|
|
|
# Process files with multiprocessing
|
|
print(f"Processing files with {num_threads} threads...")
|
|
|
|
with Pool(num_threads) as pool:
|
|
# Open CSV files for writing
|
|
with open(f'{output_prefix}_score_features.csv', 'w', newline='', encoding='utf-8') as score_csvfile:
|
|
score_writer = csv.writer(score_csvfile)
|
|
score_writer.writerow(score_feat_cols)
|
|
|
|
with open(f'{output_prefix}_track_features.csv', 'w', newline='', encoding='utf-8') as track_csvfile:
|
|
track_writer = csv.writer(track_csvfile)
|
|
track_writer.writerow(track_feat_cols)
|
|
|
|
# Process files with progress bar
|
|
processed_count = 0
|
|
skipped_count = 0
|
|
|
|
for score_feat, track_feats in tqdm(pool.imap_unordered(extractor, midi_files),
|
|
total=len(midi_files),
|
|
desc="Processing MIDI files"):
|
|
if not (score_feat, track_feats):
|
|
skipped_count += 1
|
|
continue
|
|
|
|
processed_count += 1
|
|
|
|
# Write score features
|
|
score_writer.writerow(score_feat)
|
|
|
|
# Write track features
|
|
if track_feats:
|
|
track_writer.writerows(track_feats)
|
|
|
|
print(f"\nProcessing complete!")
|
|
print(f"Successfully processed: {processed_count} files")
|
|
print(f"Skipped due to errors: {skipped_count} files")
|
|
print(f"Score features saved to: {output_prefix}_score_features.csv")
|
|
print(f"Track features saved to: {output_prefix}_track_features.csv")
|
|
|
|
|
|
def main():
|
|
"""Main function with command line argument parsing."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Extract musical features from MIDI files and save to CSV",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
python midi_statistics.py /path/to/midi/files
|
|
python midi_statistics.py /path/to/midi/files --threads 8 --output my_features
|
|
python midi_statistics.py /path/to/midi/files --tpq 12 --threads 2
|
|
|
|
Features extracted:
|
|
- Score level: note count, end time, key signatures, time signatures, tempo
|
|
- Track level: instrument, note density, polyphony rate, rhythmic entropy,
|
|
pitch distribution, and more
|
|
"""
|
|
)
|
|
|
|
parser.add_argument('directory',
|
|
help='Path to directory containing MIDI files')
|
|
|
|
parser.add_argument('--threads', '-t',
|
|
type=int,
|
|
default=4,
|
|
help='Number of threads to use (default: 4)')
|
|
|
|
parser.add_argument('--output', '-o',
|
|
type=str,
|
|
default='midi_features',
|
|
help='Output file prefix (default: midi_features)')
|
|
|
|
parser.add_argument('--tpq',
|
|
type=int,
|
|
default=6,
|
|
help='Ticks per quarter note for resampling (default: 6)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate directory
|
|
directory = pathlib.Path(args.directory)
|
|
if not directory.exists():
|
|
print(f"Error: Directory '{directory}' does not exist")
|
|
return 1
|
|
|
|
if not directory.is_dir():
|
|
print(f"Error: '{directory}' is not a directory")
|
|
return 1
|
|
|
|
# Validate threads
|
|
if args.threads < 1:
|
|
print("Error: Number of threads must be at least 1")
|
|
return 1
|
|
|
|
try:
|
|
process_midi_files(directory, args.output, args.threads, args.tpq)
|
|
return 0
|
|
except KeyboardInterrupt:
|
|
print("\nProcessing interrupted by user")
|
|
return 1
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|