#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 基于token分布距离的重采样脚本 读取octuple_token_analysis_report.json,计算每个数据与整体分布的距离, 按照距离加权采样,距离越远的越容易被采样 """ import os import numpy as np from pathlib import Path from collections import Counter from tqdm import tqdm import json from scipy.stats import entropy, wasserstein_distance from scipy.spatial.distance import jensenshannon from concurrent.futures import ThreadPoolExecutor, as_completed import multiprocessing as mp # Octuple的列名定义 COLUMN_NAMES = [ "pitch", # 0: Pitch/PitchDrum "position", # 1: Position "bar", # 2: Bar "velocity", # 3: Velocity "duration", # 4: Duration "program", # 5: Program "tempo", # 6: Tempo "timesig" # 7: TimeSignature ] def load_distribution_from_json(json_path): """ 从JSON文件中加载整体token分布 Args: json_path: JSON文件路径 Returns: dict: {column_name: {token: probability}} """ print(f"读取分布文件: {json_path}") with open(json_path, 'r', encoding='utf-8') as f: report = json.load(f) distributions = {} columns = report.get('columns', {}) for col_name in COLUMN_NAMES: if col_name not in columns: print(f"警告: 列 {col_name} 不在报告中") distributions[col_name] = {} continue col_data = columns[col_name] token_counts = col_data.get('token_counts', {}) total_tokens = col_data.get('total_tokens', 1) # 转换为概率分布 distribution = {} for token_str, count in token_counts.items(): token = int(token_str) distribution[token] = count / total_tokens distributions[col_name] = distribution print(f" 列 {col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}") return distributions def compute_data_distribution(data, col_idx): """ 计算单个数据在指定列的token分布 Args: data: numpy数组 (num_tokens, num_columns) col_idx: 列索引 Returns: dict: {token: probability} """ if data.size == 0: return {} tokens = data[:, col_idx] unique, counts = np.unique(tokens, return_counts=True) total = len(tokens) distribution = {} for token, count in zip(unique, counts): distribution[int(token)] = count / total return distribution def compute_emd_distance(dist1, dist2): """ 使用推土机距离(Earth Mover's Distance / Wasserstein距离)计算两个分布之间的距离 Args: dist1: 分布1,dict {token: probability},已归一化 dist2: 分布2,dict {token: probability},已归一化 Returns: float: EMD距离 """ # 获取所有token的并集,并排序 all_tokens = sorted(set(dist1.keys()) | set(dist2.keys())) if not all_tokens: return 0.0 # 构建概率向量和token值向量 p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens]) q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens]) token_values = np.array(all_tokens, dtype=float) # 归一化(处理数值误差) p_sum = p_weights.sum() q_sum = q_weights.sum() if p_sum < 1e-10 or q_sum < 1e-10: return 0.0 p_weights = p_weights / p_sum q_weights = q_weights / q_sum # 使用Wasserstein距离(1-Wasserstein距离,即推土机距离) # wasserstein_distance需要两个分布的样本值位置和权重 # 对于离散分布,我们使用token值作为位置 emd = wasserstein_distance(token_values, token_values, p_weights, q_weights) return emd def compute_distribution_distance(dist1, dist2, method='emd'): """ 计算两个分布之间的距离 Args: dist1: 分布1,dict {token: probability} dist2: 分布2,dict {token: probability} method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度) Returns: float: 分布距离 """ if method == 'emd': return compute_emd_distance(dist1, dist2) # 获取所有token的并集 all_tokens = set(dist1.keys()) | set(dist2.keys()) if not all_tokens: return 0.0 # 构建概率向量 p = np.array([dist1.get(token, 0.0) for token in all_tokens]) q = np.array([dist2.get(token, 0.0) for token in all_tokens]) # 归一化(处理数值误差) p = p / (p.sum() + 1e-10) q = q / (q.sum() + 1e-10) if method == 'js': # Jensen-Shannon散度(对称,范围[0, 1]) return jensenshannon(p, q) elif method == 'kl': # KL散度(非对称,需要处理零值) # 添加小的平滑项避免log(0) p = p + 1e-10 q = q + 1e-10 p = p / p.sum() q = q / q.sum() return entropy(p, q) else: raise ValueError(f"未知的距离方法: {method}") def extract_subdistribution(global_dist, data_tokens): """ 从全局分布中提取只包含数据中出现的token的子分布,并归一化 Args: global_dist: 全局分布,dict {token: probability} data_tokens: 数据中出现的token集合,set或list Returns: dict: 子分布,dict {token: probability},已归一化 """ if not data_tokens or not global_dist: return {} # 提取子分布 sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens} # 归一化 total = sum(sub_dist.values()) if total < 1e-10: return {} normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()} return normalized_sub_dist def compute_data_distance(data, global_distributions, method='emd'): """ 计算单个数据与整体分布的距离 对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布, 都进行归一化然后计算推土机距离 Args: data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载) global_distributions: 整体分布,dict {column_name: {token: probability}} method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度) Returns: float: 平均距离(跨所有列) """ # 如果data是路径,则加载它 if isinstance(data, (str, Path)): try: data = np.load(data)['arr_0'] except Exception as e: # 不打印错误,让调用者处理 raise RuntimeError(f"加载文件 {data} 时出错: {e}") distances = [] for col_idx, col_name in enumerate(COLUMN_NAMES): # 计算该数据在该列的分布 data_dist = compute_data_distribution(data, col_idx) # 获取整体分布 global_dist = global_distributions.get(col_name, {}) if not data_dist or not global_dist: continue # 从全局分布中提取只包含数据中出现的token的子分布 data_tokens = set(data_dist.keys()) sub_global_dist = extract_subdistribution(global_dist, data_tokens) if not sub_global_dist: continue # 归一化数据分布 data_dist_sum = sum(data_dist.values()) if data_dist_sum < 1e-10: continue normalized_data_dist = {token: prob / data_dist_sum for token, prob in data_dist.items()} # 计算距离(两个分布都已归一化) dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method) distances.append(dist) # 返回平均距离 return np.mean(distances) if distances else 0.0 def _load_single_file(npz_file): """ 加载单个npz文件的辅助函数(用于多线程) Args: npz_file: npz文件路径 Returns: tuple: (data, file_path) 或 None(如果加载失败) """ try: data = np.load(npz_file)['arr_0'] if data.ndim == 2: return (data, npz_file) elif data.ndim == 1: print(f"警告: {npz_file} 是一维数组,跳过") return None except Exception as e: print(f"错误: 加载 {npz_file} 时出错: {e}") return None def get_data_file_paths(data_dir): """ 获取所有数据文件路径(不加载数据) Args: data_dir: 数据目录路径 Returns: list: 文件路径列表 """ data_dir = Path(data_dir) npz_files = [] if data_dir.exists(): npz_files = sorted(list(data_dir.rglob("*.npz"))) if not npz_files: print(f"警告: 在 {data_dir} 中未找到.npz文件") return [] print(f"找到 {len(npz_files)} 个.npz文件") return npz_files def load_data_with_paths(data_dir, num_threads=None, lazy=False): """ 加载所有数据并返回数据路径列表(多线程版本) Args: data_dir: 数据目录路径 num_threads: 线程数,None表示使用CPU核心数 lazy: 如果为True,只返回文件路径,不加载数据 Returns: tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True """ data_dir = Path(data_dir) npz_files = [] if data_dir.exists(): npz_files = sorted(list(data_dir.rglob("*.npz"))) if not npz_files: print(f"警告: 在 {data_dir} 中未找到.npz文件") return [], [] if lazy: print(f"找到 {len(npz_files)} 个.npz文件(延迟加载模式)") return None, npz_files print(f"找到 {len(npz_files)} 个.npz文件,开始加载...") if num_threads is None: num_threads = min(mp.cpu_count(), len(npz_files)) all_data = [] file_paths = [] # 使用多线程加载文件 with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = {executor.submit(_load_single_file, npz_file): npz_file for npz_file in npz_files} for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"): result = future.result() if result is not None: data, file_path = result all_data.append(data) file_paths.append(file_path) # 保持原始顺序 if file_paths: sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0])) file_paths, all_data = zip(*sorted_pairs) file_paths = list(file_paths) all_data = list(all_data) return all_data, file_paths def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True): """ 根据距离进行加权重采样 Args: file_paths: 文件路径列表 distances: 距离列表 sample_ratio: 采样比例 method: 距离计算方法(用于确定权重方向) lazy: 如果为True,返回文件路径而不是数据 Returns: tuple: (sampled_data_or_paths, sampled_paths, sampled_indices) """ n_samples = int(len(file_paths) * sample_ratio) print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)") # 将距离转换为权重 # 距离越远,权重越大 distances = np.array(distances) # 处理零距离或异常值 min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10 distances = np.maximum(distances, min_dist * 0.1) # 归一化距离到[0, 1],然后转换为权重 # 使用指数函数增强距离差异 normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10) weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样 # 归一化权重 weights = weights / weights.sum() # 加权随机采样 indices = np.arange(len(file_paths)) sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights) sampled_paths = [file_paths[i] for i in sampled_indices] # 如果lazy=True,返回路径;否则加载数据 if lazy: sampled_data = sampled_paths # 返回路径,延迟加载 else: # 加载采样后的数据 sampled_data = [] for path in tqdm(sampled_paths, desc="加载采样数据"): try: data = np.load(path)['arr_0'] sampled_data.append(data) except Exception as e: print(f"错误: 加载 {path} 时出错: {e}") sampled_data.append(None) print(f"采样完成:") print(f" 采样数据数量: {len(sampled_paths)}") print(f" 平均距离: {distances[sampled_indices].mean():.6f}") print(f" 最小距离: {distances[sampled_indices].min():.6f}") print(f" 最大距离: {distances[sampled_indices].max():.6f}") return sampled_data, sampled_paths, sampled_indices def _save_single_file(args_tuple): """ 保存单个文件的辅助函数(用于多线程) 支持延迟加载:如果data是路径,则从文件加载 Args: args_tuple: (data, original_path, output_dir) Returns: tuple: (success, original_path) 或 (False, original_path, error_msg) """ data, original_path, output_dir = args_tuple try: # 如果data是路径,则加载它 if isinstance(data, (str, Path)): data = np.load(data)['arr_0'] # 保持相对路径结构 relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3]) output_path = output_dir / relative_path output_path.parent.mkdir(parents=True, exist_ok=True) np.savez_compressed(output_path, data) return (True, original_path) except Exception as e: error_msg = str(e) print(f"错误: 保存 {original_path} 时出错: {error_msg}") return (False, original_path, error_msg) def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None): """ 保存采样后的数据(多线程版本) Args: sampled_data: 采样后的数据列表 sampled_paths: 采样后的文件路径列表 output_dir: 输出目录 num_threads: 线程数,None表示使用CPU核心数 """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) print(f"\n保存采样数据到: {output_dir}") if num_threads is None: num_threads = min(mp.cpu_count(), len(sampled_data)) # 准备参数 save_args = [(data, original_path, output_dir) for data, original_path in zip(sampled_data, sampled_paths)] # 使用多线程保存文件 success_count = 0 with ThreadPoolExecutor(max_workers=num_threads) as executor: # 提交所有任务 futures = [executor.submit(_save_single_file, args) for args in save_args] # 收集结果 for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"): try: result = future.result(timeout=300) # 设置超时避免卡死 if isinstance(result, tuple) and len(result) >= 2: success = result[0] if success: success_count += 1 except Exception as e: print(f"错误: 获取保存结果时出错: {e}") print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件") def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description="基于token分布距离的重采样") parser.add_argument("--json_path", type=str, default="octuple_token_analysis_report.json", help="token分析报告JSON文件路径") parser.add_argument("--data_dir", type=str, default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8", help="数据目录路径") parser.add_argument("--output_dir", type=str, default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled", help="输出目录路径") parser.add_argument("--sample_ratio", type=float, default=0.3, help="采样比例 (默认: 0.3)") parser.add_argument("--distance_method", type=str, default="emd", choices=["emd", "js", "kl"], help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)") parser.add_argument("--seed", type=int, default=42, help="随机种子") parser.add_argument("--num_threads", type=int, default=1, help="线程数,None表示使用CPU核心数 (默认: None)") args = parser.parse_args() # 设置随机种子 np.random.seed(args.seed) # 1. 加载整体分布 global_distributions = load_distribution_from_json(args.json_path) # 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据) _, file_paths = load_data_with_paths(args.data_dir, lazy=True) if not file_paths: print("错误: 未找到任何数据文件") return print(f"\n共找到 {len(file_paths)} 个数据文件") # 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载) print("\n计算每个数据与整体分布的距离(延迟加载模式)...") def _compute_distance_wrapper(args_tuple): """计算距离的包装函数(用于多线程,支持延迟加载)""" idx, file_path, global_dists, method = args_tuple try: distance = compute_data_distance(file_path, global_dists, method=method) return (idx, distance, None) except Exception as e: return (idx, 0.0, str(e)) if args.num_threads is None: num_threads = min(mp.cpu_count(), len(file_paths)) else: num_threads = args.num_threads # 准备参数(使用文件路径而不是数据,包含索引) distance_args = [(i, file_path, global_distributions, args.distance_method) for i, file_path in enumerate(file_paths)] # 使用多线程计算距离(按需加载数据) # 初始化结果列表,保持顺序 distances = [0.0] * len(file_paths) with ThreadPoolExecutor(max_workers=num_threads) as executor: # 提交所有任务 futures = [executor.submit(_compute_distance_wrapper, args) for args in distance_args] # 收集结果,使用 tqdm 显示进度 for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"): try: idx, distance, error = future.result(timeout=300) # 设置超时避免卡死 distances[idx] = distance if error: print(f"警告: 计算距离时出错 (索引 {idx}): {error}") except Exception as e: print(f"错误: 获取结果时出错: {e}") # 如果无法获取结果,保持默认值 0.0 distances = np.array(distances) print(f"\n距离统计:") print(f" 平均距离: {distances.mean():.6f}") print(f" 最小距离: {distances.min():.6f}") print(f" 最大距离: {distances.max():.6f}") print(f" 标准差: {distances.std():.6f}") # 4. 根据距离进行加权采样(延迟加载模式) sampled_data, sampled_paths, sampled_indices = weighted_resample( file_paths, distances, sample_ratio=args.sample_ratio, method=args.distance_method, lazy=True # 使用延迟加载,避免重复加载数据 ) # 5. 保存采样结果(多线程,延迟加载) save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads) # 6. 保存采样索引(可选,用于后续分析) indices_file = Path(args.output_dir) / "sampled_indices.npy" np.save(indices_file, sampled_indices) print(f"\n采样索引已保存到: {indices_file}") # 保存采样信息 info = { "total_samples": len(file_paths), "sampled_samples": len(sampled_data), "sample_ratio": args.sample_ratio, "distance_method": args.distance_method, "distance_stats": { "mean": float(distances.mean()), "min": float(distances.min()), "max": float(distances.max()), "std": float(distances.std()) }, "sampled_distance_stats": { "mean": float(distances[sampled_indices].mean()), "min": float(distances[sampled_indices].min()), "max": float(distances[sampled_indices].max()), "std": float(distances[sampled_indices].std()) } } info_file = Path(args.output_dir) / "resample_info.json" with open(info_file, 'w', encoding='utf-8') as f: json.dump(info, f, indent=2, ensure_ascii=False) print(f"采样信息已保存到: {info_file}") if __name__ == "__main__": main()