473 lines
17 KiB
Python
473 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
基于position和duration token的重采样脚本V2
|
||
对于每首歌:
|
||
1. 如果包含的position和duration不在总数据集前3个,则必定采样
|
||
2. 对于包含的,以某个固定的百分比采样
|
||
3. 两个条件满足一个即可
|
||
"""
|
||
|
||
import os
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from collections import Counter
|
||
from tqdm import tqdm
|
||
import json
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import multiprocessing as mp
|
||
|
||
# Octuple的列名定义
|
||
COLUMN_NAMES = [
|
||
"pitch", # 0: Pitch/PitchDrum
|
||
"position", # 1: Position
|
||
"bar", # 2: Bar
|
||
"velocity", # 3: Velocity
|
||
"duration", # 4: Duration
|
||
"program", # 5: Program
|
||
"tempo", # 6: Tempo
|
||
"timesig" # 7: TimeSignature
|
||
]
|
||
|
||
|
||
def load_top_tokens_from_json(json_path, column_name, top_k=3):
|
||
"""
|
||
从JSON文件中加载指定列的前top_k个最常见的token
|
||
|
||
Args:
|
||
json_path: JSON文件路径
|
||
column_name: 列名(如'position'或'duration')
|
||
top_k: 返回前k个最常见的token
|
||
|
||
Returns:
|
||
set: 前top_k个最常见的token集合
|
||
"""
|
||
print(f"读取分布文件: {json_path}")
|
||
with open(json_path, 'r', encoding='utf-8') as f:
|
||
report = json.load(f)
|
||
|
||
columns = report.get('columns', {})
|
||
|
||
if column_name not in columns:
|
||
print(f"警告: 列 {column_name} 不在报告中")
|
||
return set()
|
||
|
||
col_data = columns[column_name]
|
||
token_counts = col_data.get('token_counts', {})
|
||
|
||
# 按出现次数排序,获取前top_k个
|
||
sorted_tokens = sorted(token_counts.items(), key=lambda x: int(x[1]), reverse=True)
|
||
top_tokens = {int(token_str) for token_str, _ in sorted_tokens[:top_k]}
|
||
|
||
print(f" 列 {column_name} 的前{top_k}个最常见token: {sorted(top_tokens)}")
|
||
|
||
return top_tokens
|
||
|
||
|
||
def get_data_tokens(data, col_idx):
|
||
"""
|
||
获取单个数据在指定列的所有唯一token
|
||
|
||
Args:
|
||
data: numpy数组 (num_tokens, num_columns) 或文件路径
|
||
col_idx: 列索引
|
||
|
||
Returns:
|
||
set: 唯一token集合
|
||
"""
|
||
# 如果data是路径,则加载它
|
||
if isinstance(data, (str, Path)):
|
||
try:
|
||
data = np.load(data)['arr_0']
|
||
except Exception as e:
|
||
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
|
||
|
||
if data.size == 0:
|
||
return set()
|
||
|
||
tokens = data[:, col_idx]
|
||
unique_tokens = set(int(token) for token in np.unique(tokens))
|
||
|
||
return unique_tokens
|
||
|
||
|
||
def should_sample_song(data, top_position_tokens, top_duration_tokens,
|
||
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, rng=None):
|
||
"""
|
||
判断一首歌是否应该被采样
|
||
|
||
Args:
|
||
data: numpy数组 (num_tokens, num_columns) 或文件路径
|
||
top_position_tokens: position列的前3个最常见token集合
|
||
top_duration_tokens: duration列的前3个最常见token集合
|
||
contain_sample_ratio: 对于包含前3个token的歌曲,采样比例
|
||
not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率)
|
||
rng: 随机数生成器,如果为None则使用全局的np.random
|
||
|
||
Returns:
|
||
tuple: (是否应该采样, 是否在前3个) - 在前3个指position和duration都在前3个
|
||
"""
|
||
# 获取position和duration列的唯一token
|
||
position_idx = COLUMN_NAMES.index("position")
|
||
duration_idx = COLUMN_NAMES.index("duration")
|
||
|
||
position_tokens = get_data_tokens(data, position_idx)
|
||
duration_tokens = get_data_tokens(data, duration_idx)
|
||
|
||
# 判断是否在前3个
|
||
position_in_top3 = bool(position_tokens & top_position_tokens)
|
||
duration_in_top3 = bool(duration_tokens & top_duration_tokens)
|
||
in_top3 = position_in_top3 and duration_in_top3
|
||
|
||
if rng is None:
|
||
rng = np.random
|
||
|
||
# 条件1: 如果position或duration不包含前3个token,以更高概率采样
|
||
if not position_in_top3 or not duration_in_top3:
|
||
should_sample = rng.random() < not_contain_sample_ratio
|
||
return should_sample, False
|
||
|
||
# 条件2: 如果包含前3个token,则以固定百分比采样
|
||
should_sample = rng.random() < contain_sample_ratio
|
||
return should_sample, True
|
||
|
||
|
||
def _load_single_file(npz_file):
|
||
"""
|
||
加载单个npz文件的辅助函数(用于多线程)
|
||
|
||
Args:
|
||
npz_file: npz文件路径
|
||
|
||
Returns:
|
||
tuple: (data, file_path) 或 None(如果加载失败)
|
||
"""
|
||
try:
|
||
data = np.load(npz_file)['arr_0']
|
||
if data.ndim == 2:
|
||
return (data, npz_file)
|
||
elif data.ndim == 1:
|
||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||
return None
|
||
except Exception as e:
|
||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||
return None
|
||
|
||
|
||
def get_data_file_paths(data_dir):
|
||
"""
|
||
获取所有数据文件路径(不加载数据)
|
||
|
||
Args:
|
||
data_dir: 数据目录路径
|
||
|
||
Returns:
|
||
list: 文件路径列表
|
||
"""
|
||
data_dir = Path(data_dir)
|
||
npz_files = []
|
||
|
||
if data_dir.exists():
|
||
npz_files = sorted(list(data_dir.rglob("*.npz")))
|
||
|
||
if not npz_files:
|
||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||
return []
|
||
|
||
print(f"找到 {len(npz_files)} 个.npz文件")
|
||
return npz_files
|
||
|
||
|
||
def resample_songs(file_paths, top_position_tokens, top_duration_tokens,
|
||
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9,
|
||
num_threads=None, seed=42):
|
||
"""
|
||
根据新逻辑进行重采样
|
||
|
||
Args:
|
||
file_paths: 文件路径列表
|
||
top_position_tokens: position列的前3个最常见token集合
|
||
top_duration_tokens: duration列的前3个最常见token集合
|
||
contain_sample_ratio: 对于包含前3个token的歌曲,采样比例
|
||
not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率)
|
||
num_threads: 线程数,None表示使用CPU核心数
|
||
seed: 随机种子
|
||
|
||
Returns:
|
||
tuple: (sampled_paths, sampled_indices, stats)
|
||
"""
|
||
import threading
|
||
|
||
# 为每个线程创建独立的随机数生成器
|
||
thread_local = threading.local()
|
||
|
||
def get_thread_rng():
|
||
"""获取当前线程的随机数生成器"""
|
||
if not hasattr(thread_local, 'rng'):
|
||
# 使用线程ID和种子创建独立的随机数生成器
|
||
thread_id = threading.current_thread().ident
|
||
thread_local.rng = np.random.RandomState(seed + hash(thread_id) % 1000000)
|
||
return thread_local.rng
|
||
|
||
if num_threads is None:
|
||
num_threads = min(mp.cpu_count(), len(file_paths))
|
||
|
||
print(f"\n开始重采样,使用 {num_threads} 个线程...")
|
||
print(f" 包含前3个token的采样比例: {contain_sample_ratio*100:.1f}%")
|
||
print(f" 不包含前3个token的采样比例: {not_contain_sample_ratio*100:.1f}%")
|
||
|
||
def _should_sample_wrapper(args_tuple):
|
||
"""判断是否采样的包装函数(用于多线程)"""
|
||
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio = args_tuple
|
||
try:
|
||
# 使用线程本地的随机数生成器
|
||
thread_rng = get_thread_rng()
|
||
should_sample, in_top3 = should_sample_song(
|
||
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio, thread_rng
|
||
)
|
||
return (file_path, should_sample, in_top3, None)
|
||
except Exception as e:
|
||
return (file_path, False, False, str(e))
|
||
|
||
# 准备参数
|
||
sample_args = [(file_path, top_position_tokens, top_duration_tokens,
|
||
contain_sample_ratio, not_contain_sample_ratio)
|
||
for file_path in file_paths]
|
||
|
||
# 使用多线程判断每首歌是否应该采样
|
||
sampled_paths = []
|
||
sampled_indices = []
|
||
stats = {
|
||
'not_in_top3_count': 0, # 不在前3个的歌曲数量
|
||
'not_in_top3_sampled': 0, # 不在前3个且被采样的歌曲数量
|
||
'in_top3_count': 0, # 在前3个的歌曲数量
|
||
'in_top3_sampled': 0 # 在前3个且被采样的歌曲数量
|
||
}
|
||
|
||
# 限制并发任务数量,避免一次性提交过多任务
|
||
batch_size = min(1000, len(file_paths))
|
||
results = {}
|
||
|
||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||
# 分批提交任务
|
||
for batch_start in range(0, len(sample_args), batch_size):
|
||
batch_end = min(batch_start + batch_size, len(sample_args))
|
||
batch_args = sample_args[batch_start:batch_end]
|
||
|
||
futures = {executor.submit(_should_sample_wrapper, args): args[0]
|
||
for args in batch_args}
|
||
|
||
# 收集结果
|
||
for future in tqdm(as_completed(futures), total=len(futures),
|
||
desc=f"判断采样 [{batch_start+1}-{batch_end}/{len(file_paths)}]",
|
||
leave=False):
|
||
try:
|
||
file_path, should_sample, in_top3, error = future.result(timeout=60)
|
||
results[file_path] = (should_sample, in_top3, error)
|
||
if error:
|
||
print(f"警告: 处理 {file_path} 时出错: {error}")
|
||
except Exception as e:
|
||
print(f"错误: 获取结果时出错: {e}")
|
||
|
||
# 按原始顺序处理结果,并统计
|
||
for idx, file_path in enumerate(file_paths):
|
||
if file_path not in results:
|
||
continue
|
||
|
||
should_sample, in_top3, error = results[file_path]
|
||
if error:
|
||
continue
|
||
|
||
# 统计信息
|
||
if in_top3:
|
||
stats['in_top3_count'] += 1
|
||
if should_sample:
|
||
stats['in_top3_sampled'] += 1
|
||
else:
|
||
stats['not_in_top3_count'] += 1
|
||
if should_sample:
|
||
stats['not_in_top3_sampled'] += 1
|
||
|
||
if should_sample:
|
||
sampled_paths.append(file_path)
|
||
sampled_indices.append(idx)
|
||
|
||
print(f"\n采样完成:")
|
||
print(f" 总歌曲数: {len(file_paths)}")
|
||
print(f" 采样歌曲数: {len(sampled_paths)}")
|
||
print(f" 采样比例: {len(sampled_paths)/len(file_paths)*100:.2f}%")
|
||
print(f" 不在前3个的歌曲数: {stats['not_in_top3_count']}")
|
||
print(f" 不在前3个且被采样的歌曲数: {stats['not_in_top3_sampled']}")
|
||
if stats['not_in_top3_count'] > 0:
|
||
print(f" 不在前3个的歌曲采样比例: {stats['not_in_top3_sampled']/stats['not_in_top3_count']*100:.2f}%")
|
||
print(f" 在前3个的歌曲数: {stats['in_top3_count']}")
|
||
print(f" 在前3个且被采样的歌曲数: {stats['in_top3_sampled']}")
|
||
if stats['in_top3_count'] > 0:
|
||
print(f" 在前3个的歌曲采样比例: {stats['in_top3_sampled']/stats['in_top3_count']*100:.2f}%")
|
||
|
||
return sampled_paths, sampled_indices, stats
|
||
|
||
|
||
def _save_single_file(args_tuple):
|
||
"""
|
||
保存单个文件的辅助函数(用于多线程)
|
||
支持延迟加载:如果data是路径,则从文件加载
|
||
|
||
Args:
|
||
args_tuple: (data, original_path, output_dir)
|
||
|
||
Returns:
|
||
tuple: (success, original_path) 或 (False, original_path, error_msg)
|
||
"""
|
||
data, original_path, output_dir = args_tuple
|
||
try:
|
||
# 如果data是路径,则加载它
|
||
if isinstance(data, (str, Path)):
|
||
data = np.load(data)['arr_0']
|
||
|
||
# 保持相对路径结构
|
||
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
|
||
output_path = output_dir / relative_path
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
np.savez_compressed(output_path, data)
|
||
return (True, original_path)
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
|
||
return (False, original_path, error_msg)
|
||
|
||
|
||
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
|
||
"""
|
||
保存采样后的数据(多线程版本)
|
||
|
||
Args:
|
||
sampled_data: 采样后的数据列表
|
||
sampled_paths: 采样后的文件路径列表
|
||
output_dir: 输出目录
|
||
num_threads: 线程数,None表示使用CPU核心数
|
||
"""
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"\n保存采样数据到: {output_dir}")
|
||
|
||
if num_threads is None:
|
||
num_threads = min(mp.cpu_count(), len(sampled_data))
|
||
|
||
# 准备参数
|
||
save_args = [(data, original_path, output_dir)
|
||
for data, original_path in zip(sampled_data, sampled_paths)]
|
||
|
||
# 使用多线程保存文件
|
||
success_count = 0
|
||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||
# 提交所有任务
|
||
futures = [executor.submit(_save_single_file, args)
|
||
for args in save_args]
|
||
|
||
# 收集结果
|
||
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
|
||
try:
|
||
result = future.result(timeout=300) # 设置超时避免卡死
|
||
if isinstance(result, tuple) and len(result) >= 2:
|
||
success = result[0]
|
||
if success:
|
||
success_count += 1
|
||
except Exception as e:
|
||
print(f"错误: 获取保存结果时出错: {e}")
|
||
|
||
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="基于position和duration token的重采样V2")
|
||
parser.add_argument("--json_path", type=str,
|
||
default="octuple_token_analysis_report.json",
|
||
help="token分析报告JSON文件路径")
|
||
parser.add_argument("--data_dir", type=str,
|
||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
|
||
help="数据目录路径")
|
||
parser.add_argument("--output_dir", type=str,
|
||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2",
|
||
help="输出目录路径")
|
||
parser.add_argument("--contain_sample_ratio", type=float, default=0.1,
|
||
help="对于包含前3个token的歌曲,采样比例 (默认: 0.1)")
|
||
parser.add_argument("--not_contain_sample_ratio", type=float, default=0.9,
|
||
help="对于不包含前3个token的歌曲,采样比例 (默认: 0.9)")
|
||
parser.add_argument("--top_k", type=int, default=3,
|
||
help="使用前k个最常见的token (默认: 3)")
|
||
parser.add_argument("--seed", type=int, default=42,
|
||
help="随机种子")
|
||
parser.add_argument("--num_threads", type=int, default=None,
|
||
help="线程数,None表示使用CPU核心数 (默认: None)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 1. 加载position和duration的前top_k个最常见token
|
||
top_position_tokens = load_top_tokens_from_json(
|
||
args.json_path, "position", top_k=args.top_k
|
||
)
|
||
top_duration_tokens = load_top_tokens_from_json(
|
||
args.json_path, "duration", top_k=args.top_k
|
||
)
|
||
|
||
if not top_position_tokens or not top_duration_tokens:
|
||
print("错误: 无法加载前top_k个token")
|
||
return
|
||
|
||
# 2. 获取所有数据文件路径
|
||
file_paths = get_data_file_paths(args.data_dir)
|
||
|
||
if not file_paths:
|
||
print("错误: 未找到任何数据文件")
|
||
return
|
||
|
||
print(f"\n共找到 {len(file_paths)} 个数据文件")
|
||
|
||
# 3. 根据新逻辑进行重采样
|
||
sampled_paths, sampled_indices, stats = resample_songs(
|
||
file_paths,
|
||
top_position_tokens,
|
||
top_duration_tokens,
|
||
contain_sample_ratio=args.contain_sample_ratio,
|
||
not_contain_sample_ratio=args.not_contain_sample_ratio,
|
||
num_threads=args.num_threads,
|
||
seed=args.seed
|
||
)
|
||
|
||
# 4. 保存采样结果(延迟加载)
|
||
sampled_data = sampled_paths # 使用路径,延迟加载
|
||
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
|
||
|
||
# 5. 保存采样索引(可选,用于后续分析)
|
||
indices_file = Path(args.output_dir) / "sampled_indices.npy"
|
||
np.save(indices_file, np.array(sampled_indices))
|
||
print(f"\n采样索引已保存到: {indices_file}")
|
||
|
||
# 6. 保存采样信息
|
||
info = {
|
||
"total_samples": len(file_paths),
|
||
"sampled_samples": len(sampled_paths),
|
||
"contain_sample_ratio": args.contain_sample_ratio,
|
||
"not_contain_sample_ratio": args.not_contain_sample_ratio,
|
||
"top_k": args.top_k,
|
||
"top_position_tokens": sorted(list(top_position_tokens)),
|
||
"top_duration_tokens": sorted(list(top_duration_tokens)),
|
||
"stats": stats
|
||
}
|
||
|
||
info_file = Path(args.output_dir) / "resample_info.json"
|
||
with open(info_file, 'w', encoding='utf-8') as f:
|
||
json.dump(info, f, indent=2, ensure_ascii=False)
|
||
print(f"采样信息已保存到: {info_file}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|