Files
MIDIFoundationModel/data_representation/test.py
2025-11-27 15:44:17 +08:00

283 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统计octuple分词结果中每一列每个token的出现次数并生成分析报告
"""
import os
import numpy as np
from pathlib import Path
from collections import defaultdict, Counter
from tqdm import tqdm
import json
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_octuple_data(data_dir):
"""
加载所有octuple分词后的.npz文件
Args:
data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录
Returns:
list: 所有加载的numpy数组列表
"""
data_dir = Path(data_dir)
npz_files = []
# 如果目录存在,查找所有.npz文件
if data_dir.exists():
npz_files = list(data_dir.rglob("*.npz"))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件开始加载...")
all_data = []
for npz_file in tqdm(npz_files, desc="加载数据"):
try:
data = np.load(npz_file)['arr_0']
# 确保数据是二维数组 (num_tokens, num_columns)
if data.ndim == 2:
all_data.append(data)
elif data.ndim == 1:
# 如果是一维可能需要reshape但octuple应该是二维的
print(f"警告: {npz_file} 是一维数组,跳过")
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
continue
return all_data
def count_tokens_by_column(all_data):
"""
统计每一列每个token的出现次数
Args:
all_data: 所有数据的列表每个元素是一个numpy数组 (num_tokens, num_columns)
Returns:
dict: {column_index: Counter({token_value: count})}
"""
column_counters = defaultdict(Counter)
print("统计token出现次数...")
for data in tqdm(all_data, desc="处理数据"):
if data.size == 0:
continue
num_columns = data.shape[1] if data.ndim == 2 else 1
for col_idx in range(num_columns):
if data.ndim == 2:
tokens = data[:, col_idx]
else:
tokens = data
# 统计该列中每个token的出现次数
unique, counts = np.unique(tokens, return_counts=True)
for token, count in zip(unique, counts):
column_counters[col_idx][int(token)] += int(count)
return dict(column_counters)
def generate_report(column_counters, output_file=None):
"""
生成分析报告
Args:
column_counters: 统计结果字典
output_file: 输出文件路径(可选)
"""
report_lines = []
report_lines.append("=" * 80)
report_lines.append("OCTUPLE分词结果统计分析报告")
report_lines.append("=" * 80)
report_lines.append("")
# 总体统计
total_tokens = sum(sum(counter.values()) for counter in column_counters.values())
report_lines.append(f"总token数: {total_tokens:,}")
report_lines.append(f"分析的列数: {len(column_counters)}")
report_lines.append("")
# 每一列的详细统计
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
report_lines.append("-" * 80)
report_lines.append(f"{col_idx}: {col_name}")
report_lines.append("-" * 80)
total_in_column = sum(counter.values())
unique_tokens = len(counter)
min_token = min(counter.keys()) if counter else 0
max_token = max(counter.keys()) if counter else 0
report_lines.append(f" 总token数: {total_in_column:,}")
report_lines.append(f" 唯一token数: {unique_tokens:,}")
report_lines.append(f" Token值范围: [{min_token}, {max_token}]")
report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A")
report_lines.append("")
# Top 20 最常见的token
report_lines.append(f" Top 20 最常见的token:")
top_tokens = counter.most_common(20)
for rank, (token, count) in enumerate(top_tokens, 1):
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
report_lines.append("")
# Top 20 最不常见的token出现次数>0的
report_lines.append(f" Top 20 最不常见的token (出现次数>0):")
bottom_tokens = counter.most_common()[-20:]
bottom_tokens.reverse()
for rank, (token, count) in enumerate(bottom_tokens, 1):
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
report_lines.append("")
# 分布统计
counts_list = list(counter.values())
if counts_list:
report_lines.append(f" 分布统计:")
report_lines.append(f" 最小出现次数: {min(counts_list):,}")
report_lines.append(f" 最大出现次数: {max(counts_list):,}")
report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}")
report_lines.append(f" 标准差: {np.std(counts_list):,.2f}")
report_lines.append("")
# 跨列分析
report_lines.append("=" * 80)
report_lines.append("跨列分析")
report_lines.append("=" * 80)
report_lines.append("")
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
total_in_column = sum(counter.values())
percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0
report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)")
report_lines.append("")
report_lines.append("=" * 80)
report_lines.append("报告生成完成")
report_lines.append("=" * 80)
# 输出报告
report_text = "\n".join(report_lines)
print("\n" + report_text)
# 保存到文件
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(report_text)
print(f"\n报告已保存到: {output_path}")
# 同时保存JSON格式的详细数据
if output_file:
json_output = output_path.with_suffix('.json')
json_data = {
'summary': {
'total_tokens': total_tokens,
'num_columns': len(column_counters)
},
'columns': {}
}
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
json_data['columns'][col_name] = {
'total_tokens': sum(counter.values()),
'unique_tokens': len(counter),
'token_counts': dict(counter),
'top_20': dict(counter.most_common(20)),
'bottom_20': dict(counter.most_common()[-20:])
}
with open(json_output, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
print(f"详细数据已保存到: {json_output}")
def main():
"""主函数"""
# 默认数据目录 - 可以根据需要修改
default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi"
# 可以指定具体的数据目录,例如:
data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2"
# 或者使用默认目录扫描所有oct8目录
# import sys
# if len(sys.argv) > 1:
# data_dir = sys.argv[1]
# else:
# # 自动查找所有oct8目录
# base_dir = Path(default_data_dir)
# oct8_dirs = list(base_dir.rglob("oct8"))
# if oct8_dirs:
# print(f"找到以下oct8目录:")
# for i, d in enumerate(oct8_dirs, 1):
# print(f" {i}. {d}")
# if len(oct8_dirs) == 1:
# data_dir = str(oct8_dirs[0])
# print(f"\n使用目录: {data_dir}")
# else:
# # 使用第一个找到的目录,或者合并所有目录
# print(f"\n使用第一个目录: {oct8_dirs[0]}")
# print("如需分析其他目录,请指定路径作为参数")
# data_dir = str(oct8_dirs[0])
# else:
# data_dir = default_data_dir
# print(f"未找到oct8目录使用默认目录: {data_dir}")
# 加载数据
all_data = load_octuple_data(data_dir)
if not all_data:
print("错误: 未加载到任何数据")
return
# 检查数据格式
if all_data:
sample = all_data[0]
print(f"\n数据格式检查:")
print(f" 样本形状: {sample.shape}")
print(f" 样本数据类型: {sample.dtype}")
print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}")
print()
# 统计token出现次数
column_counters = count_tokens_by_column(all_data)
# 生成报告
output_file = "octuple_token_analysis_report_part.txt"
generate_report(column_counters, output_file)
if __name__ == "__main__":
main()