104 lines
4.3 KiB
Python
104 lines
4.3 KiB
Python
import argparse
|
||
import glob
|
||
import os
|
||
import pandas as pd
|
||
import muspy
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from tqdm import tqdm
|
||
|
||
def compute_midi_metrics(file_path):
|
||
"""计算单个MIDI文件的音乐指标"""
|
||
try:
|
||
music = muspy.read(file_path)
|
||
scale_consistency = muspy.scale_consistency(music)
|
||
pitch_entropy = muspy.pitch_entropy(music)
|
||
pitch_class_entropy = muspy.pitch_class_entropy(music)
|
||
empty_beat_rate = muspy.empty_beat_rate(music)
|
||
groove_consistency = muspy.groove_consistency(music, 12)
|
||
metrics = {
|
||
'scale_consistency': scale_consistency,
|
||
'pitch_entropy': pitch_entropy,
|
||
'pitch_class_entropy': pitch_class_entropy,
|
||
'empty_beat_rate': empty_beat_rate,
|
||
'groove_consistency': groove_consistency,
|
||
'filename': os.path.basename(file_path)
|
||
}
|
||
return metrics
|
||
except Exception as e:
|
||
print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
|
||
return None
|
||
|
||
def compute_directory_metrics(directory_path, num_workers=8):
|
||
"""计算目录下所有MIDI文件的音乐指标(多线程加速)"""
|
||
midi_files = []
|
||
for root, _, files in os.walk(directory_path):
|
||
for file in files:
|
||
if file.lower().endswith(('.mid', '.midi')):
|
||
midi_files.append(os.path.join(root, file))
|
||
if not midi_files:
|
||
print("目录及子文件夹中未找到MIDI文件")
|
||
return None
|
||
|
||
all_metrics = []
|
||
average_metrics = {
|
||
'scale_consistency': 0,
|
||
'pitch_entropy': 0,
|
||
'pitch_class_entropy': 0,
|
||
'empty_beat_rate': 0,
|
||
'groove_consistency': 0
|
||
}
|
||
current_num = 0
|
||
total_scale_consistency = 0
|
||
total_pitch_entropy = 0
|
||
total_pitch_class_entropy = 0
|
||
total_empty_beat_rate = 0
|
||
total_groove_consistency = 0
|
||
print(f"正在处理目录: {directory_path}")
|
||
print(f"发现 {len(midi_files)} 个MIDI文件:")
|
||
|
||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||
futures = {executor.submit(compute_midi_metrics, midi_file): midi_file for midi_file in midi_files}
|
||
for future in tqdm(as_completed(futures), total=len(midi_files), desc="处理中"):
|
||
metrics = future.result()
|
||
|
||
if metrics is not None:
|
||
current_num += 1
|
||
total_scale_consistency += metrics['scale_consistency']
|
||
total_pitch_entropy += metrics['pitch_entropy']
|
||
total_pitch_class_entropy += metrics['pitch_class_entropy']
|
||
total_empty_beat_rate += metrics['empty_beat_rate']
|
||
total_groove_consistency += metrics['groove_consistency']
|
||
average_metrics['scale_consistency'] = total_scale_consistency / current_num
|
||
average_metrics['pitch_entropy'] = total_pitch_entropy / current_num
|
||
average_metrics['pitch_class_entropy'] = total_pitch_class_entropy / current_num
|
||
average_metrics['empty_beat_rate'] = total_empty_beat_rate / current_num
|
||
average_metrics['groove_consistency'] = total_groove_consistency / current_num
|
||
print("current_metrics:", metrics)
|
||
|
||
all_metrics.append(metrics)
|
||
|
||
if not all_metrics:
|
||
print("所有文件处理失败")
|
||
return None
|
||
|
||
df = pd.DataFrame(all_metrics)
|
||
output_csv = os.path.join(directory_path, "midi_metrics_report.csv")
|
||
df.to_csv(output_csv, index=False)
|
||
avg_metrics = df.mean(numeric_only=True)
|
||
return df, avg_metrics
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="计算目录下所有MIDI文件的音乐指标")
|
||
parser.add_argument("path", type=str, help="包含MIDI文件的目录路径")
|
||
parser.add_argument("--threads", type=int, default=1, help="线程数(默认8)")
|
||
args = parser.parse_args()
|
||
|
||
if not os.path.isdir(args.path):
|
||
print(f"错误: 路径 '{args.path}' 不存在或不是目录")
|
||
else:
|
||
result, averages = compute_directory_metrics(args.path, num_workers=args.threads)
|
||
if result is not None:
|
||
print("\n计算完成! 结果已保存到 midi_metrics_report.csv")
|
||
print("\n平均指标值:")
|
||
print(averages.to_string())
|