1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -20,10 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from symusic import Score
from miditok import Octuple, TokenizerConfig
from itertools import groupby, chain
from random import shuffle, seed
lock = RLock()
def shuffled(seq):
shuffle(seq)
return seq
def permute_inside_and_across_tracks(seq):
seq_sorted = sorted(seq, key=lambda x: x[5])
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
def convert_event_dicts(dict_list):
"""
将 event 词表列表按顺序转换为结构化输出
@ -59,7 +71,7 @@ def convert_event_dicts(dict_list):
return result
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool):
try:
score = Score(midi_path, ttype="tick")
@ -71,6 +83,8 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
# 分词
tok_seq = tokenizer(score)
token_ids = tok_seq.ids
if whether_shuffle:
token_ids = permute_inside_and_across_tracks(token_ids)
# add sos token at the beginning
vocab = tokenizer.vocab
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
@ -86,7 +100,7 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
print(f" × 处理文件时出错:{midi_path} -> {e}")
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False):
# === 1. 初始化分词器并保存词表 ===
print("初始化分词器 Octuple...")
config = TokenizerConfig(
@ -108,15 +122,20 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
os.makedirs(output_dir, exist_ok=True)
# === 3. 收集 MIDI 文件 ===
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
midi_paths = list(midi_paths)
# midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
# glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
# midi_paths = list(midi_paths)
midi_paths = []
for root, _, files in os.walk(midi_dir):
for file in files:
if file.lower().endswith(('.mid', '.midi')):
midi_paths.append(os.path.join(root, file))
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
# === 4. 并行处理 ===
results = []
with ProcessPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths}
for future in tqdm(as_completed(futures), total=len(futures)):
res = future.result()
@ -128,8 +147,18 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
if __name__ == "__main__":
import argparse
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)")
parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录")
parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序")
dataset_name = midi_directory.split("/")[-1]
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
output_dir = tuneidx_prefix
preprocess_midi_directory(midi_directory, output_dir)
args = parser.parse_args()
preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle)

View File

@ -0,0 +1,39 @@
from itertools import groupby, chain
from random import shuffle, seed
def shuffled(seq):
shuffle(seq)
return seq
def permute_inside_and_across_tracks(seq):
seq_sorted = sorted(seq, key=lambda x: x[5])
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
# (PitchDrum, Position, Bar, Velocity, Duration, Program, Tempo, TimeSignature)
seq = [
# Program 0
(60, 0, 0, 90, 96, 0, 120, 16),
(64, 48,0, 88, 96, 0, 120, 16),
(67, 96,0, 92, 96, 0, 120, 16),
# Program 32
(40, 0, 0, 80, 192, 32, 120, 16),
(43, 0, 1, 78, 192, 32, 120, 16),
# Program 40
(72, 24,0, 85, 72, 40, 120, 16),
(74, 72,0, 83, 72, 40, 120, 16),
(76, 24,1, 86, 72, 40, 120, 16),
]
# seed(42)
inside_track_permuted_and_track_permuted = permute_inside_and_across_tracks(seq)
print("原始 seq")
for e in seq:
print(e)
print("\n打乱后的 seq")
for e in inside_track_permuted_and_track_permuted:
print(e)

View File

@ -0,0 +1,634 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于token分布距离的重采样脚本
读取octuple_token_analysis_report.json计算每个数据与整体分布的距离
按照距离加权采样,距离越远的越容易被采样
"""
import os
import numpy as np
from pathlib import Path
from collections import Counter
from tqdm import tqdm
import json
from scipy.stats import entropy, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_distribution_from_json(json_path):
"""
从JSON文件中加载整体token分布
Args:
json_path: JSON文件路径
Returns:
dict: {column_name: {token: probability}}
"""
print(f"读取分布文件: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
report = json.load(f)
distributions = {}
columns = report.get('columns', {})
for col_name in COLUMN_NAMES:
if col_name not in columns:
print(f"警告: 列 {col_name} 不在报告中")
distributions[col_name] = {}
continue
col_data = columns[col_name]
token_counts = col_data.get('token_counts', {})
total_tokens = col_data.get('total_tokens', 1)
# 转换为概率分布
distribution = {}
for token_str, count in token_counts.items():
token = int(token_str)
distribution[token] = count / total_tokens
distributions[col_name] = distribution
print(f"{col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}")
return distributions
def compute_data_distribution(data, col_idx):
"""
计算单个数据在指定列的token分布
Args:
data: numpy数组 (num_tokens, num_columns)
col_idx: 列索引
Returns:
dict: {token: probability}
"""
if data.size == 0:
return {}
tokens = data[:, col_idx]
unique, counts = np.unique(tokens, return_counts=True)
total = len(tokens)
distribution = {}
for token, count in zip(unique, counts):
distribution[int(token)] = count / total
return distribution
def compute_emd_distance(dist1, dist2):
"""
使用推土机距离Earth Mover's Distance / Wasserstein距离计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability},已归一化
dist2: 分布2dict {token: probability},已归一化
Returns:
float: EMD距离
"""
# 获取所有token的并集并排序
all_tokens = sorted(set(dist1.keys()) | set(dist2.keys()))
if not all_tokens:
return 0.0
# 构建概率向量和token值向量
p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens])
q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens])
token_values = np.array(all_tokens, dtype=float)
# 归一化(处理数值误差)
p_sum = p_weights.sum()
q_sum = q_weights.sum()
if p_sum < 1e-10 or q_sum < 1e-10:
return 0.0
p_weights = p_weights / p_sum
q_weights = q_weights / q_sum
# 使用Wasserstein距离1-Wasserstein距离即推土机距离
# wasserstein_distance需要两个分布的样本值位置和权重
# 对于离散分布我们使用token值作为位置
emd = wasserstein_distance(token_values, token_values, p_weights, q_weights)
return emd
def compute_distribution_distance(dist1, dist2, method='emd'):
"""
计算两个分布之间的距离
Args:
dist1: 分布1dict {token: probability}
dist2: 分布2dict {token: probability}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 分布距离
"""
if method == 'emd':
return compute_emd_distance(dist1, dist2)
# 获取所有token的并集
all_tokens = set(dist1.keys()) | set(dist2.keys())
if not all_tokens:
return 0.0
# 构建概率向量
p = np.array([dist1.get(token, 0.0) for token in all_tokens])
q = np.array([dist2.get(token, 0.0) for token in all_tokens])
# 归一化(处理数值误差)
p = p / (p.sum() + 1e-10)
q = q / (q.sum() + 1e-10)
if method == 'js':
# Jensen-Shannon散度对称范围[0, 1]
return jensenshannon(p, q)
elif method == 'kl':
# KL散度非对称需要处理零值
# 添加小的平滑项避免log(0)
p = p + 1e-10
q = q + 1e-10
p = p / p.sum()
q = q / q.sum()
return entropy(p, q)
else:
raise ValueError(f"未知的距离方法: {method}")
def extract_subdistribution(global_dist, data_tokens):
"""
从全局分布中提取只包含数据中出现的token的子分布并归一化
Args:
global_dist: 全局分布dict {token: probability}
data_tokens: 数据中出现的token集合set或list
Returns:
dict: 子分布dict {token: probability},已归一化
"""
if not data_tokens or not global_dist:
return {}
# 提取子分布
sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens}
# 归一化
total = sum(sub_dist.values())
if total < 1e-10:
return {}
normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()}
return normalized_sub_dist
def compute_data_distance(data, global_distributions, method='emd'):
"""
计算单个数据与整体分布的距离
对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布,
都进行归一化然后计算推土机距离
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载)
global_distributions: 整体分布dict {column_name: {token: probability}}
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
Returns:
float: 平均距离(跨所有列)
"""
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
try:
data = np.load(data)['arr_0']
except Exception as e:
# 不打印错误,让调用者处理
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
distances = []
for col_idx, col_name in enumerate(COLUMN_NAMES):
# 计算该数据在该列的分布
data_dist = compute_data_distribution(data, col_idx)
# 获取整体分布
global_dist = global_distributions.get(col_name, {})
if not data_dist or not global_dist:
continue
# 从全局分布中提取只包含数据中出现的token的子分布
data_tokens = set(data_dist.keys())
sub_global_dist = extract_subdistribution(global_dist, data_tokens)
if not sub_global_dist:
continue
# 归一化数据分布
data_dist_sum = sum(data_dist.values())
if data_dist_sum < 1e-10:
continue
normalized_data_dist = {token: prob / data_dist_sum
for token, prob in data_dist.items()}
# 计算距离(两个分布都已归一化)
dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method)
distances.append(dist)
# 返回平均距离
return np.mean(distances) if distances else 0.0
def _load_single_file(npz_file):
"""
加载单个npz文件的辅助函数用于多线程
Args:
npz_file: npz文件路径
Returns:
tuple: (data, file_path) 或 None如果加载失败
"""
try:
data = np.load(npz_file)['arr_0']
if data.ndim == 2:
return (data, npz_file)
elif data.ndim == 1:
print(f"警告: {npz_file} 是一维数组,跳过")
return None
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
return None
def get_data_file_paths(data_dir):
"""
获取所有数据文件路径(不加载数据)
Args:
data_dir: 数据目录路径
Returns:
list: 文件路径列表
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件")
return npz_files
def load_data_with_paths(data_dir, num_threads=None, lazy=False):
"""
加载所有数据并返回数据路径列表(多线程版本)
Args:
data_dir: 数据目录路径
num_threads: 线程数None表示使用CPU核心数
lazy: 如果为True只返回文件路径不加载数据
Returns:
tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return [], []
if lazy:
print(f"找到 {len(npz_files)} 个.npz文件延迟加载模式")
return None, npz_files
print(f"找到 {len(npz_files)} 个.npz文件开始加载...")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(npz_files))
all_data = []
file_paths = []
# 使用多线程加载文件
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(_load_single_file, npz_file): npz_file
for npz_file in npz_files}
for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"):
result = future.result()
if result is not None:
data, file_path = result
all_data.append(data)
file_paths.append(file_path)
# 保持原始顺序
if file_paths:
sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0]))
file_paths, all_data = zip(*sorted_pairs)
file_paths = list(file_paths)
all_data = list(all_data)
return all_data, file_paths
def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True):
"""
根据距离进行加权重采样
Args:
file_paths: 文件路径列表
distances: 距离列表
sample_ratio: 采样比例
method: 距离计算方法(用于确定权重方向)
lazy: 如果为True返回文件路径而不是数据
Returns:
tuple: (sampled_data_or_paths, sampled_paths, sampled_indices)
"""
n_samples = int(len(file_paths) * sample_ratio)
print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)")
# 将距离转换为权重
# 距离越远,权重越大
distances = np.array(distances)
# 处理零距离或异常值
min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10
distances = np.maximum(distances, min_dist * 0.1)
# 归一化距离到[0, 1],然后转换为权重
# 使用指数函数增强距离差异
normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10)
weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样
# 归一化权重
weights = weights / weights.sum()
# 加权随机采样
indices = np.arange(len(file_paths))
sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights)
sampled_paths = [file_paths[i] for i in sampled_indices]
# 如果lazy=True返回路径否则加载数据
if lazy:
sampled_data = sampled_paths # 返回路径,延迟加载
else:
# 加载采样后的数据
sampled_data = []
for path in tqdm(sampled_paths, desc="加载采样数据"):
try:
data = np.load(path)['arr_0']
sampled_data.append(data)
except Exception as e:
print(f"错误: 加载 {path} 时出错: {e}")
sampled_data.append(None)
print(f"采样完成:")
print(f" 采样数据数量: {len(sampled_paths)}")
print(f" 平均距离: {distances[sampled_indices].mean():.6f}")
print(f" 最小距离: {distances[sampled_indices].min():.6f}")
print(f" 最大距离: {distances[sampled_indices].max():.6f}")
return sampled_data, sampled_paths, sampled_indices
def _save_single_file(args_tuple):
"""
保存单个文件的辅助函数(用于多线程)
支持延迟加载如果data是路径则从文件加载
Args:
args_tuple: (data, original_path, output_dir)
Returns:
tuple: (success, original_path) 或 (False, original_path, error_msg)
"""
data, original_path, output_dir = args_tuple
try:
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
data = np.load(data)['arr_0']
# 保持相对路径结构
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
output_path = output_dir / relative_path
output_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(output_path, data)
return (True, original_path)
except Exception as e:
error_msg = str(e)
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
return (False, original_path, error_msg)
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
"""
保存采样后的数据(多线程版本)
Args:
sampled_data: 采样后的数据列表
sampled_paths: 采样后的文件路径列表
output_dir: 输出目录
num_threads: 线程数None表示使用CPU核心数
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n保存采样数据到: {output_dir}")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(sampled_data))
# 准备参数
save_args = [(data, original_path, output_dir)
for data, original_path in zip(sampled_data, sampled_paths)]
# 使用多线程保存文件
success_count = 0
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_save_single_file, args)
for args in save_args]
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
try:
result = future.result(timeout=300) # 设置超时避免卡死
if isinstance(result, tuple) and len(result) >= 2:
success = result[0]
if success:
success_count += 1
except Exception as e:
print(f"错误: 获取保存结果时出错: {e}")
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="基于token分布距离的重采样")
parser.add_argument("--json_path", type=str,
default="octuple_token_analysis_report.json",
help="token分析报告JSON文件路径")
parser.add_argument("--data_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
help="数据目录路径")
parser.add_argument("--output_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled",
help="输出目录路径")
parser.add_argument("--sample_ratio", type=float, default=0.3,
help="采样比例 (默认: 0.3)")
parser.add_argument("--distance_method", type=str, default="emd",
choices=["emd", "js", "kl"],
help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子")
parser.add_argument("--num_threads", type=int, default=1,
help="线程数None表示使用CPU核心数 (默认: None)")
args = parser.parse_args()
# 设置随机种子
np.random.seed(args.seed)
# 1. 加载整体分布
global_distributions = load_distribution_from_json(args.json_path)
# 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据)
_, file_paths = load_data_with_paths(args.data_dir, lazy=True)
if not file_paths:
print("错误: 未找到任何数据文件")
return
print(f"\n共找到 {len(file_paths)} 个数据文件")
# 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载)
print("\n计算每个数据与整体分布的距离(延迟加载模式)...")
def _compute_distance_wrapper(args_tuple):
"""计算距离的包装函数(用于多线程,支持延迟加载)"""
idx, file_path, global_dists, method = args_tuple
try:
distance = compute_data_distance(file_path, global_dists, method=method)
return (idx, distance, None)
except Exception as e:
return (idx, 0.0, str(e))
if args.num_threads is None:
num_threads = min(mp.cpu_count(), len(file_paths))
else:
num_threads = args.num_threads
# 准备参数(使用文件路径而不是数据,包含索引)
distance_args = [(i, file_path, global_distributions, args.distance_method)
for i, file_path in enumerate(file_paths)]
# 使用多线程计算距离(按需加载数据)
# 初始化结果列表,保持顺序
distances = [0.0] * len(file_paths)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_compute_distance_wrapper, args)
for args in distance_args]
# 收集结果,使用 tqdm 显示进度
for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"):
try:
idx, distance, error = future.result(timeout=300) # 设置超时避免卡死
distances[idx] = distance
if error:
print(f"警告: 计算距离时出错 (索引 {idx}): {error}")
except Exception as e:
print(f"错误: 获取结果时出错: {e}")
# 如果无法获取结果,保持默认值 0.0
distances = np.array(distances)
print(f"\n距离统计:")
print(f" 平均距离: {distances.mean():.6f}")
print(f" 最小距离: {distances.min():.6f}")
print(f" 最大距离: {distances.max():.6f}")
print(f" 标准差: {distances.std():.6f}")
# 4. 根据距离进行加权采样(延迟加载模式)
sampled_data, sampled_paths, sampled_indices = weighted_resample(
file_paths, distances,
sample_ratio=args.sample_ratio,
method=args.distance_method,
lazy=True # 使用延迟加载,避免重复加载数据
)
# 5. 保存采样结果(多线程,延迟加载)
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
# 6. 保存采样索引(可选,用于后续分析)
indices_file = Path(args.output_dir) / "sampled_indices.npy"
np.save(indices_file, sampled_indices)
print(f"\n采样索引已保存到: {indices_file}")
# 保存采样信息
info = {
"total_samples": len(file_paths),
"sampled_samples": len(sampled_data),
"sample_ratio": args.sample_ratio,
"distance_method": args.distance_method,
"distance_stats": {
"mean": float(distances.mean()),
"min": float(distances.min()),
"max": float(distances.max()),
"std": float(distances.std())
},
"sampled_distance_stats": {
"mean": float(distances[sampled_indices].mean()),
"min": float(distances[sampled_indices].min()),
"max": float(distances[sampled_indices].max()),
"std": float(distances[sampled_indices].std())
}
}
info_file = Path(args.output_dir) / "resample_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(info, f, indent=2, ensure_ascii=False)
print(f"采样信息已保存到: {info_file}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,472 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于position和duration token的重采样脚本V2
对于每首歌:
1. 如果包含的position和duration不在总数据集前3个则必定采样
2. 对于包含的,以某个固定的百分比采样
3. 两个条件满足一个即可
"""
import os
import numpy as np
from pathlib import Path
from collections import Counter
from tqdm import tqdm
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
def load_top_tokens_from_json(json_path, column_name, top_k=3):
"""
从JSON文件中加载指定列的前top_k个最常见的token
Args:
json_path: JSON文件路径
column_name: 列名(如'position''duration'
top_k: 返回前k个最常见的token
Returns:
set: 前top_k个最常见的token集合
"""
print(f"读取分布文件: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
report = json.load(f)
columns = report.get('columns', {})
if column_name not in columns:
print(f"警告: 列 {column_name} 不在报告中")
return set()
col_data = columns[column_name]
token_counts = col_data.get('token_counts', {})
# 按出现次数排序获取前top_k个
sorted_tokens = sorted(token_counts.items(), key=lambda x: int(x[1]), reverse=True)
top_tokens = {int(token_str) for token_str, _ in sorted_tokens[:top_k]}
print(f"{column_name} 的前{top_k}个最常见token: {sorted(top_tokens)}")
return top_tokens
def get_data_tokens(data, col_idx):
"""
获取单个数据在指定列的所有唯一token
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径
col_idx: 列索引
Returns:
set: 唯一token集合
"""
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
try:
data = np.load(data)['arr_0']
except Exception as e:
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
if data.size == 0:
return set()
tokens = data[:, col_idx]
unique_tokens = set(int(token) for token in np.unique(tokens))
return unique_tokens
def should_sample_song(data, top_position_tokens, top_duration_tokens,
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, rng=None):
"""
判断一首歌是否应该被采样
Args:
data: numpy数组 (num_tokens, num_columns) 或文件路径
top_position_tokens: position列的前3个最常见token集合
top_duration_tokens: duration列的前3个最常见token集合
contain_sample_ratio: 对于包含前3个token的歌曲采样比例
not_contain_sample_ratio: 对于不包含前3个token的歌曲采样比例更高概率
rng: 随机数生成器如果为None则使用全局的np.random
Returns:
tuple: (是否应该采样, 是否在前3个) - 在前3个指position和duration都在前3个
"""
# 获取position和duration列的唯一token
position_idx = COLUMN_NAMES.index("position")
duration_idx = COLUMN_NAMES.index("duration")
position_tokens = get_data_tokens(data, position_idx)
duration_tokens = get_data_tokens(data, duration_idx)
# 判断是否在前3个
position_in_top3 = bool(position_tokens & top_position_tokens)
duration_in_top3 = bool(duration_tokens & top_duration_tokens)
in_top3 = position_in_top3 and duration_in_top3
if rng is None:
rng = np.random
# 条件1: 如果position或duration不包含前3个token以更高概率采样
if not position_in_top3 or not duration_in_top3:
should_sample = rng.random() < not_contain_sample_ratio
return should_sample, False
# 条件2: 如果包含前3个token则以固定百分比采样
should_sample = rng.random() < contain_sample_ratio
return should_sample, True
def _load_single_file(npz_file):
"""
加载单个npz文件的辅助函数用于多线程
Args:
npz_file: npz文件路径
Returns:
tuple: (data, file_path) 或 None如果加载失败
"""
try:
data = np.load(npz_file)['arr_0']
if data.ndim == 2:
return (data, npz_file)
elif data.ndim == 1:
print(f"警告: {npz_file} 是一维数组,跳过")
return None
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
return None
def get_data_file_paths(data_dir):
"""
获取所有数据文件路径(不加载数据)
Args:
data_dir: 数据目录路径
Returns:
list: 文件路径列表
"""
data_dir = Path(data_dir)
npz_files = []
if data_dir.exists():
npz_files = sorted(list(data_dir.rglob("*.npz")))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件")
return npz_files
def resample_songs(file_paths, top_position_tokens, top_duration_tokens,
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9,
num_threads=None, seed=42):
"""
根据新逻辑进行重采样
Args:
file_paths: 文件路径列表
top_position_tokens: position列的前3个最常见token集合
top_duration_tokens: duration列的前3个最常见token集合
contain_sample_ratio: 对于包含前3个token的歌曲采样比例
not_contain_sample_ratio: 对于不包含前3个token的歌曲采样比例更高概率
num_threads: 线程数None表示使用CPU核心数
seed: 随机种子
Returns:
tuple: (sampled_paths, sampled_indices, stats)
"""
import threading
# 为每个线程创建独立的随机数生成器
thread_local = threading.local()
def get_thread_rng():
"""获取当前线程的随机数生成器"""
if not hasattr(thread_local, 'rng'):
# 使用线程ID和种子创建独立的随机数生成器
thread_id = threading.current_thread().ident
thread_local.rng = np.random.RandomState(seed + hash(thread_id) % 1000000)
return thread_local.rng
if num_threads is None:
num_threads = min(mp.cpu_count(), len(file_paths))
print(f"\n开始重采样,使用 {num_threads} 个线程...")
print(f" 包含前3个token的采样比例: {contain_sample_ratio*100:.1f}%")
print(f" 不包含前3个token的采样比例: {not_contain_sample_ratio*100:.1f}%")
def _should_sample_wrapper(args_tuple):
"""判断是否采样的包装函数(用于多线程)"""
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio = args_tuple
try:
# 使用线程本地的随机数生成器
thread_rng = get_thread_rng()
should_sample, in_top3 = should_sample_song(
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio, thread_rng
)
return (file_path, should_sample, in_top3, None)
except Exception as e:
return (file_path, False, False, str(e))
# 准备参数
sample_args = [(file_path, top_position_tokens, top_duration_tokens,
contain_sample_ratio, not_contain_sample_ratio)
for file_path in file_paths]
# 使用多线程判断每首歌是否应该采样
sampled_paths = []
sampled_indices = []
stats = {
'not_in_top3_count': 0, # 不在前3个的歌曲数量
'not_in_top3_sampled': 0, # 不在前3个且被采样的歌曲数量
'in_top3_count': 0, # 在前3个的歌曲数量
'in_top3_sampled': 0 # 在前3个且被采样的歌曲数量
}
# 限制并发任务数量,避免一次性提交过多任务
batch_size = min(1000, len(file_paths))
results = {}
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 分批提交任务
for batch_start in range(0, len(sample_args), batch_size):
batch_end = min(batch_start + batch_size, len(sample_args))
batch_args = sample_args[batch_start:batch_end]
futures = {executor.submit(_should_sample_wrapper, args): args[0]
for args in batch_args}
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures),
desc=f"判断采样 [{batch_start+1}-{batch_end}/{len(file_paths)}]",
leave=False):
try:
file_path, should_sample, in_top3, error = future.result(timeout=60)
results[file_path] = (should_sample, in_top3, error)
if error:
print(f"警告: 处理 {file_path} 时出错: {error}")
except Exception as e:
print(f"错误: 获取结果时出错: {e}")
# 按原始顺序处理结果,并统计
for idx, file_path in enumerate(file_paths):
if file_path not in results:
continue
should_sample, in_top3, error = results[file_path]
if error:
continue
# 统计信息
if in_top3:
stats['in_top3_count'] += 1
if should_sample:
stats['in_top3_sampled'] += 1
else:
stats['not_in_top3_count'] += 1
if should_sample:
stats['not_in_top3_sampled'] += 1
if should_sample:
sampled_paths.append(file_path)
sampled_indices.append(idx)
print(f"\n采样完成:")
print(f" 总歌曲数: {len(file_paths)}")
print(f" 采样歌曲数: {len(sampled_paths)}")
print(f" 采样比例: {len(sampled_paths)/len(file_paths)*100:.2f}%")
print(f" 不在前3个的歌曲数: {stats['not_in_top3_count']}")
print(f" 不在前3个且被采样的歌曲数: {stats['not_in_top3_sampled']}")
if stats['not_in_top3_count'] > 0:
print(f" 不在前3个的歌曲采样比例: {stats['not_in_top3_sampled']/stats['not_in_top3_count']*100:.2f}%")
print(f" 在前3个的歌曲数: {stats['in_top3_count']}")
print(f" 在前3个且被采样的歌曲数: {stats['in_top3_sampled']}")
if stats['in_top3_count'] > 0:
print(f" 在前3个的歌曲采样比例: {stats['in_top3_sampled']/stats['in_top3_count']*100:.2f}%")
return sampled_paths, sampled_indices, stats
def _save_single_file(args_tuple):
"""
保存单个文件的辅助函数(用于多线程)
支持延迟加载如果data是路径则从文件加载
Args:
args_tuple: (data, original_path, output_dir)
Returns:
tuple: (success, original_path) 或 (False, original_path, error_msg)
"""
data, original_path, output_dir = args_tuple
try:
# 如果data是路径则加载它
if isinstance(data, (str, Path)):
data = np.load(data)['arr_0']
# 保持相对路径结构
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
output_path = output_dir / relative_path
output_path.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(output_path, data)
return (True, original_path)
except Exception as e:
error_msg = str(e)
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
return (False, original_path, error_msg)
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
"""
保存采样后的数据(多线程版本)
Args:
sampled_data: 采样后的数据列表
sampled_paths: 采样后的文件路径列表
output_dir: 输出目录
num_threads: 线程数None表示使用CPU核心数
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n保存采样数据到: {output_dir}")
if num_threads is None:
num_threads = min(mp.cpu_count(), len(sampled_data))
# 准备参数
save_args = [(data, original_path, output_dir)
for data, original_path in zip(sampled_data, sampled_paths)]
# 使用多线程保存文件
success_count = 0
with ThreadPoolExecutor(max_workers=num_threads) as executor:
# 提交所有任务
futures = [executor.submit(_save_single_file, args)
for args in save_args]
# 收集结果
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
try:
result = future.result(timeout=300) # 设置超时避免卡死
if isinstance(result, tuple) and len(result) >= 2:
success = result[0]
if success:
success_count += 1
except Exception as e:
print(f"错误: 获取保存结果时出错: {e}")
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="基于position和duration token的重采样V2")
parser.add_argument("--json_path", type=str,
default="octuple_token_analysis_report.json",
help="token分析报告JSON文件路径")
parser.add_argument("--data_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
help="数据目录路径")
parser.add_argument("--output_dir", type=str,
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2",
help="输出目录路径")
parser.add_argument("--contain_sample_ratio", type=float, default=0.1,
help="对于包含前3个token的歌曲采样比例 (默认: 0.1)")
parser.add_argument("--not_contain_sample_ratio", type=float, default=0.9,
help="对于不包含前3个token的歌曲采样比例 (默认: 0.9)")
parser.add_argument("--top_k", type=int, default=3,
help="使用前k个最常见的token (默认: 3)")
parser.add_argument("--seed", type=int, default=42,
help="随机种子")
parser.add_argument("--num_threads", type=int, default=None,
help="线程数None表示使用CPU核心数 (默认: None)")
args = parser.parse_args()
# 1. 加载position和duration的前top_k个最常见token
top_position_tokens = load_top_tokens_from_json(
args.json_path, "position", top_k=args.top_k
)
top_duration_tokens = load_top_tokens_from_json(
args.json_path, "duration", top_k=args.top_k
)
if not top_position_tokens or not top_duration_tokens:
print("错误: 无法加载前top_k个token")
return
# 2. 获取所有数据文件路径
file_paths = get_data_file_paths(args.data_dir)
if not file_paths:
print("错误: 未找到任何数据文件")
return
print(f"\n共找到 {len(file_paths)} 个数据文件")
# 3. 根据新逻辑进行重采样
sampled_paths, sampled_indices, stats = resample_songs(
file_paths,
top_position_tokens,
top_duration_tokens,
contain_sample_ratio=args.contain_sample_ratio,
not_contain_sample_ratio=args.not_contain_sample_ratio,
num_threads=args.num_threads,
seed=args.seed
)
# 4. 保存采样结果(延迟加载)
sampled_data = sampled_paths # 使用路径,延迟加载
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
# 5. 保存采样索引(可选,用于后续分析)
indices_file = Path(args.output_dir) / "sampled_indices.npy"
np.save(indices_file, np.array(sampled_indices))
print(f"\n采样索引已保存到: {indices_file}")
# 6. 保存采样信息
info = {
"total_samples": len(file_paths),
"sampled_samples": len(sampled_paths),
"contain_sample_ratio": args.contain_sample_ratio,
"not_contain_sample_ratio": args.not_contain_sample_ratio,
"top_k": args.top_k,
"top_position_tokens": sorted(list(top_position_tokens)),
"top_duration_tokens": sorted(list(top_duration_tokens)),
"stats": stats
}
info_file = Path(args.output_dir) / "resample_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(info, f, indent=2, ensure_ascii=False)
print(f"采样信息已保存到: {info_file}")
if __name__ == "__main__":
main()

View File

@ -1,14 +1,282 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
统计octuple分词结果中每一列每个token的出现次数并生成分析报告
"""
import os
import numpy as np
from pathlib import Path
from collections import defaultdict, Counter
from tqdm import tqdm
import json
# 读取 npz 文件
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
# Octuple的列名定义
COLUMN_NAMES = [
"pitch", # 0: Pitch/PitchDrum
"position", # 1: Position
"bar", # 2: Bar
"velocity", # 3: Velocity
"duration", # 4: Duration
"program", # 5: Program
"tempo", # 6: Tempo
"timesig" # 7: TimeSignature
]
# 查看保存的键
print(data.files)
# 输出:['filename', 'sequence']
# 访问数据
sequence = data["arr_0"]
def load_octuple_data(data_dir):
"""
加载所有octuple分词后的.npz文件
Args:
data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录
Returns:
list: 所有加载的numpy数组列表
"""
data_dir = Path(data_dir)
npz_files = []
# 如果目录存在,查找所有.npz文件
if data_dir.exists():
npz_files = list(data_dir.rglob("*.npz"))
if not npz_files:
print(f"警告: 在 {data_dir} 中未找到.npz文件")
return []
print(f"找到 {len(npz_files)} 个.npz文件开始加载...")
all_data = []
for npz_file in tqdm(npz_files, desc="加载数据"):
try:
data = np.load(npz_file)['arr_0']
# 确保数据是二维数组 (num_tokens, num_columns)
if data.ndim == 2:
all_data.append(data)
elif data.ndim == 1:
# 如果是一维可能需要reshape但octuple应该是二维的
print(f"警告: {npz_file} 是一维数组,跳过")
except Exception as e:
print(f"错误: 加载 {npz_file} 时出错: {e}")
continue
return all_data
def count_tokens_by_column(all_data):
"""
统计每一列每个token的出现次数
Args:
all_data: 所有数据的列表每个元素是一个numpy数组 (num_tokens, num_columns)
Returns:
dict: {column_index: Counter({token_value: count})}
"""
column_counters = defaultdict(Counter)
print("统计token出现次数...")
for data in tqdm(all_data, desc="处理数据"):
if data.size == 0:
continue
num_columns = data.shape[1] if data.ndim == 2 else 1
for col_idx in range(num_columns):
if data.ndim == 2:
tokens = data[:, col_idx]
else:
tokens = data
# 统计该列中每个token的出现次数
unique, counts = np.unique(tokens, return_counts=True)
for token, count in zip(unique, counts):
column_counters[col_idx][int(token)] += int(count)
return dict(column_counters)
def generate_report(column_counters, output_file=None):
"""
生成分析报告
Args:
column_counters: 统计结果字典
output_file: 输出文件路径(可选)
"""
report_lines = []
report_lines.append("=" * 80)
report_lines.append("OCTUPLE分词结果统计分析报告")
report_lines.append("=" * 80)
report_lines.append("")
# 总体统计
total_tokens = sum(sum(counter.values()) for counter in column_counters.values())
report_lines.append(f"总token数: {total_tokens:,}")
report_lines.append(f"分析的列数: {len(column_counters)}")
report_lines.append("")
# 每一列的详细统计
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
report_lines.append("-" * 80)
report_lines.append(f"{col_idx}: {col_name}")
report_lines.append("-" * 80)
total_in_column = sum(counter.values())
unique_tokens = len(counter)
min_token = min(counter.keys()) if counter else 0
max_token = max(counter.keys()) if counter else 0
report_lines.append(f" 总token数: {total_in_column:,}")
report_lines.append(f" 唯一token数: {unique_tokens:,}")
report_lines.append(f" Token值范围: [{min_token}, {max_token}]")
report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A")
report_lines.append("")
# Top 20 最常见的token
report_lines.append(f" Top 20 最常见的token:")
top_tokens = counter.most_common(20)
for rank, (token, count) in enumerate(top_tokens, 1):
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
report_lines.append("")
# Top 20 最不常见的token出现次数>0的
report_lines.append(f" Top 20 最不常见的token (出现次数>0):")
bottom_tokens = counter.most_common()[-20:]
bottom_tokens.reverse()
for rank, (token, count) in enumerate(bottom_tokens, 1):
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
report_lines.append("")
# 分布统计
counts_list = list(counter.values())
if counts_list:
report_lines.append(f" 分布统计:")
report_lines.append(f" 最小出现次数: {min(counts_list):,}")
report_lines.append(f" 最大出现次数: {max(counts_list):,}")
report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}")
report_lines.append(f" 标准差: {np.std(counts_list):,.2f}")
report_lines.append("")
# 跨列分析
report_lines.append("=" * 80)
report_lines.append("跨列分析")
report_lines.append("=" * 80)
report_lines.append("")
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
total_in_column = sum(counter.values())
percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0
report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)")
report_lines.append("")
report_lines.append("=" * 80)
report_lines.append("报告生成完成")
report_lines.append("=" * 80)
# 输出报告
report_text = "\n".join(report_lines)
print("\n" + report_text)
# 保存到文件
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(report_text)
print(f"\n报告已保存到: {output_path}")
# 同时保存JSON格式的详细数据
if output_file:
json_output = output_path.with_suffix('.json')
json_data = {
'summary': {
'total_tokens': total_tokens,
'num_columns': len(column_counters)
},
'columns': {}
}
for col_idx in sorted(column_counters.keys()):
counter = column_counters[col_idx]
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
json_data['columns'][col_name] = {
'total_tokens': sum(counter.values()),
'unique_tokens': len(counter),
'token_counts': dict(counter),
'top_20': dict(counter.most_common(20)),
'bottom_20': dict(counter.most_common()[-20:])
}
with open(json_output, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
print(f"详细数据已保存到: {json_output}")
def main():
"""主函数"""
# 默认数据目录 - 可以根据需要修改
default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi"
# 可以指定具体的数据目录,例如:
data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2"
# 或者使用默认目录扫描所有oct8目录
# import sys
# if len(sys.argv) > 1:
# data_dir = sys.argv[1]
# else:
# # 自动查找所有oct8目录
# base_dir = Path(default_data_dir)
# oct8_dirs = list(base_dir.rglob("oct8"))
# if oct8_dirs:
# print(f"找到以下oct8目录:")
# for i, d in enumerate(oct8_dirs, 1):
# print(f" {i}. {d}")
# if len(oct8_dirs) == 1:
# data_dir = str(oct8_dirs[0])
# print(f"\n使用目录: {data_dir}")
# else:
# # 使用第一个找到的目录,或者合并所有目录
# print(f"\n使用第一个目录: {oct8_dirs[0]}")
# print("如需分析其他目录,请指定路径作为参数")
# data_dir = str(oct8_dirs[0])
# else:
# data_dir = default_data_dir
# print(f"未找到oct8目录使用默认目录: {data_dir}")
# 加载数据
all_data = load_octuple_data(data_dir)
if not all_data:
print("错误: 未加载到任何数据")
return
# 检查数据格式
if all_data:
sample = all_data[0]
print(f"\n数据格式检查:")
print(f" 样本形状: {sample.shape}")
print(f" 样本数据类型: {sample.dtype}")
print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}")
print()
# 统计token出现次数
column_counters = count_tokens_by_column(all_data)
# 生成报告
output_file = "octuple_token_analysis_report_part.txt"
generate_report(column_counters, output_file)
if __name__ == "__main__":
main()
print("token 序列长度:", len(sequence))
print("前 20 个 token", sequence[:20])