1127 update to latest
This commit is contained in:
@ -1,14 +1,282 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统计octuple分词结果中每一列每个token的出现次数,并生成分析报告
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import defaultdict, Counter
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
# 读取 npz 文件
|
||||
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
|
||||
# Octuple的列名定义
|
||||
COLUMN_NAMES = [
|
||||
"pitch", # 0: Pitch/PitchDrum
|
||||
"position", # 1: Position
|
||||
"bar", # 2: Bar
|
||||
"velocity", # 3: Velocity
|
||||
"duration", # 4: Duration
|
||||
"program", # 5: Program
|
||||
"tempo", # 6: Tempo
|
||||
"timesig" # 7: TimeSignature
|
||||
]
|
||||
|
||||
# 查看保存的键
|
||||
print(data.files)
|
||||
# 输出:['filename', 'sequence']
|
||||
|
||||
# 访问数据
|
||||
sequence = data["arr_0"]
|
||||
def load_octuple_data(data_dir):
|
||||
"""
|
||||
加载所有octuple分词后的.npz文件
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录
|
||||
|
||||
Returns:
|
||||
list: 所有加载的numpy数组列表
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
# 如果目录存在,查找所有.npz文件
|
||||
if data_dir.exists():
|
||||
npz_files = list(data_dir.rglob("*.npz"))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return []
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件,开始加载...")
|
||||
|
||||
all_data = []
|
||||
for npz_file in tqdm(npz_files, desc="加载数据"):
|
||||
try:
|
||||
data = np.load(npz_file)['arr_0']
|
||||
# 确保数据是二维数组 (num_tokens, num_columns)
|
||||
if data.ndim == 2:
|
||||
all_data.append(data)
|
||||
elif data.ndim == 1:
|
||||
# 如果是一维,可能需要reshape,但octuple应该是二维的
|
||||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||||
continue
|
||||
|
||||
return all_data
|
||||
|
||||
|
||||
def count_tokens_by_column(all_data):
|
||||
"""
|
||||
统计每一列每个token的出现次数
|
||||
|
||||
Args:
|
||||
all_data: 所有数据的列表,每个元素是一个numpy数组 (num_tokens, num_columns)
|
||||
|
||||
Returns:
|
||||
dict: {column_index: Counter({token_value: count})}
|
||||
"""
|
||||
column_counters = defaultdict(Counter)
|
||||
|
||||
print("统计token出现次数...")
|
||||
for data in tqdm(all_data, desc="处理数据"):
|
||||
if data.size == 0:
|
||||
continue
|
||||
|
||||
num_columns = data.shape[1] if data.ndim == 2 else 1
|
||||
|
||||
for col_idx in range(num_columns):
|
||||
if data.ndim == 2:
|
||||
tokens = data[:, col_idx]
|
||||
else:
|
||||
tokens = data
|
||||
|
||||
# 统计该列中每个token的出现次数
|
||||
unique, counts = np.unique(tokens, return_counts=True)
|
||||
for token, count in zip(unique, counts):
|
||||
column_counters[col_idx][int(token)] += int(count)
|
||||
|
||||
return dict(column_counters)
|
||||
|
||||
|
||||
def generate_report(column_counters, output_file=None):
|
||||
"""
|
||||
生成分析报告
|
||||
|
||||
Args:
|
||||
column_counters: 统计结果字典
|
||||
output_file: 输出文件路径(可选)
|
||||
"""
|
||||
report_lines = []
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("OCTUPLE分词结果统计分析报告")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
# 总体统计
|
||||
total_tokens = sum(sum(counter.values()) for counter in column_counters.values())
|
||||
report_lines.append(f"总token数: {total_tokens:,}")
|
||||
report_lines.append(f"分析的列数: {len(column_counters)}")
|
||||
report_lines.append("")
|
||||
|
||||
# 每一列的详细统计
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
|
||||
report_lines.append("-" * 80)
|
||||
report_lines.append(f"列 {col_idx}: {col_name}")
|
||||
report_lines.append("-" * 80)
|
||||
|
||||
total_in_column = sum(counter.values())
|
||||
unique_tokens = len(counter)
|
||||
min_token = min(counter.keys()) if counter else 0
|
||||
max_token = max(counter.keys()) if counter else 0
|
||||
|
||||
report_lines.append(f" 总token数: {total_in_column:,}")
|
||||
report_lines.append(f" 唯一token数: {unique_tokens:,}")
|
||||
report_lines.append(f" Token值范围: [{min_token}, {max_token}]")
|
||||
report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A")
|
||||
report_lines.append("")
|
||||
|
||||
# Top 20 最常见的token
|
||||
report_lines.append(f" Top 20 最常见的token:")
|
||||
top_tokens = counter.most_common(20)
|
||||
for rank, (token, count) in enumerate(top_tokens, 1):
|
||||
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
|
||||
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
|
||||
report_lines.append("")
|
||||
|
||||
# Top 20 最不常见的token(出现次数>0的)
|
||||
report_lines.append(f" Top 20 最不常见的token (出现次数>0):")
|
||||
bottom_tokens = counter.most_common()[-20:]
|
||||
bottom_tokens.reverse()
|
||||
for rank, (token, count) in enumerate(bottom_tokens, 1):
|
||||
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
|
||||
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
|
||||
report_lines.append("")
|
||||
|
||||
# 分布统计
|
||||
counts_list = list(counter.values())
|
||||
if counts_list:
|
||||
report_lines.append(f" 分布统计:")
|
||||
report_lines.append(f" 最小出现次数: {min(counts_list):,}")
|
||||
report_lines.append(f" 最大出现次数: {max(counts_list):,}")
|
||||
report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}")
|
||||
report_lines.append(f" 标准差: {np.std(counts_list):,.2f}")
|
||||
report_lines.append("")
|
||||
|
||||
# 跨列分析
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("跨列分析")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
total_in_column = sum(counter.values())
|
||||
percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0
|
||||
report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)")
|
||||
|
||||
report_lines.append("")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("报告生成完成")
|
||||
report_lines.append("=" * 80)
|
||||
|
||||
# 输出报告
|
||||
report_text = "\n".join(report_lines)
|
||||
print("\n" + report_text)
|
||||
|
||||
# 保存到文件
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report_text)
|
||||
print(f"\n报告已保存到: {output_path}")
|
||||
|
||||
# 同时保存JSON格式的详细数据
|
||||
if output_file:
|
||||
json_output = output_path.with_suffix('.json')
|
||||
json_data = {
|
||||
'summary': {
|
||||
'total_tokens': total_tokens,
|
||||
'num_columns': len(column_counters)
|
||||
},
|
||||
'columns': {}
|
||||
}
|
||||
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
json_data['columns'][col_name] = {
|
||||
'total_tokens': sum(counter.values()),
|
||||
'unique_tokens': len(counter),
|
||||
'token_counts': dict(counter),
|
||||
'top_20': dict(counter.most_common(20)),
|
||||
'bottom_20': dict(counter.most_common()[-20:])
|
||||
}
|
||||
|
||||
with open(json_output, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"详细数据已保存到: {json_output}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 默认数据目录 - 可以根据需要修改
|
||||
default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi"
|
||||
|
||||
# 可以指定具体的数据目录,例如:
|
||||
data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2"
|
||||
# 或者使用默认目录扫描所有oct8目录
|
||||
|
||||
# import sys
|
||||
# if len(sys.argv) > 1:
|
||||
# data_dir = sys.argv[1]
|
||||
# else:
|
||||
# # 自动查找所有oct8目录
|
||||
# base_dir = Path(default_data_dir)
|
||||
# oct8_dirs = list(base_dir.rglob("oct8"))
|
||||
# if oct8_dirs:
|
||||
# print(f"找到以下oct8目录:")
|
||||
# for i, d in enumerate(oct8_dirs, 1):
|
||||
# print(f" {i}. {d}")
|
||||
# if len(oct8_dirs) == 1:
|
||||
# data_dir = str(oct8_dirs[0])
|
||||
# print(f"\n使用目录: {data_dir}")
|
||||
# else:
|
||||
# # 使用第一个找到的目录,或者合并所有目录
|
||||
# print(f"\n使用第一个目录: {oct8_dirs[0]}")
|
||||
# print("如需分析其他目录,请指定路径作为参数")
|
||||
# data_dir = str(oct8_dirs[0])
|
||||
# else:
|
||||
# data_dir = default_data_dir
|
||||
# print(f"未找到oct8目录,使用默认目录: {data_dir}")
|
||||
|
||||
# 加载数据
|
||||
all_data = load_octuple_data(data_dir)
|
||||
|
||||
if not all_data:
|
||||
print("错误: 未加载到任何数据")
|
||||
return
|
||||
|
||||
# 检查数据格式
|
||||
if all_data:
|
||||
sample = all_data[0]
|
||||
print(f"\n数据格式检查:")
|
||||
print(f" 样本形状: {sample.shape}")
|
||||
print(f" 样本数据类型: {sample.dtype}")
|
||||
print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}")
|
||||
print()
|
||||
|
||||
# 统计token出现次数
|
||||
column_counters = count_tokens_by_column(all_data)
|
||||
|
||||
# 生成报告
|
||||
output_file = "octuple_token_analysis_report_part.txt"
|
||||
generate_report(column_counters, output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
print("token 序列长度:", len(sequence))
|
||||
print("前 20 个 token:", sequence[:20])
|
||||
Reference in New Issue
Block a user