Files
MIDIFoundationModel/SongEval/ebr.py
2025-09-08 14:49:28 +08:00

104 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import glob
import os
import pandas as pd
import muspy
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
def compute_midi_metrics(file_path):
"""计算单个MIDI文件的音乐指标"""
try:
music = muspy.read(file_path)
scale_consistency = muspy.scale_consistency(music)
pitch_entropy = muspy.pitch_entropy(music)
pitch_class_entropy = muspy.pitch_class_entropy(music)
empty_beat_rate = muspy.empty_beat_rate(music)
groove_consistency = muspy.groove_consistency(music, 12)
metrics = {
'scale_consistency': scale_consistency,
'pitch_entropy': pitch_entropy,
'pitch_class_entropy': pitch_class_entropy,
'empty_beat_rate': empty_beat_rate,
'groove_consistency': groove_consistency,
'filename': os.path.basename(file_path)
}
return metrics
except Exception as e:
print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
return None
def compute_directory_metrics(directory_path, num_workers=8):
"""计算目录下所有MIDI文件的音乐指标多线程加速"""
midi_files = []
for root, _, files in os.walk(directory_path):
for file in files:
if file.lower().endswith(('.mid', '.midi')):
midi_files.append(os.path.join(root, file))
if not midi_files:
print("目录及子文件夹中未找到MIDI文件")
return None
all_metrics = []
average_metrics = {
'scale_consistency': 0,
'pitch_entropy': 0,
'pitch_class_entropy': 0,
'empty_beat_rate': 0,
'groove_consistency': 0
}
current_num = 0
total_scale_consistency = 0
total_pitch_entropy = 0
total_pitch_class_entropy = 0
total_empty_beat_rate = 0
total_groove_consistency = 0
print(f"正在处理目录: {directory_path}")
print(f"发现 {len(midi_files)} 个MIDI文件:")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(compute_midi_metrics, midi_file): midi_file for midi_file in midi_files}
for future in tqdm(as_completed(futures), total=len(midi_files), desc="处理中"):
metrics = future.result()
if metrics is not None:
current_num += 1
total_scale_consistency += metrics['scale_consistency']
total_pitch_entropy += metrics['pitch_entropy']
total_pitch_class_entropy += metrics['pitch_class_entropy']
total_empty_beat_rate += metrics['empty_beat_rate']
total_groove_consistency += metrics['groove_consistency']
average_metrics['scale_consistency'] = total_scale_consistency / current_num
average_metrics['pitch_entropy'] = total_pitch_entropy / current_num
average_metrics['pitch_class_entropy'] = total_pitch_class_entropy / current_num
average_metrics['empty_beat_rate'] = total_empty_beat_rate / current_num
average_metrics['groove_consistency'] = total_groove_consistency / current_num
print("current_metrics:", metrics)
all_metrics.append(metrics)
if not all_metrics:
print("所有文件处理失败")
return None
df = pd.DataFrame(all_metrics)
output_csv = os.path.join(directory_path, "midi_metrics_report.csv")
df.to_csv(output_csv, index=False)
avg_metrics = df.mean(numeric_only=True)
return df, avg_metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="计算目录下所有MIDI文件的音乐指标")
parser.add_argument("path", type=str, help="包含MIDI文件的目录路径")
parser.add_argument("--threads", type=int, default=1, help="线程数默认8")
args = parser.parse_args()
if not os.path.isdir(args.path):
print(f"错误: 路径 '{args.path}' 不存在或不是目录")
else:
result, averages = compute_directory_metrics(args.path, num_workers=args.threads)
if result is not None:
print("\n计算完成! 结果已保存到 midi_metrics_report.csv")
print("\n平均指标值:")
print(averages.to_string())