Files
2025-11-27 15:44:17 +08:00

635 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于token分布距离的重采样脚本
读取octuple_token_analysis_report.json计算每个数据与整体分布的距离
按照距离加权采样,距离越远的越容易被采样
"""
import os
import numpy as np
from pathlib import Path
from collections import Counter
from tqdm import tqdm
import json
from scipy.stats import entropy, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_distribution_from_json(json_path):
"""
从JSON文件中加载整体token分布
Args:
json_path: JSON文件路径
Returns:
dict: {column_name: {token: probability}}
"""
print(f"读取分布文件: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
report = json.load(f)
distributions = {}
columns = report.get('columns', {})
for col_name in COLUMN_NAMES:
if col_name not in columns:
print(f"警告: 列 {col_name} 不在报告中")
distributions[col_name] = {}
continue
col_data = columns[col_name]
token_counts = col_data.get('token_counts', {})
total_tokens = col_data.get('total_tokens', 1)
# 转换为概率分布
distribution = {}
for token_str, count in token_counts.items():
token = int(token_str)
distribution[token] = count / total_tokens
distributions[col_name] = distribution
print(f"{col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}")
return distributions
def compute_data_distribution(data, col_idx):
"""
计算单个数据在指定列的token分布
Args:
data: numpy数组 (num_tokens, num_columns)
col_idx: 列索引
Returns:
dict: {token: probability}
"""
if data.size == 0:
return {}
tokens = data[:, col_idx]
unique, counts = np.unique(tokens, return_counts=True)
total = len(tokens)
distribution = {}
for token, count in zip(unique, counts):
distribution[int(token)] = count / total
return distribution
def compute_emd_distance(dist1, dist2):
"""
使用推土机距离Earth Mover's Distance / Wasserstein距离计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability},已归一化
dist2: 分布2dict {token: probability},已归一化
Returns:
float: EMD距离
"""
# 获取所有token的并集并排序
all_tokens = sorted(set(dist1.keys()) | set(dist2.keys()))
if not all_tokens:
return 0.0
# 构建概率向量和token值向量
p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens])
q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens])
token_values = np.array(all_tokens, dtype=float)
# 归一化(处理数值误差)
p_sum = p_weights.sum()
q_sum = q_weights.sum()
if p_sum < 1e-10 or q_sum < 1e-10:
return 0.0
p_weights = p_weights / p_sum
q_weights = q_weights / q_sum
# 使用Wasserstein距离1-Wasserstein距离即推土机距离
# wasserstein_distance需要两个分布的样本值位置和权重
# 对于离散分布我们使用token值作为位置
emd = wasserstein_distance(token_values, token_values, p_weights, q_weights)
return emd
def compute_distribution_distance(dist1, dist2, method='emd'):
"""
计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability}
dist2: 分布2dict {token: probability}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 分布距离
"""
if method == 'emd':
return compute_emd_distance(dist1, dist2)
# 获取所有token的并集
all_tokens = set(dist1.keys()) | set(dist2.keys())
if not all_tokens:
return 0.0
# 构建概率向量
p = np.array([dist1.get(token, 0.0) for token in all_tokens])
q = np.array([dist2.get(token, 0.0) for token in all_tokens])
# 归一化(处理数值误差)
p = p / (p.sum() + 1e-10)
q = q / (q.sum() + 1e-10)
if method == 'js':
# Jensen-Shannon散度对称范围[0, 1]
return jensenshannon(p, q)
elif method == 'kl':
# KL散度非对称需要处理零值
# 添加小的平滑项避免log(0)
p = p + 1e-10
q = q + 1e-10
p = p / p.sum()
q = q / q.sum()
return entropy(p, q)
else:
raise ValueError(f"未知的距离方法: {method}")
def extract_subdistribution(global_dist, data_tokens):
"""
从全局分布中提取只包含数据中出现的token的子分布并归一化
Args:
global_dist: 全局分布dict {token: probability}
data_tokens: 数据中出现的token集合set或list
Returns:
dict: 子分布dict {token: probability},已归一化
"""
if not data_tokens or not global_dist:
return {}
# 提取子分布
sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens}
# 归一化
total = sum(sub_dist.values())
if total < 1e-10:
return {}
normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()}
return normalized_sub_dist
def compute_data_distance(data, global_distributions, method='emd'):
"""
计算单个数据与整体分布的距离
对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布,
都进行归一化然后计算推土机距离
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载)
global_distributions: 整体分布dict {column_name: {token: probability}}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 平均距离(跨所有列)
"""
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
try:
data = np.load(data)['arr_0']
except Exception as e:
# 不打印错误,让调用者处理
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
distances = []
for col_idx, col_name in enumerate(COLUMN_NAMES):
# 计算该数据在该列的分布
data_dist = compute_data_distribution(data, col_idx)
# 获取整体分布
global_dist = global_distributions.get(col_name, {})
if not data_dist or not global_dist:
continue
# 从全局分布中提取只包含数据中出现的token的子分布
data_tokens = set(data_dist.keys())
sub_global_dist = extract_subdistribution(global_dist, data_tokens)
if not sub_global_dist:
continue
# 归一化数据分布
data_dist_sum = sum(data_dist.values())
if data_dist_sum < 1e-10:
continue
normalized_data_dist = {token: prob / data_dist_sum
for token, prob in data_dist.items()}
# 计算距离(两个分布都已归一化)
dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method)
distances.append(dist)
# 返回平均距离
return np.mean(distances) if distances else 0.0
def _load_single_file(npz_file):
"""
加载单个npz文件的辅助函数用于多线程
Args:
npz_file: npz文件路径
Returns:
tuple: (data, file_path) 或 None如果加载失败
"""
try:
data = np.load(npz_file)['arr_0']
if data.ndim == 2:
return (data, npz_file)
elif data.ndim == 1:
print(f"警告: {npz_file} 是一维数组,跳过")
return None
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
return None
def get_data_file_paths(data_dir):
"""
获取所有数据文件路径(不加载数据)
Args:
data_dir: 数据目录路径
Returns:
list: 文件路径列表
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件")
return npz_files
def load_data_with_paths(data_dir, num_threads=None, lazy=False):
"""
加载所有数据并返回数据路径列表(多线程版本)
Args:
data_dir: 数据目录路径
num_threads: 线程数None表示使用CPU核心数
lazy: 如果为True只返回文件路径不加载数据
Returns:
tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return [], []
if lazy:
print(f"找到 {len(npz_files)} 个.npz文件延迟加载模式")
return None, npz_files
print(f"找到 {len(npz_files)} 个.npz文件开始加载...")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(npz_files))
all_data = []
file_paths = []
# 使用多线程加载文件
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(_load_single_file, npz_file): npz_file
for npz_file in npz_files}
for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"):
result = future.result()
if result is not None:
data, file_path = result
all_data.append(data)
file_paths.append(file_path)
# 保持原始顺序
if file_paths:
sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0]))
file_paths, all_data = zip(*sorted_pairs)
file_paths = list(file_paths)
all_data = list(all_data)
return all_data, file_paths
def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True):
"""
根据距离进行加权重采样
Args:
file_paths: 文件路径列表
distances: 距离列表
sample_ratio: 采样比例
method: 距离计算方法(用于确定权重方向)
lazy: 如果为True返回文件路径而不是数据
Returns:
tuple: (sampled_data_or_paths, sampled_paths, sampled_indices)
"""
n_samples = int(len(file_paths) * sample_ratio)
print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)")
# 将距离转换为权重
# 距离越远,权重越大
distances = np.array(distances)
# 处理零距离或异常值
min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10
distances = np.maximum(distances, min_dist * 0.1)
# 归一化距离到[0, 1],然后转换为权重
# 使用指数函数增强距离差异
normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10)
weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样
# 归一化权重
weights = weights / weights.sum()
# 加权随机采样
indices = np.arange(len(file_paths))
sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights)
sampled_paths = [file_paths[i] for i in sampled_indices]
# 如果lazy=True返回路径否则加载数据
if lazy:
sampled_data = sampled_paths # 返回路径,延迟加载
else:
# 加载采样后的数据
sampled_data = []
for path in tqdm(sampled_paths, desc="加载采样数据"):
try:
data = np.load(path)['arr_0']
sampled_data.append(data)
except Exception as e:
print(f"错误: 加载 {path} 时出错: {e}")
sampled_data.append(None)
print(f"采样完成:")
print(f" 采样数据数量: {len(sampled_paths)}")
print(f" 平均距离: {distances[sampled_indices].mean():.6f}")
print(f" 最小距离: {distances[sampled_indices].min():.6f}")
print(f" 最大距离: {distances[sampled_indices].max():.6f}")
return sampled_data, sampled_paths, sampled_indices
def _save_single_file(args_tuple):
"""
保存单个文件的辅助函数(用于多线程)
支持延迟加载如果data是路径则从文件加载
Args:
args_tuple: (data, original_path, output_dir)
Returns:
tuple: (success, original_path) 或 (False, original_path, error_msg)
"""
data, original_path, output_dir = args_tuple
try:
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
data = np.load(data)['arr_0']
# 保持相对路径结构
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
output_path = output_dir / relative_path
output_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(output_path, data)
return (True, original_path)
except Exception as e:
error_msg = str(e)
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
return (False, original_path, error_msg)
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
"""
保存采样后的数据(多线程版本)
Args:
sampled_data: 采样后的数据列表
sampled_paths: 采样后的文件路径列表
output_dir: 输出目录
num_threads: 线程数None表示使用CPU核心数
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n保存采样数据到: {output_dir}")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(sampled_data))
# 准备参数
save_args = [(data, original_path, output_dir)
for data, original_path in zip(sampled_data, sampled_paths)]
# 使用多线程保存文件
success_count = 0
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_save_single_file, args)
for args in save_args]
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
try:
result = future.result(timeout=300) # 设置超时避免卡死
if isinstance(result, tuple) and len(result) >= 2:
success = result[0]
if success:
success_count += 1
except Exception as e:
print(f"错误: 获取保存结果时出错: {e}")
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="基于token分布距离的重采样")
parser.add_argument("--json_path", type=str,
default="octuple_token_analysis_report.json",
help="token分析报告JSON文件路径")
parser.add_argument("--data_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
help="数据目录路径")
parser.add_argument("--output_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled",
help="输出目录路径")
parser.add_argument("--sample_ratio", type=float, default=0.3,
help="采样比例 (默认: 0.3)")
parser.add_argument("--distance_method", type=str, default="emd",
choices=["emd", "js", "kl"],
help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子")
parser.add_argument("--num_threads", type=int, default=1,
help="线程数None表示使用CPU核心数 (默认: None)")
args = parser.parse_args()
# 设置随机种子
np.random.seed(args.seed)
# 1. 加载整体分布
global_distributions = load_distribution_from_json(args.json_path)
# 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据)
_, file_paths = load_data_with_paths(args.data_dir, lazy=True)
if not file_paths:
print("错误: 未找到任何数据文件")
return
print(f"\n共找到 {len(file_paths)} 个数据文件")
# 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载)
print("\n计算每个数据与整体分布的距离(延迟加载模式)...")
def _compute_distance_wrapper(args_tuple):
"""计算距离的包装函数(用于多线程,支持延迟加载)"""
idx, file_path, global_dists, method = args_tuple
try:
distance = compute_data_distance(file_path, global_dists, method=method)
return (idx, distance, None)
except Exception as e:
return (idx, 0.0, str(e))
if args.num_threads is None:
num_threads = min(mp.cpu_count(), len(file_paths))
else:
num_threads = args.num_threads
# 准备参数(使用文件路径而不是数据,包含索引)
distance_args = [(i, file_path, global_distributions, args.distance_method)
for i, file_path in enumerate(file_paths)]
# 使用多线程计算距离(按需加载数据)
# 初始化结果列表,保持顺序
distances = [0.0] * len(file_paths)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_compute_distance_wrapper, args)
for args in distance_args]
# 收集结果,使用 tqdm 显示进度
for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"):
try:
idx, distance, error = future.result(timeout=300) # 设置超时避免卡死
distances[idx] = distance
if error:
print(f"警告: 计算距离时出错 (索引 {idx}): {error}")
except Exception as e:
print(f"错误: 获取结果时出错: {e}")
# 如果无法获取结果,保持默认值 0.0
distances = np.array(distances)
print(f"\n距离统计:")
print(f" 平均距离: {distances.mean():.6f}")
print(f" 最小距离: {distances.min():.6f}")
print(f" 最大距离: {distances.max():.6f}")
print(f" 标准差: {distances.std():.6f}")
# 4. 根据距离进行加权采样(延迟加载模式)
sampled_data, sampled_paths, sampled_indices = weighted_resample(
file_paths, distances,
sample_ratio=args.sample_ratio,
method=args.distance_method,
lazy=True # 使用延迟加载,避免重复加载数据
)
# 5. 保存采样结果(多线程,延迟加载)
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
# 6. 保存采样索引(可选,用于后续分析)
indices_file = Path(args.output_dir) / "sampled_indices.npy"
np.save(indices_file, sampled_indices)
print(f"\n采样索引已保存到: {indices_file}")
# 保存采样信息
info = {
"total_samples": len(file_paths),
"sampled_samples": len(sampled_data),
"sample_ratio": args.sample_ratio,
"distance_method": args.distance_method,
"distance_stats": {
"mean": float(distances.mean()),
"min": float(distances.min()),
"max": float(distances.max()),
"std": float(distances.std())
},
"sampled_distance_stats": {
"mean": float(distances[sampled_indices].mean()),
"min": float(distances[sampled_indices].min()),
"max": float(distances[sampled_indices].max()),
"std": float(distances[sampled_indices].std())
}
}
info_file = Path(args.output_dir) / "resample_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(info, f, indent=2, ensure_ascii=False)
print(f"采样信息已保存到: {info_file}")
if __name__ == "__main__":
main()