Files
MIDIFoundationModel/midi_sim.py
2025-10-21 15:27:03 +08:00

134 lines
5.5 KiB
Python

import os
from math import ceil
#CUDA_VISIBLE_DEVICES= "0"
import numpy as np
import pandas as pd
from symusic import Score
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (0., 1.5)):
if(not a.shape[1] or not b.shape[1]):
return np.inf
a_onset, a_pitch = a
b_onset, b_pitch = b
a_onset = a_onset.astype(np.float32)
b_onset = b_onset.astype(np.float32)
a_pitch = a_pitch.astype(np.int16)
b_pitch = b_pitch.astype(np.int16)
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
a2b_idx = onset_dist_matrix.argmin(1)
b2a_idx = onset_dist_matrix.argmin(0)
a_pitch -= (np.median(a_pitch) - np.median(b_pitch)).astype(np.int16) # Normalize pitch
a_pitch = a_pitch + np.arange(-7, 7).reshape(-1, 1) # Transpose invarient
interval_diff = np.concatenate([
a_pitch[:, a2b_idx] - b_pitch,
b_pitch[b2a_idx] - a_pitch], axis=1)
pitch_dist = np.abs(semitone2degree[interval_diff % 8] + np.abs(interval_diff) // 8 * np.sign(interval_diff)).mean(1).min()
onset_dist = np.abs(np.concatenate([
a_onset[a2b_idx] - b_onset,
b_onset[b2a_idx] - a_onset], axis=0)).mean()
return (weight[0] * onset_dist + weight[1] * pitch_dist) / sum(weight)
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 8., hop_size: float = 4.):
x = sorted(x)
trim_offset = (x[0][0] // hop_size) * hop_size
end_time = x[-1][0]
num_segment = ceil((end_time - window_size - trim_offset) / hop_size) + 1
time_matrix = (np.fromiter((time for time, _ in x), dtype=float) - trim_offset).reshape(1, -1).repeat(num_segment, axis=0)
seg_time_starts = np.arange(num_segment).reshape(-1, 1) * hop_size
time_compare_matrix = np.where((time_matrix >= seg_time_starts) & (time_matrix <= seg_time_starts + window_size), 0, 1)
time_compare_matrix = np.diff(np.pad(time_compare_matrix, ((0, 0), (1, 1)), constant_values=1))
start_idxs = sorted(np.where(time_compare_matrix == -1), key=lambda x: x[0])[1].tolist()
end_idxs = sorted(np.where(time_compare_matrix == 1), key=lambda x: x[0])[1].tolist()
segments = [x[start:end] for start, end in zip(start_idxs, end_idxs)]
return segments
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
a = midi_time_sliding_window(a, window_size=window_size, hop_size=hop_size)
b = midi_time_sliding_window(b, window_size=window_size, hop_size=hop_size)
dist = np.inf
for x,i in enumerate(a):
for y,j in enumerate(b):
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
if cur_dist == 0:
print(x, y)
if(cur_dist < dist):
dist = cur_dist
return float(dist)
def extract_notes(filepath: str):
"""读取MIDI并返回 (time, pitch) 列表"""
try:
s = Score(filepath).to("quarter")
notes = []
# for t in s.tracks:
# notes.extend([(n.time, n.pitch) for n in t.notes])
notes = [(n.time, n.pitch) for n in s.tracks[0].notes] # 仅使用第一个track
return notes
except Exception as e:
print(f"读取 {filepath} 出错: {e}")
return []
def compare_pair(file_a: str, file_b: str):
try:
notes_a = extract_notes(file_a)
notes_b = extract_notes(file_b)
if not notes_a or not notes_b:
return (file_a, file_b, np.inf)
dist = midi_dist(notes_a, notes_b)
return (file_a, file_b, dist)
except Exception as e:
import traceback
print(f"⚠️ compare_pair 出错: {file_a} vs {file_b}")
traceback.print_exc()
return (file_a, file_b, np.inf)
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
files_a = files_a[:100] # 仅比较前100个文件以节省时间
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
results = []
pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
for fut in as_completed(futures):
pbar.update(1)
try:
results.append(fut.result())
except Exception as e:
print(fut.result())
print(f"Error comparing pair: {e}")
# print(f"Compared: {results[-1][0]} vs {results[-1][1]}, Distance: {results[-1][2]:.4f}")
# with tqdm(total=len(files_a) * len(files_b)) as pbar:
# for fa in files_a:
# for fb in files_b:
# results.append(compare_pair(fa, fb))
# pbar.update(1)
# # 排序
results = sorted(results, key=lambda x: x[2])
# 保存
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
df.to_csv(out_csv, index=False)
print(f"已保存结果到 {out_csv}")
if __name__ == "__main__":
dir_a = "wandb/run-20251015_154556-f0pj3ys3/cond_4m_top_p_t0.99_temp1.25/process_2_batch_23"
dir_b = "dataset/Melody"
batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6)