1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -0,0 +1,634 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于token分布距离的重采样脚本
读取octuple_token_analysis_report.json计算每个数据与整体分布的距离
按照距离加权采样,距离越远的越容易被采样
"""
import os
import numpy as np
from pathlib import Path
from collections import Counter
from tqdm import tqdm
import json
from scipy.stats import entropy, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_distribution_from_json(json_path):
"""
从JSON文件中加载整体token分布
Args:
json_path: JSON文件路径
Returns:
dict: {column_name: {token: probability}}
"""
print(f"读取分布文件: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
report = json.load(f)
distributions = {}
columns = report.get('columns', {})
for col_name in COLUMN_NAMES:
if col_name not in columns:
print(f"警告: 列 {col_name} 不在报告中")
distributions[col_name] = {}
continue
col_data = columns[col_name]
token_counts = col_data.get('token_counts', {})
total_tokens = col_data.get('total_tokens', 1)
# 转换为概率分布
distribution = {}
for token_str, count in token_counts.items():
token = int(token_str)
distribution[token] = count / total_tokens
distributions[col_name] = distribution
print(f"{col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}")
return distributions
def compute_data_distribution(data, col_idx):
"""
计算单个数据在指定列的token分布
Args:
data: numpy数组 (num_tokens, num_columns)
col_idx: 列索引
Returns:
dict: {token: probability}
"""
if data.size == 0:
return {}
tokens = data[:, col_idx]
unique, counts = np.unique(tokens, return_counts=True)
total = len(tokens)
distribution = {}
for token, count in zip(unique, counts):
distribution[int(token)] = count / total
return distribution
def compute_emd_distance(dist1, dist2):
"""
使用推土机距离Earth Mover's Distance / Wasserstein距离计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability},已归一化
dist2: 分布2dict {token: probability},已归一化
Returns:
float: EMD距离
"""
# 获取所有token的并集并排序
all_tokens = sorted(set(dist1.keys()) | set(dist2.keys()))
if not all_tokens:
return 0.0
# 构建概率向量和token值向量
p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens])
q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens])
token_values = np.array(all_tokens, dtype=float)
# 归一化(处理数值误差)
p_sum = p_weights.sum()
q_sum = q_weights.sum()
if p_sum < 1e-10 or q_sum < 1e-10:
return 0.0
p_weights = p_weights / p_sum
q_weights = q_weights / q_sum
# 使用Wasserstein距离1-Wasserstein距离即推土机距离
# wasserstein_distance需要两个分布的样本值位置和权重
# 对于离散分布我们使用token值作为位置
emd = wasserstein_distance(token_values, token_values, p_weights, q_weights)
return emd
def compute_distribution_distance(dist1, dist2, method='emd'):
"""
计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability}
dist2: 分布2dict {token: probability}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 分布距离
"""
if method == 'emd':
return compute_emd_distance(dist1, dist2)
# 获取所有token的并集
all_tokens = set(dist1.keys()) | set(dist2.keys())
if not all_tokens:
return 0.0
# 构建概率向量
p = np.array([dist1.get(token, 0.0) for token in all_tokens])
q = np.array([dist2.get(token, 0.0) for token in all_tokens])
# 归一化(处理数值误差)
p = p / (p.sum() + 1e-10)
q = q / (q.sum() + 1e-10)
if method == 'js':
# Jensen-Shannon散度对称范围[0, 1]
return jensenshannon(p, q)
elif method == 'kl':
# KL散度非对称需要处理零值
# 添加小的平滑项避免log(0)
p = p + 1e-10
q = q + 1e-10
p = p / p.sum()
q = q / q.sum()
return entropy(p, q)
else:
raise ValueError(f"未知的距离方法: {method}")
def extract_subdistribution(global_dist, data_tokens):
"""
从全局分布中提取只包含数据中出现的token的子分布并归一化
Args:
global_dist: 全局分布dict {token: probability}
data_tokens: 数据中出现的token集合set或list
Returns:
dict: 子分布dict {token: probability},已归一化
"""
if not data_tokens or not global_dist:
return {}
# 提取子分布
sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens}
# 归一化
total = sum(sub_dist.values())
if total < 1e-10:
return {}
normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()}
return normalized_sub_dist
def compute_data_distance(data, global_distributions, method='emd'):
"""
计算单个数据与整体分布的距离
对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布,
都进行归一化然后计算推土机距离
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载)
global_distributions: 整体分布dict {column_name: {token: probability}}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 平均距离(跨所有列)
"""
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
try:
data = np.load(data)['arr_0']
except Exception as e:
# 不打印错误,让调用者处理
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
distances = []
for col_idx, col_name in enumerate(COLUMN_NAMES):
# 计算该数据在该列的分布
data_dist = compute_data_distribution(data, col_idx)
# 获取整体分布
global_dist = global_distributions.get(col_name, {})
if not data_dist or not global_dist:
continue
# 从全局分布中提取只包含数据中出现的token的子分布
data_tokens = set(data_dist.keys())
sub_global_dist = extract_subdistribution(global_dist, data_tokens)
if not sub_global_dist:
continue
# 归一化数据分布
data_dist_sum = sum(data_dist.values())
if data_dist_sum < 1e-10:
continue
normalized_data_dist = {token: prob / data_dist_sum
for token, prob in data_dist.items()}
# 计算距离(两个分布都已归一化)
dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method)
distances.append(dist)
# 返回平均距离
return np.mean(distances) if distances else 0.0
def _load_single_file(npz_file):
"""
加载单个npz文件的辅助函数用于多线程
Args:
npz_file: npz文件路径
Returns:
tuple: (data, file_path) 或 None如果加载失败
"""
try:
data = np.load(npz_file)['arr_0']
if data.ndim == 2:
return (data, npz_file)
elif data.ndim == 1:
print(f"警告: {npz_file} 是一维数组,跳过")
return None
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
return None
def get_data_file_paths(data_dir):
"""
获取所有数据文件路径(不加载数据)
Args:
data_dir: 数据目录路径
Returns:
list: 文件路径列表
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件")
return npz_files
def load_data_with_paths(data_dir, num_threads=None, lazy=False):
"""
加载所有数据并返回数据路径列表(多线程版本)
Args:
data_dir: 数据目录路径
num_threads: 线程数None表示使用CPU核心数
lazy: 如果为True只返回文件路径不加载数据
Returns:
tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return [], []
if lazy:
print(f"找到 {len(npz_files)} 个.npz文件延迟加载模式")
return None, npz_files
print(f"找到 {len(npz_files)} 个.npz文件开始加载...")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(npz_files))
all_data = []
file_paths = []
# 使用多线程加载文件
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(_load_single_file, npz_file): npz_file
for npz_file in npz_files}
for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"):
result = future.result()
if result is not None:
data, file_path = result
all_data.append(data)
file_paths.append(file_path)
# 保持原始顺序
if file_paths:
sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0]))
file_paths, all_data = zip(*sorted_pairs)
file_paths = list(file_paths)
all_data = list(all_data)
return all_data, file_paths
def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True):
"""
根据距离进行加权重采样
Args:
file_paths: 文件路径列表
distances: 距离列表
sample_ratio: 采样比例
method: 距离计算方法(用于确定权重方向)
lazy: 如果为True返回文件路径而不是数据
Returns:
tuple: (sampled_data_or_paths, sampled_paths, sampled_indices)
"""
n_samples = int(len(file_paths) * sample_ratio)
print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)")
# 将距离转换为权重
# 距离越远,权重越大
distances = np.array(distances)
# 处理零距离或异常值
min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10
distances = np.maximum(distances, min_dist * 0.1)
# 归一化距离到[0, 1],然后转换为权重
# 使用指数函数增强距离差异
normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10)
weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样
# 归一化权重
weights = weights / weights.sum()
# 加权随机采样
indices = np.arange(len(file_paths))
sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights)
sampled_paths = [file_paths[i] for i in sampled_indices]
# 如果lazy=True返回路径否则加载数据
if lazy:
sampled_data = sampled_paths # 返回路径,延迟加载
else:
# 加载采样后的数据
sampled_data = []
for path in tqdm(sampled_paths, desc="加载采样数据"):
try:
data = np.load(path)['arr_0']
sampled_data.append(data)
except Exception as e:
print(f"错误: 加载 {path} 时出错: {e}")
sampled_data.append(None)
print(f"采样完成:")
print(f" 采样数据数量: {len(sampled_paths)}")
print(f" 平均距离: {distances[sampled_indices].mean():.6f}")
print(f" 最小距离: {distances[sampled_indices].min():.6f}")
print(f" 最大距离: {distances[sampled_indices].max():.6f}")
return sampled_data, sampled_paths, sampled_indices
def _save_single_file(args_tuple):
"""
保存单个文件的辅助函数(用于多线程)
支持延迟加载如果data是路径则从文件加载
Args:
args_tuple: (data, original_path, output_dir)
Returns:
tuple: (success, original_path) 或 (False, original_path, error_msg)
"""
data, original_path, output_dir = args_tuple
try:
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
data = np.load(data)['arr_0']
# 保持相对路径结构
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
output_path = output_dir / relative_path
output_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(output_path, data)
return (True, original_path)
except Exception as e:
error_msg = str(e)
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
return (False, original_path, error_msg)
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
"""
保存采样后的数据(多线程版本)
Args:
sampled_data: 采样后的数据列表
sampled_paths: 采样后的文件路径列表
output_dir: 输出目录
num_threads: 线程数None表示使用CPU核心数
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n保存采样数据到: {output_dir}")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(sampled_data))
# 准备参数
save_args = [(data, original_path, output_dir)
for data, original_path in zip(sampled_data, sampled_paths)]
# 使用多线程保存文件
success_count = 0
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_save_single_file, args)
for args in save_args]
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
try:
result = future.result(timeout=300) # 设置超时避免卡死
if isinstance(result, tuple) and len(result) >= 2:
success = result[0]
if success:
success_count += 1
except Exception as e:
print(f"错误: 获取保存结果时出错: {e}")
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="基于token分布距离的重采样")
parser.add_argument("--json_path", type=str,
default="octuple_token_analysis_report.json",
help="token分析报告JSON文件路径")
parser.add_argument("--data_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
help="数据目录路径")
parser.add_argument("--output_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled",
help="输出目录路径")
parser.add_argument("--sample_ratio", type=float, default=0.3,
help="采样比例 (默认: 0.3)")
parser.add_argument("--distance_method", type=str, default="emd",
choices=["emd", "js", "kl"],
help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子")
parser.add_argument("--num_threads", type=int, default=1,
help="线程数None表示使用CPU核心数 (默认: None)")
args = parser.parse_args()
# 设置随机种子
np.random.seed(args.seed)
# 1. 加载整体分布
global_distributions = load_distribution_from_json(args.json_path)
# 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据)
_, file_paths = load_data_with_paths(args.data_dir, lazy=True)
if not file_paths:
print("错误: 未找到任何数据文件")
return
print(f"\n共找到 {len(file_paths)} 个数据文件")
# 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载)
print("\n计算每个数据与整体分布的距离(延迟加载模式)...")
def _compute_distance_wrapper(args_tuple):
"""计算距离的包装函数(用于多线程,支持延迟加载)"""
idx, file_path, global_dists, method = args_tuple
try:
distance = compute_data_distance(file_path, global_dists, method=method)
return (idx, distance, None)
except Exception as e:
return (idx, 0.0, str(e))
if args.num_threads is None:
num_threads = min(mp.cpu_count(), len(file_paths))
else:
num_threads = args.num_threads
# 准备参数(使用文件路径而不是数据,包含索引)
distance_args = [(i, file_path, global_distributions, args.distance_method)
for i, file_path in enumerate(file_paths)]
# 使用多线程计算距离(按需加载数据)
# 初始化结果列表,保持顺序
distances = [0.0] * len(file_paths)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_compute_distance_wrapper, args)
for args in distance_args]
# 收集结果,使用 tqdm 显示进度
for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"):
try:
idx, distance, error = future.result(timeout=300) # 设置超时避免卡死
distances[idx] = distance
if error:
print(f"警告: 计算距离时出错 (索引 {idx}): {error}")
except Exception as e:
print(f"错误: 获取结果时出错: {e}")
# 如果无法获取结果,保持默认值 0.0
distances = np.array(distances)
print(f"\n距离统计:")
print(f" 平均距离: {distances.mean():.6f}")
print(f" 最小距离: {distances.min():.6f}")
print(f" 最大距离: {distances.max():.6f}")
print(f" 标准差: {distances.std():.6f}")
# 4. 根据距离进行加权采样(延迟加载模式)
sampled_data, sampled_paths, sampled_indices = weighted_resample(
file_paths, distances,
sample_ratio=args.sample_ratio,
method=args.distance_method,
lazy=True # 使用延迟加载,避免重复加载数据
)
# 5. 保存采样结果(多线程,延迟加载)
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
# 6. 保存采样索引(可选,用于后续分析)
indices_file = Path(args.output_dir) / "sampled_indices.npy"
np.save(indices_file, sampled_indices)
print(f"\n采样索引已保存到: {indices_file}")
# 保存采样信息
info = {
"total_samples": len(file_paths),
"sampled_samples": len(sampled_data),
"sample_ratio": args.sample_ratio,
"distance_method": args.distance_method,
"distance_stats": {
"mean": float(distances.mean()),
"min": float(distances.min()),
"max": float(distances.max()),
"std": float(distances.std())
},
"sampled_distance_stats": {
"mean": float(distances[sampled_indices].mean()),
"min": float(distances[sampled_indices].min()),
"max": float(distances[sampled_indices].max()),
"std": float(distances[sampled_indices].std())
}
}
info_file = Path(args.output_dir) / "resample_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(info, f, indent=2, ensure_ascii=False)
print(f"采样信息已保存到: {info_file}")
if __name__ == "__main__":
main()