1013 update

This commit is contained in:
FelixChan
2025-10-13 17:56:36 +08:00
parent d077e3210e
commit d6b68ef90b
17 changed files with 815 additions and 70 deletions

442
midi_stastic.py Normal file
View File

@ -0,0 +1,442 @@
#!/usr/bin/env python3
"""
MIDI Statistics Extractor
Usage: python midi_statistics.py <path_to_directory> [options]
This script traverses a directory and all subdirectories to find MID files,
extracts musical features from each file using multi-threading for speed,
and saves the results to CSV files.
"""
import argparse
import pathlib
import os
import csv
import json
from multiprocessing import Pool
from itertools import chain
from math import ceil
from functools import partial
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from symusic import Score
import pandas as pd
from tqdm import tqdm
from numba import njit, prange
@njit
def merge_intervals(intervals: list[tuple[int, int]], threshold: int):
"""Merge overlapping or close intervals."""
out = []
last_s, last_e = intervals[0]
for i in range(1, len(intervals)):
s, e = intervals[i]
if s - last_e <= threshold:
if e > last_e:
last_e = e
else:
out.append((last_s, last_e))
last_s, last_e = s, e
out.append((last_s, last_e))
return out
@njit(fastmath=True)
def note_distribution(events: list[tuple[float, int]], threshold: int = 2, segment_threshold: int = 0):
"""Calculate polyphony rate and sounding segments."""
try:
events.sort()
active_notes = 0
polyphonic_steps = 0
total_steps = 0
last_time = None
last_state = False
last_seg_start = 0
sounding_segments = []
for time, change in events:
if last_time is not None and time != last_time:
if active_notes >= threshold:
polyphonic_steps += (time - last_time)
if active_notes:
total_steps += (time - last_time)
if(last_state != bool(active_notes)):
if(last_state):
last_seg_start = time
else:
sounding_segments.append((last_seg_start, time))
active_notes += change
last_state = bool(active_notes)
last_time = time
if(segment_threshold != 0):
sounding_segments = merge_intervals(sounding_segments, segment_threshold)
return polyphonic_steps / total_steps, total_steps, sounding_segments
except:
return None, None, None
@njit(fastmath=True)
def entropy(X: np.ndarray, base: float = 2.0) -> float:
"""Calculate entropy function optimized with numba."""
N, M = X.shape
out = np.empty(N, dtype=np.float64)
log_base = np.log(base) if base > 0.0 else 1.0
for i in prange(N):
row = X[i]
total = np.nansum(row)
if total <= 0.0:
out[i] = 0.0
continue
mask = (~np.isnan(row)) & (row > 0.0)
probs = row[mask] / total
if probs.size == 0:
out[i] = 0.0
else:
H = -np.sum(probs * np.log(probs))
if base > 0.0:
H /= log_base
out[i] = H
nz = out > 0.0
if not np.any(nz):
return 0.0
return float(np.exp(np.mean(np.log(out[nz]))))
@njit(fastmath=True)
def n_gram_co_occurence_entropy(seq: list[list[int]], N: int = 5):
"""Calculate n-gram co-occurrence entropy."""
counts = []
for seg in seq:
if len(seg) < 2:
continue
arr = np.asarray(seg, dtype=np.int64)
min_val = np.min(arr)
if min_val < 0:
arr = arr - min_val
vocabs = int(np.max(arr) + 1)
wlen = N if len(arr) >= N else len(arr)
nwin = len(arr) - wlen + 1
C = np.zeros((vocabs, vocabs), dtype=np.int64)
for start in range(nwin):
for i in range(wlen - 1):
a = int(arr[start + i])
for j in range(i + 1, wlen):
b = int(arr[start + j])
if a < vocabs and b < vocabs:
C[a, b] += 1
for i in range(vocabs):
counts.append(int(C[i, i]))
for j in range(i + 1, vocabs):
counts.append(int(C[i, j]))
total = 0
for v in counts:
total += v
if total <= 0:
return 0.0
H = 0.0
for v in counts:
if v > 0:
p = v / total
H -= p * np.log(p)
return H
def calc_pitch_distribution(pitches: np.ndarray, window_size: int = 32, hop_size: int = 16):
"""Calculate pitch distribution features."""
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(pitches) > window_size else (lambda x: x.reshape(1, -1))
used_pitches = np.unique(pitches)
n_pitches_used = len(used_pitches)
pitch_entropy = entropy(sw(pitches))
pitch_range = [int(min(used_pitches)), int(max(used_pitches))]
pitch_classes = pitches % 12
n_pitch_classes_used = len(np.unique(pitch_classes))
pitch_class_entropy = entropy(sw(pitch_classes))
return n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range
def calc_rhythmic_entropy(ioi: np.ndarray, window_size: int = 32, hop_size: int = 16):
"""Calculate rhythmic entropy."""
sw = (lambda x: sliding_window_view(x, window_size)[::hop_size, :]) if len(ioi) > window_size else (lambda x: x.reshape(1, -1))
if(len(ioi) == 0):
return None
return entropy(sw(ioi))
def extract_features(midi_path: pathlib.Path, tpq: int = 6):
"""Extract features from a single MIDI file."""
try:
seg_threshold = tpq * 8
midi_id = midi_path.parent.name + '/' + midi_path.stem
score = Score(midi_path).resample(tpq)
track_features = []
for i, t in enumerate(score.tracks):
if(not len(t.notes)):
track_features.append((
midi_id, # midi_id
i, # track_id
128 if t.is_drum else t.program, # instrument
0, # end_time
0, # note_num
None, # sounding_interval
None, # note_density
None, # polyphony_rate
None, # rhythmic_entropy
None, # rhythmic_token_co_occurrence_entropy
None, # n_pitch_classes_used
None, # n_pitches_used
None, # pitch_class_entropy
None, # pitch_entropy
None, # pitch_range
None # interval_token_co_occurrence_entropy
))
continue
t.sort()
features = t.notes.numpy()
ioi = np.diff(features['time'])
seg_points = np.where(ioi > tpq * seg_threshold)[0]
polyphony_rate, sounding_interval_length, sounding_segment = note_distribution(list(chain(*
[((note.start, 1), (note.end, -1)) for note in t.notes])))
rhythmic_entropy = calc_rhythmic_entropy(ioi)
rhythmic_token_co_occurrence_entropy = n_gram_co_occurence_entropy([i for i in np.split(ioi, seg_points) if np.all(i) <= seg_threshold])
if(t.is_drum or len(t.notes) < 2):
track_features.append((
midi_id, # midi_id
i, # track_id
128 if t.is_drum else t.program, # instrument
t.end(), # end_time
len(t.notes), # note_num
sounding_interval_length, # sounding_interval
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
polyphony_rate, # polyphony_rate
rhythmic_entropy, # rhythmic_entropy
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
None, # n_pitch_classes_used
None, # n_pitches_used
None, # pitch_class_entropy
None, # pitch_entropy
None, # pitch_range
None # interval_token_co_occurrence_entropy
))
else:
n_pitch_classes_used, n_pitches_used, pitch_class_entropy, pitch_entropy, pitch_range = calc_pitch_distribution(features['pitch'])
intervals = np.diff(features['pitch'])
track_features.append((
midi_id, # midi_id
i, # track_id
t.program, # instrument
t.end(), # end_time
len(t.notes), # note_num
sounding_interval_length, # sounding_interval
len(t.notes) / ceil(sounding_interval_length) if sounding_interval_length else None, # note_density
polyphony_rate, # polyphony_rate
rhythmic_entropy, # rhythmic_entropy
rhythmic_token_co_occurrence_entropy, # rhythmic_token_co_occurrence_entropy
n_pitch_classes_used, # n_pitch_classes_used
n_pitches_used, # n_pitches_used
pitch_class_entropy, # pitch_class_entropy
pitch_entropy, # pitch_entropy
json.dumps(pitch_range), # pitch_range
n_gram_co_occurence_entropy([p for i, p in zip(np.split(ioi, seg_points), np.split(intervals, seg_points)) if np.all(i) <= seg_threshold]) # interval_token_co_occurrence_entropy
))
score_features = (
midi_id, # midi_id
sum(tf[4] for tf in track_features) if track_features else 0, # note_num
max(tf[3] for tf in track_features) if track_features else 0, # end_time
json.dumps([[ks.time, ks.key, ks.tonality] for ks in score.key_signatures]), # key
json.dumps([[ts.time, ts.numerator, ts.denominator] for ts in score.time_signatures]), # time_signature
json.dumps([[t.time, t.qpm] for t in score.tempos]) # tempo
)
return score_features, track_features
except Exception as e:
print(f"Error processing {midi_path}: {e}")
return None, None
def find_midi_files(directory: pathlib.Path):
"""Find all MIDI files in directory and subdirectories."""
midi_extensions = {'.mid', '.midi', '.MID', '.MIDI'}
midi_files = []
# Use rglob to recursively find MIDI files
for file_path in directory.rglob('*'):
if file_path.is_file() and file_path.suffix in midi_extensions:
midi_files.append(file_path)
return midi_files
def process_midi_files(directory: pathlib.Path, output_prefix: str = "midi_features",
num_threads: int = 4, tpq: int = 6):
"""Process MIDI files with multi-threading and save to CSV."""
# Find all MIDI files
print(f"Searching for MIDI files in: {directory}")
midi_files = find_midi_files(directory)
if not midi_files:
print(f"No MIDI files found in {directory}")
return
print(f"Found {len(midi_files)} MIDI files")
# Create extractor function with fixed parameters
extractor = partial(extract_features, tpq=tpq)
# Feature column names
score_feat_cols = ['midi_id', 'note_num', 'end_time', 'key', 'time_signature', 'tempo']
track_feat_cols = ['midi_id', 'track_id', 'instrument', 'end_time', 'note_num',
'sounding_interval', 'note_density', 'polyphony_rate', 'rhythmic_entropy',
'rhythmic_token_co_occurrence_entropy', 'n_pitch_classes_used',
'n_pitches_used', 'pitch_class_entropy', 'pitch_entropy', 'pitch_range',
'interval_token_co_occurrence_entropy']
# Process files with multiprocessing
print(f"Processing files with {num_threads} threads...")
with Pool(num_threads) as pool:
# Open CSV files for writing
with open(f'{output_prefix}_score_features.csv', 'w', newline='', encoding='utf-8') as score_csvfile:
score_writer = csv.writer(score_csvfile)
score_writer.writerow(score_feat_cols)
with open(f'{output_prefix}_track_features.csv', 'w', newline='', encoding='utf-8') as track_csvfile:
track_writer = csv.writer(track_csvfile)
track_writer.writerow(track_feat_cols)
# Process files with progress bar
processed_count = 0
skipped_count = 0
for score_feat, track_feats in tqdm(pool.imap_unordered(extractor, midi_files),
total=len(midi_files),
desc="Processing MIDI files"):
if not (score_feat, track_feats):
skipped_count += 1
continue
processed_count += 1
# Write score features
score_writer.writerow(score_feat)
# Write track features
if track_feats:
track_writer.writerows(track_feats)
print(f"\nProcessing complete!")
print(f"Successfully processed: {processed_count} files")
print(f"Skipped due to errors: {skipped_count} files")
print(f"Score features saved to: {output_prefix}_score_features.csv")
print(f"Track features saved to: {output_prefix}_track_features.csv")
def main():
"""Main function with command line argument parsing."""
parser = argparse.ArgumentParser(
description="Extract musical features from MIDI files and save to CSV",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python midi_statistics.py /path/to/midi/files
python midi_statistics.py /path/to/midi/files --threads 8 --output my_features
python midi_statistics.py /path/to/midi/files --tpq 12 --threads 2
Features extracted:
- Score level: note count, end time, key signatures, time signatures, tempo
- Track level: instrument, note density, polyphony rate, rhythmic entropy,
pitch distribution, and more
"""
)
parser.add_argument('directory',
help='Path to directory containing MIDI files')
parser.add_argument('--threads', '-t',
type=int,
default=4,
help='Number of threads to use (default: 4)')
parser.add_argument('--output', '-o',
type=str,
default='midi_features',
help='Output file prefix (default: midi_features)')
parser.add_argument('--tpq',
type=int,
default=6,
help='Ticks per quarter note for resampling (default: 6)')
args = parser.parse_args()
# Validate directory
directory = pathlib.Path(args.directory)
if not directory.exists():
print(f"Error: Directory '{directory}' does not exist")
return 1
if not directory.is_dir():
print(f"Error: '{directory}' is not a directory")
return 1
# Validate threads
if args.threads < 1:
print("Error: Number of threads must be at least 1")
return 1
try:
process_midi_files(directory, args.output, args.threads, args.tpq)
return 0
except KeyboardInterrupt:
print("\nProcessing interrupted by user")
return 1
except Exception as e:
print(f"Error: {e}")
return 1
if __name__ == "__main__":
exit(main())