1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -0,0 +1,472 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于position和duration token的重采样脚本V2
对于每首歌:
1. 如果包含的position和duration不在总数据集前3个则必定采样
2. 对于包含的,以某个固定的百分比采样
3. 两个条件满足一个即可
"""
import os
import numpy as np
from pathlib import Path
from collections import Counter
from tqdm import tqdm
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_top_tokens_from_json(json_path, column_name, top_k=3):
"""
从JSON文件中加载指定列的前top_k个最常见的token
Args:
json_path: JSON文件路径
column_name: 列名(如'position''duration'
top_k: 返回前k个最常见的token
Returns:
set: 前top_k个最常见的token集合
"""
print(f"读取分布文件: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
report = json.load(f)
columns = report.get('columns', {})
if column_name not in columns:
print(f"警告: 列 {column_name} 不在报告中")
return set()
col_data = columns[column_name]
token_counts = col_data.get('token_counts', {})
# 按出现次数排序获取前top_k个
sorted_tokens = sorted(token_counts.items(), key=lambda x: int(x[1]), reverse=True)
top_tokens = {int(token_str) for token_str, _ in sorted_tokens[:top_k]}
print(f"{column_name} 的前{top_k}个最常见token: {sorted(top_tokens)}")
return top_tokens
def get_data_tokens(data, col_idx):
"""
获取单个数据在指定列的所有唯一token
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径
col_idx: 列索引
Returns:
set: 唯一token集合
"""
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
try:
data = np.load(data)['arr_0']
except Exception as e:
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
if data.size == 0:
return set()
tokens = data[:, col_idx]
unique_tokens = set(int(token) for token in np.unique(tokens))
return unique_tokens
def should_sample_song(data, top_position_tokens, top_duration_tokens,
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, rng=None):
"""
判断一首歌是否应该被采样
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径
top_position_tokens: position列的前3个最常见token集合
top_duration_tokens: duration列的前3个最常见token集合
contain_sample_ratio: 对于包含前3个token的歌曲采样比例
not_contain_sample_ratio: 对于不包含前3个token的歌曲采样比例更高概率
rng: 随机数生成器如果为None则使用全局的np.random
Returns:
tuple: (是否应该采样, 是否在前3个) - 在前3个指position和duration都在前3个
"""
# 获取position和duration列的唯一token
position_idx = COLUMN_NAMES.index("position")
duration_idx = COLUMN_NAMES.index("duration")
position_tokens = get_data_tokens(data, position_idx)
duration_tokens = get_data_tokens(data, duration_idx)
# 判断是否在前3个
position_in_top3 = bool(position_tokens & top_position_tokens)
duration_in_top3 = bool(duration_tokens & top_duration_tokens)
in_top3 = position_in_top3 and duration_in_top3
if rng is None:
rng = np.random
# 条件1: 如果position或duration不包含前3个token以更高概率采样
if not position_in_top3 or not duration_in_top3:
should_sample = rng.random() < not_contain_sample_ratio
return should_sample, False
# 条件2: 如果包含前3个token则以固定百分比采样
should_sample = rng.random() < contain_sample_ratio
return should_sample, True
def _load_single_file(npz_file):
"""
加载单个npz文件的辅助函数用于多线程
Args:
npz_file: npz文件路径
Returns:
tuple: (data, file_path) 或 None如果加载失败
"""
try:
data = np.load(npz_file)['arr_0']
if data.ndim == 2:
return (data, npz_file)
elif data.ndim == 1:
print(f"警告: {npz_file} 是一维数组,跳过")
return None
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
return None
def get_data_file_paths(data_dir):
"""
获取所有数据文件路径(不加载数据)
Args:
data_dir: 数据目录路径
Returns:
list: 文件路径列表
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件")
return npz_files
def resample_songs(file_paths, top_position_tokens, top_duration_tokens,
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9,
num_threads=None, seed=42):
"""
根据新逻辑进行重采样
Args:
file_paths: 文件路径列表
top_position_tokens: position列的前3个最常见token集合
top_duration_tokens: duration列的前3个最常见token集合
contain_sample_ratio: 对于包含前3个token的歌曲采样比例
not_contain_sample_ratio: 对于不包含前3个token的歌曲采样比例更高概率
num_threads: 线程数None表示使用CPU核心数
seed: 随机种子
Returns:
tuple: (sampled_paths, sampled_indices, stats)
"""
import threading
# 为每个线程创建独立的随机数生成器
thread_local = threading.local()
def get_thread_rng():
"""获取当前线程的随机数生成器"""
if not hasattr(thread_local, 'rng'):
# 使用线程ID和种子创建独立的随机数生成器
thread_id = threading.current_thread().ident
thread_local.rng = np.random.RandomState(seed + hash(thread_id) % 1000000)
return thread_local.rng
if num_threads is None:
num_threads = min(mp.cpu_count(), len(file_paths))
print(f"\n开始重采样,使用 {num_threads} 个线程...")
print(f" 包含前3个token的采样比例: {contain_sample_ratio*100:.1f}%")
print(f" 不包含前3个token的采样比例: {not_contain_sample_ratio*100:.1f}%")
def _should_sample_wrapper(args_tuple):
"""判断是否采样的包装函数(用于多线程)"""
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio = args_tuple
try:
# 使用线程本地的随机数生成器
thread_rng = get_thread_rng()
should_sample, in_top3 = should_sample_song(
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio, thread_rng
)
return (file_path, should_sample, in_top3, None)
except Exception as e:
return (file_path, False, False, str(e))
# 准备参数
sample_args = [(file_path, top_position_tokens, top_duration_tokens,
contain_sample_ratio, not_contain_sample_ratio)
for file_path in file_paths]
# 使用多线程判断每首歌是否应该采样
sampled_paths = []
sampled_indices = []
stats = {
'not_in_top3_count': 0, # 不在前3个的歌曲数量
'not_in_top3_sampled': 0, # 不在前3个且被采样的歌曲数量
'in_top3_count': 0, # 在前3个的歌曲数量
'in_top3_sampled': 0 # 在前3个且被采样的歌曲数量
}
# 限制并发任务数量,避免一次性提交过多任务
batch_size = min(1000, len(file_paths))
results = {}
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 分批提交任务
for batch_start in range(0, len(sample_args), batch_size):
batch_end = min(batch_start + batch_size, len(sample_args))
batch_args = sample_args[batch_start:batch_end]
futures = {executor.submit(_should_sample_wrapper, args): args[0]
for args in batch_args}
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures),
desc=f"判断采样 [{batch_start+1}-{batch_end}/{len(file_paths)}]",
leave=False):
try:
file_path, should_sample, in_top3, error = future.result(timeout=60)
results[file_path] = (should_sample, in_top3, error)
if error:
print(f"警告: 处理 {file_path} 时出错: {error}")
except Exception as e:
print(f"错误: 获取结果时出错: {e}")
# 按原始顺序处理结果,并统计
for idx, file_path in enumerate(file_paths):
if file_path not in results:
continue
should_sample, in_top3, error = results[file_path]
if error:
continue
# 统计信息
if in_top3:
stats['in_top3_count'] += 1
if should_sample:
stats['in_top3_sampled'] += 1
else:
stats['not_in_top3_count'] += 1
if should_sample:
stats['not_in_top3_sampled'] += 1
if should_sample:
sampled_paths.append(file_path)
sampled_indices.append(idx)
print(f"\n采样完成:")
print(f" 总歌曲数: {len(file_paths)}")
print(f" 采样歌曲数: {len(sampled_paths)}")
print(f" 采样比例: {len(sampled_paths)/len(file_paths)*100:.2f}%")
print(f" 不在前3个的歌曲数: {stats['not_in_top3_count']}")
print(f" 不在前3个且被采样的歌曲数: {stats['not_in_top3_sampled']}")
if stats['not_in_top3_count'] > 0:
print(f" 不在前3个的歌曲采样比例: {stats['not_in_top3_sampled']/stats['not_in_top3_count']*100:.2f}%")
print(f" 在前3个的歌曲数: {stats['in_top3_count']}")
print(f" 在前3个且被采样的歌曲数: {stats['in_top3_sampled']}")
if stats['in_top3_count'] > 0:
print(f" 在前3个的歌曲采样比例: {stats['in_top3_sampled']/stats['in_top3_count']*100:.2f}%")
return sampled_paths, sampled_indices, stats
def _save_single_file(args_tuple):
"""
保存单个文件的辅助函数(用于多线程)
支持延迟加载如果data是路径则从文件加载
Args:
args_tuple: (data, original_path, output_dir)
Returns:
tuple: (success, original_path) 或 (False, original_path, error_msg)
"""
data, original_path, output_dir = args_tuple
try:
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
data = np.load(data)['arr_0']
# 保持相对路径结构
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
output_path = output_dir / relative_path
output_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(output_path, data)
return (True, original_path)
except Exception as e:
error_msg = str(e)
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
return (False, original_path, error_msg)
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
"""
保存采样后的数据(多线程版本)
Args:
sampled_data: 采样后的数据列表
sampled_paths: 采样后的文件路径列表
output_dir: 输出目录
num_threads: 线程数None表示使用CPU核心数
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n保存采样数据到: {output_dir}")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(sampled_data))
# 准备参数
save_args = [(data, original_path, output_dir)
for data, original_path in zip(sampled_data, sampled_paths)]
# 使用多线程保存文件
success_count = 0
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_save_single_file, args)
for args in save_args]
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
try:
result = future.result(timeout=300) # 设置超时避免卡死
if isinstance(result, tuple) and len(result) >= 2:
success = result[0]
if success:
success_count += 1
except Exception as e:
print(f"错误: 获取保存结果时出错: {e}")
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="基于position和duration token的重采样V2")
parser.add_argument("--json_path", type=str,
default="octuple_token_analysis_report.json",
help="token分析报告JSON文件路径")
parser.add_argument("--data_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
help="数据目录路径")
parser.add_argument("--output_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2",
help="输出目录路径")
parser.add_argument("--contain_sample_ratio", type=float, default=0.1,
help="对于包含前3个token的歌曲采样比例 (默认: 0.1)")
parser.add_argument("--not_contain_sample_ratio", type=float, default=0.9,
help="对于不包含前3个token的歌曲采样比例 (默认: 0.9)")
parser.add_argument("--top_k", type=int, default=3,
help="使用前k个最常见的token (默认: 3)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子")
parser.add_argument("--num_threads", type=int, default=None,
help="线程数None表示使用CPU核心数 (默认: None)")
args = parser.parse_args()
# 1. 加载position和duration的前top_k个最常见token
top_position_tokens = load_top_tokens_from_json(
args.json_path, "position", top_k=args.top_k
)
top_duration_tokens = load_top_tokens_from_json(
args.json_path, "duration", top_k=args.top_k
)
if not top_position_tokens or not top_duration_tokens:
print("错误: 无法加载前top_k个token")
return
# 2. 获取所有数据文件路径
file_paths = get_data_file_paths(args.data_dir)
if not file_paths:
print("错误: 未找到任何数据文件")
return
print(f"\n共找到 {len(file_paths)} 个数据文件")
# 3. 根据新逻辑进行重采样
sampled_paths, sampled_indices, stats = resample_songs(
file_paths,
top_position_tokens,
top_duration_tokens,
contain_sample_ratio=args.contain_sample_ratio,
not_contain_sample_ratio=args.not_contain_sample_ratio,
num_threads=args.num_threads,
seed=args.seed
)
# 4. 保存采样结果(延迟加载)
sampled_data = sampled_paths # 使用路径,延迟加载
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
# 5. 保存采样索引(可选,用于后续分析)
indices_file = Path(args.output_dir) / "sampled_indices.npy"
np.save(indices_file, np.array(sampled_indices))
print(f"\n采样索引已保存到: {indices_file}")
# 6. 保存采样信息
info = {
"total_samples": len(file_paths),
"sampled_samples": len(sampled_paths),
"contain_sample_ratio": args.contain_sample_ratio,
"not_contain_sample_ratio": args.not_contain_sample_ratio,
"top_k": args.top_k,
"top_position_tokens": sorted(list(top_position_tokens)),
"top_duration_tokens": sorted(list(top_duration_tokens)),
"stats": stats
}
info_file = Path(args.output_dir) / "resample_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(info, f, indent=2, ensure_ascii=False)
print(f"采样信息已保存到: {info_file}")
if __name__ == "__main__":
main()