#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 统计octuple分词结果中每一列每个token的出现次数,并生成分析报告 """ import os import numpy as np from pathlib import Path from collections import defaultdict, Counter from tqdm import tqdm import json # Octuple的列名定义 COLUMN_NAMES = [ "pitch", # 0: Pitch/PitchDrum "position", # 1: Position "bar", # 2: Bar "velocity", # 3: Velocity "duration", # 4: Duration "program", # 5: Program "tempo", # 6: Tempo "timesig" # 7: TimeSignature ] def load_octuple_data(data_dir): """ 加载所有octuple分词后的.npz文件 Args: data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录 Returns: list: 所有加载的numpy数组列表 """ data_dir = Path(data_dir) npz_files = [] # 如果目录存在,查找所有.npz文件 if data_dir.exists(): npz_files = list(data_dir.rglob("*.npz")) if not npz_files: print(f"警告: 在 {data_dir} 中未找到.npz文件") return [] print(f"找到 {len(npz_files)} 个.npz文件,开始加载...") all_data = [] for npz_file in tqdm(npz_files, desc="加载数据"): try: data = np.load(npz_file)['arr_0'] # 确保数据是二维数组 (num_tokens, num_columns) if data.ndim == 2: all_data.append(data) elif data.ndim == 1: # 如果是一维,可能需要reshape,但octuple应该是二维的 print(f"警告: {npz_file} 是一维数组,跳过") except Exception as e: print(f"错误: 加载 {npz_file} 时出错: {e}") continue return all_data def count_tokens_by_column(all_data): """ 统计每一列每个token的出现次数 Args: all_data: 所有数据的列表,每个元素是一个numpy数组 (num_tokens, num_columns) Returns: dict: {column_index: Counter({token_value: count})} """ column_counters = defaultdict(Counter) print("统计token出现次数...") for data in tqdm(all_data, desc="处理数据"): if data.size == 0: continue num_columns = data.shape[1] if data.ndim == 2 else 1 for col_idx in range(num_columns): if data.ndim == 2: tokens = data[:, col_idx] else: tokens = data # 统计该列中每个token的出现次数 unique, counts = np.unique(tokens, return_counts=True) for token, count in zip(unique, counts): column_counters[col_idx][int(token)] += int(count) return dict(column_counters) def generate_report(column_counters, output_file=None): """ 生成分析报告 Args: column_counters: 统计结果字典 output_file: 输出文件路径(可选) """ report_lines = [] report_lines.append("=" * 80) report_lines.append("OCTUPLE分词结果统计分析报告") report_lines.append("=" * 80) report_lines.append("") # 总体统计 total_tokens = sum(sum(counter.values()) for counter in column_counters.values()) report_lines.append(f"总token数: {total_tokens:,}") report_lines.append(f"分析的列数: {len(column_counters)}") report_lines.append("") # 每一列的详细统计 for col_idx in sorted(column_counters.keys()): counter = column_counters[col_idx] col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}" report_lines.append("-" * 80) report_lines.append(f"列 {col_idx}: {col_name}") report_lines.append("-" * 80) total_in_column = sum(counter.values()) unique_tokens = len(counter) min_token = min(counter.keys()) if counter else 0 max_token = max(counter.keys()) if counter else 0 report_lines.append(f" 总token数: {total_in_column:,}") report_lines.append(f" 唯一token数: {unique_tokens:,}") report_lines.append(f" Token值范围: [{min_token}, {max_token}]") report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A") report_lines.append("") # Top 20 最常见的token report_lines.append(f" Top 20 最常见的token:") top_tokens = counter.most_common(20) for rank, (token, count) in enumerate(top_tokens, 1): percentage = (count / total_in_column * 100) if total_in_column > 0 else 0 report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)") report_lines.append("") # Top 20 最不常见的token(出现次数>0的) report_lines.append(f" Top 20 最不常见的token (出现次数>0):") bottom_tokens = counter.most_common()[-20:] bottom_tokens.reverse() for rank, (token, count) in enumerate(bottom_tokens, 1): percentage = (count / total_in_column * 100) if total_in_column > 0 else 0 report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)") report_lines.append("") # 分布统计 counts_list = list(counter.values()) if counts_list: report_lines.append(f" 分布统计:") report_lines.append(f" 最小出现次数: {min(counts_list):,}") report_lines.append(f" 最大出现次数: {max(counts_list):,}") report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}") report_lines.append(f" 标准差: {np.std(counts_list):,.2f}") report_lines.append("") # 跨列分析 report_lines.append("=" * 80) report_lines.append("跨列分析") report_lines.append("=" * 80) report_lines.append("") for col_idx in sorted(column_counters.keys()): counter = column_counters[col_idx] col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}" total_in_column = sum(counter.values()) percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0 report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)") report_lines.append("") report_lines.append("=" * 80) report_lines.append("报告生成完成") report_lines.append("=" * 80) # 输出报告 report_text = "\n".join(report_lines) print("\n" + report_text) # 保存到文件 if output_file: output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: f.write(report_text) print(f"\n报告已保存到: {output_path}") # 同时保存JSON格式的详细数据 if output_file: json_output = output_path.with_suffix('.json') json_data = { 'summary': { 'total_tokens': total_tokens, 'num_columns': len(column_counters) }, 'columns': {} } for col_idx in sorted(column_counters.keys()): counter = column_counters[col_idx] col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}" json_data['columns'][col_name] = { 'total_tokens': sum(counter.values()), 'unique_tokens': len(counter), 'token_counts': dict(counter), 'top_20': dict(counter.most_common(20)), 'bottom_20': dict(counter.most_common()[-20:]) } with open(json_output, 'w', encoding='utf-8') as f: json.dump(json_data, f, indent=2, ensure_ascii=False) print(f"详细数据已保存到: {json_output}") def main(): """主函数""" # 默认数据目录 - 可以根据需要修改 default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi" # 可以指定具体的数据目录,例如: data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2" # 或者使用默认目录扫描所有oct8目录 # import sys # if len(sys.argv) > 1: # data_dir = sys.argv[1] # else: # # 自动查找所有oct8目录 # base_dir = Path(default_data_dir) # oct8_dirs = list(base_dir.rglob("oct8")) # if oct8_dirs: # print(f"找到以下oct8目录:") # for i, d in enumerate(oct8_dirs, 1): # print(f" {i}. {d}") # if len(oct8_dirs) == 1: # data_dir = str(oct8_dirs[0]) # print(f"\n使用目录: {data_dir}") # else: # # 使用第一个找到的目录,或者合并所有目录 # print(f"\n使用第一个目录: {oct8_dirs[0]}") # print("如需分析其他目录,请指定路径作为参数") # data_dir = str(oct8_dirs[0]) # else: # data_dir = default_data_dir # print(f"未找到oct8目录,使用默认目录: {data_dir}") # 加载数据 all_data = load_octuple_data(data_dir) if not all_data: print("错误: 未加载到任何数据") return # 检查数据格式 if all_data: sample = all_data[0] print(f"\n数据格式检查:") print(f" 样本形状: {sample.shape}") print(f" 样本数据类型: {sample.dtype}") print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}") print() # 统计token出现次数 column_counters = count_tokens_by_column(all_data) # 生成报告 output_file = "octuple_token_analysis_report_part.txt" generate_report(column_counters, output_file) if __name__ == "__main__": main()