first commit
This commit is contained in:
103
SongEval/ebr.py
Normal file
103
SongEval/ebr.py
Normal file
@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import pandas as pd
|
||||
import muspy
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
|
||||
def compute_midi_metrics(file_path):
|
||||
"""计算单个MIDI文件的音乐指标"""
|
||||
try:
|
||||
music = muspy.read(file_path)
|
||||
scale_consistency = muspy.scale_consistency(music)
|
||||
pitch_entropy = muspy.pitch_entropy(music)
|
||||
pitch_class_entropy = muspy.pitch_class_entropy(music)
|
||||
empty_beat_rate = muspy.empty_beat_rate(music)
|
||||
groove_consistency = muspy.groove_consistency(music, 12)
|
||||
metrics = {
|
||||
'scale_consistency': scale_consistency,
|
||||
'pitch_entropy': pitch_entropy,
|
||||
'pitch_class_entropy': pitch_class_entropy,
|
||||
'empty_beat_rate': empty_beat_rate,
|
||||
'groove_consistency': groove_consistency,
|
||||
'filename': os.path.basename(file_path)
|
||||
}
|
||||
return metrics
|
||||
except Exception as e:
|
||||
print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def compute_directory_metrics(directory_path, num_workers=8):
|
||||
"""计算目录下所有MIDI文件的音乐指标(多线程加速)"""
|
||||
midi_files = []
|
||||
for root, _, files in os.walk(directory_path):
|
||||
for file in files:
|
||||
if file.lower().endswith(('.mid', '.midi')):
|
||||
midi_files.append(os.path.join(root, file))
|
||||
if not midi_files:
|
||||
print("目录及子文件夹中未找到MIDI文件")
|
||||
return None
|
||||
|
||||
all_metrics = []
|
||||
average_metrics = {
|
||||
'scale_consistency': 0,
|
||||
'pitch_entropy': 0,
|
||||
'pitch_class_entropy': 0,
|
||||
'empty_beat_rate': 0,
|
||||
'groove_consistency': 0
|
||||
}
|
||||
current_num = 0
|
||||
total_scale_consistency = 0
|
||||
total_pitch_entropy = 0
|
||||
total_pitch_class_entropy = 0
|
||||
total_empty_beat_rate = 0
|
||||
total_groove_consistency = 0
|
||||
print(f"正在处理目录: {directory_path}")
|
||||
print(f"发现 {len(midi_files)} 个MIDI文件:")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {executor.submit(compute_midi_metrics, midi_file): midi_file for midi_file in midi_files}
|
||||
for future in tqdm(as_completed(futures), total=len(midi_files), desc="处理中"):
|
||||
metrics = future.result()
|
||||
|
||||
if metrics is not None:
|
||||
current_num += 1
|
||||
total_scale_consistency += metrics['scale_consistency']
|
||||
total_pitch_entropy += metrics['pitch_entropy']
|
||||
total_pitch_class_entropy += metrics['pitch_class_entropy']
|
||||
total_empty_beat_rate += metrics['empty_beat_rate']
|
||||
total_groove_consistency += metrics['groove_consistency']
|
||||
average_metrics['scale_consistency'] = total_scale_consistency / current_num
|
||||
average_metrics['pitch_entropy'] = total_pitch_entropy / current_num
|
||||
average_metrics['pitch_class_entropy'] = total_pitch_class_entropy / current_num
|
||||
average_metrics['empty_beat_rate'] = total_empty_beat_rate / current_num
|
||||
average_metrics['groove_consistency'] = total_groove_consistency / current_num
|
||||
print("current_metrics:", metrics)
|
||||
|
||||
all_metrics.append(metrics)
|
||||
|
||||
if not all_metrics:
|
||||
print("所有文件处理失败")
|
||||
return None
|
||||
|
||||
df = pd.DataFrame(all_metrics)
|
||||
output_csv = os.path.join(directory_path, "midi_metrics_report.csv")
|
||||
df.to_csv(output_csv, index=False)
|
||||
avg_metrics = df.mean(numeric_only=True)
|
||||
return df, avg_metrics
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="计算目录下所有MIDI文件的音乐指标")
|
||||
parser.add_argument("path", type=str, help="包含MIDI文件的目录路径")
|
||||
parser.add_argument("--threads", type=int, default=1, help="线程数(默认8)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isdir(args.path):
|
||||
print(f"错误: 路径 '{args.path}' 不存在或不是目录")
|
||||
else:
|
||||
result, averages = compute_directory_metrics(args.path, num_workers=args.threads)
|
||||
if result is not None:
|
||||
print("\n计算完成! 结果已保存到 midi_metrics_report.csv")
|
||||
print("\n平均指标值:")
|
||||
print(averages.to_string())
|
||||
Reference in New Issue
Block a user