1013 update
This commit is contained in:
105
,idi_sim.py
Normal file
105
,idi_sim.py
Normal file
@ -0,0 +1,105 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from symusic import Score
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
semitone2degree = np.array([0, 2, 2, 3, 3, 4, 4.5, 4, 3, 3, 2, 2])
|
||||
|
||||
def hausdorff_dist(a: np.ndarray, b: np.ndarray, weight: tuple[float, float] = (2., 1.5), oti: bool = True):
|
||||
if(not a.shape[1] or not b.shape[1]):
|
||||
return np.inf
|
||||
a_onset, a_pitch = a
|
||||
b_onset, b_pitch = b
|
||||
a_onset = a_onset.astype(np.float32)
|
||||
b_onset = b_onset.astype(np.float32)
|
||||
a_pitch = a_pitch.astype(np.uint8)
|
||||
b_pitch = b_pitch.astype(np.uint8)
|
||||
|
||||
onset_dist_matrix = np.abs(a_onset.reshape(1, -1) - b_onset.reshape(-1, 1))
|
||||
if(oti):
|
||||
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, 1, -1) + np.arange(12).reshape(-1, 1, 1) - b_pitch.reshape(-1, 1)) % 12]
|
||||
dist_matrix = (weight[0] * np.expand_dims(onset_dist_matrix, 0) + weight[1] * pitch_dist_matrix) / sum(weight)
|
||||
a2b = dist_matrix.min(2)
|
||||
b2a = dist_matrix.min(1)
|
||||
dist = np.concatenate([a2b, b2a], axis=1)
|
||||
return dist.sum(axis=1).min() / len(dist)
|
||||
else:
|
||||
pitch_dist_matrix = semitone2degree[np.abs(a_pitch.reshape(1, -1) - b_pitch.reshape(-1, 1)) % 12]
|
||||
dist_matrix = (weight[0] * onset_dist_matrix + weight[1] * pitch_dist_matrix) / sum(weight)
|
||||
a2b = dist_matrix.min(1)
|
||||
b2a = dist_matrix.min(0)
|
||||
return float((a2b.sum() + b2a.sum()) / (a.shape[1] + b.shape[1]))
|
||||
|
||||
|
||||
def midi_time_sliding_window(x: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4.):
|
||||
x = sorted(x)
|
||||
end_time = x[-1][0]
|
||||
out = [[] for _ in range(int(end_time // hop_size))]
|
||||
for i in sorted(x):
|
||||
segment = min(int(i[0] // hop_size), len(out) - 1)
|
||||
while(i[0] >= segment * hop_size):
|
||||
out[segment].append(i)
|
||||
segment -= 1
|
||||
if(segment < 0):
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def midi_dist(a: list[tuple[float, int]], b: list[tuple[float, int]], window_size: float = 16., hop_size: float = 4):
|
||||
a = midi_time_sliding_window(a)
|
||||
b = midi_time_sliding_window(b)
|
||||
dist = np.inf
|
||||
for i in a:
|
||||
for j in b:
|
||||
cur_dist = hausdorff_dist(np.array(i, dtype=np.float32).T, np.array(j, dtype=np.float32).T)
|
||||
if(cur_dist < dist):
|
||||
dist = cur_dist
|
||||
return dist
|
||||
|
||||
|
||||
def extract_notes(filepath: str):
|
||||
"""读取MIDI并返回 (time, pitch) 列表"""
|
||||
try:
|
||||
s = Score(filepath).to("quarter")
|
||||
notes = []
|
||||
for t in s.tracks:
|
||||
notes.extend([(n.time, n.pitch) for n in t.notes])
|
||||
return notes
|
||||
except Exception as e:
|
||||
print(f"读取 {filepath} 出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def compare_pair(file_a: str, file_b: str):
|
||||
notes_a = extract_notes(file_a)
|
||||
notes_b = extract_notes(file_b)
|
||||
if not notes_a or not notes_b:
|
||||
return (file_a, file_b, np.inf)
|
||||
dist = midi_dist(notes_a, notes_b)
|
||||
return (file_a, file_b, dist)
|
||||
|
||||
|
||||
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
|
||||
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
|
||||
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
|
||||
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(compare_pair, fa, fb) for fa in files_a for fb in files_b]
|
||||
for fut in as_completed(futures):
|
||||
results.append(fut.result())
|
||||
|
||||
# 排序
|
||||
results = sorted(results, key=lambda x: x[2])
|
||||
|
||||
# 保存
|
||||
df = pd.DataFrame(results, columns=["file_a", "file_b", "distance"])
|
||||
df.to_csv(out_csv, index=False)
|
||||
print(f"已保存结果到 {out_csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dir_a = "folder_a"
|
||||
dir_b = "folder_b"
|
||||
batch_compare(dir_a, dir_b, out_csv="midi_similarity.csv", max_workers=8)
|
||||
Reference in New Issue
Block a user