1127 update to latest
This commit is contained in:
@ -20,10 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
from itertools import groupby, chain
|
||||
from random import shuffle, seed
|
||||
|
||||
|
||||
lock = RLock()
|
||||
|
||||
def shuffled(seq):
|
||||
shuffle(seq)
|
||||
return seq
|
||||
|
||||
def permute_inside_and_across_tracks(seq):
|
||||
seq_sorted = sorted(seq, key=lambda x: x[5])
|
||||
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
|
||||
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
|
||||
|
||||
|
||||
def convert_event_dicts(dict_list):
|
||||
"""
|
||||
将 event 词表列表按顺序转换为结构化输出
|
||||
@ -59,7 +71,7 @@ def convert_event_dicts(dict_list):
|
||||
|
||||
return result
|
||||
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool):
|
||||
try:
|
||||
score = Score(midi_path, ttype="tick")
|
||||
|
||||
@ -71,6 +83,8 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
# 分词
|
||||
tok_seq = tokenizer(score)
|
||||
token_ids = tok_seq.ids
|
||||
if whether_shuffle:
|
||||
token_ids = permute_inside_and_across_tracks(token_ids)
|
||||
# add sos token at the beginning
|
||||
vocab = tokenizer.vocab
|
||||
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
|
||||
@ -86,7 +100,7 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
print(f" × 处理文件时出错:{midi_path} -> {e}")
|
||||
|
||||
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False):
|
||||
# === 1. 初始化分词器并保存词表 ===
|
||||
print("初始化分词器 Octuple...")
|
||||
config = TokenizerConfig(
|
||||
@ -108,15 +122,20 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# === 3. 收集 MIDI 文件 ===
|
||||
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
midi_paths = list(midi_paths)
|
||||
# midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
# glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
# midi_paths = list(midi_paths)
|
||||
midi_paths = []
|
||||
for root, _, files in os.walk(midi_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith(('.mid', '.midi')):
|
||||
midi_paths.append(os.path.join(root, file))
|
||||
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
|
||||
|
||||
# === 4. 并行处理 ===
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
res = future.result()
|
||||
@ -128,8 +147,18 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
|
||||
|
||||
parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)")
|
||||
parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录")
|
||||
parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序")
|
||||
|
||||
dataset_name = midi_directory.split("/")[-1]
|
||||
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
|
||||
|
||||
|
||||
output_dir = tuneidx_prefix
|
||||
preprocess_midi_directory(midi_directory, output_dir)
|
||||
args = parser.parse_args()
|
||||
preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle)
|
||||
39
data_representation/permute.py
Normal file
39
data_representation/permute.py
Normal file
@ -0,0 +1,39 @@
|
||||
from itertools import groupby, chain
|
||||
from random import shuffle, seed
|
||||
|
||||
def shuffled(seq):
|
||||
shuffle(seq)
|
||||
return seq
|
||||
|
||||
def permute_inside_and_across_tracks(seq):
|
||||
seq_sorted = sorted(seq, key=lambda x: x[5])
|
||||
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
|
||||
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
|
||||
|
||||
# (PitchDrum, Position, Bar, Velocity, Duration, Program, Tempo, TimeSignature)
|
||||
seq = [
|
||||
# Program 0
|
||||
(60, 0, 0, 90, 96, 0, 120, 16),
|
||||
(64, 48,0, 88, 96, 0, 120, 16),
|
||||
(67, 96,0, 92, 96, 0, 120, 16),
|
||||
|
||||
# Program 32
|
||||
(40, 0, 0, 80, 192, 32, 120, 16),
|
||||
(43, 0, 1, 78, 192, 32, 120, 16),
|
||||
|
||||
# Program 40
|
||||
(72, 24,0, 85, 72, 40, 120, 16),
|
||||
(74, 72,0, 83, 72, 40, 120, 16),
|
||||
(76, 24,1, 86, 72, 40, 120, 16),
|
||||
]
|
||||
|
||||
# seed(42)
|
||||
|
||||
inside_track_permuted_and_track_permuted = permute_inside_and_across_tracks(seq)
|
||||
|
||||
print("原始 seq:")
|
||||
for e in seq:
|
||||
print(e)
|
||||
print("\n打乱后的 seq:")
|
||||
for e in inside_track_permuted_and_track_permuted:
|
||||
print(e)
|
||||
634
data_representation/resample.py
Normal file
634
data_representation/resample.py
Normal file
@ -0,0 +1,634 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于token分布距离的重采样脚本
|
||||
读取octuple_token_analysis_report.json,计算每个数据与整体分布的距离,
|
||||
按照距离加权采样,距离越远的越容易被采样
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from scipy.stats import entropy, wasserstein_distance
|
||||
from scipy.spatial.distance import jensenshannon
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import multiprocessing as mp
|
||||
|
||||
# Octuple的列名定义
|
||||
COLUMN_NAMES = [
|
||||
"pitch", # 0: Pitch/PitchDrum
|
||||
"position", # 1: Position
|
||||
"bar", # 2: Bar
|
||||
"velocity", # 3: Velocity
|
||||
"duration", # 4: Duration
|
||||
"program", # 5: Program
|
||||
"tempo", # 6: Tempo
|
||||
"timesig" # 7: TimeSignature
|
||||
]
|
||||
|
||||
|
||||
def load_distribution_from_json(json_path):
|
||||
"""
|
||||
从JSON文件中加载整体token分布
|
||||
|
||||
Args:
|
||||
json_path: JSON文件路径
|
||||
|
||||
Returns:
|
||||
dict: {column_name: {token: probability}}
|
||||
"""
|
||||
print(f"读取分布文件: {json_path}")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
report = json.load(f)
|
||||
|
||||
distributions = {}
|
||||
columns = report.get('columns', {})
|
||||
|
||||
for col_name in COLUMN_NAMES:
|
||||
if col_name not in columns:
|
||||
print(f"警告: 列 {col_name} 不在报告中")
|
||||
distributions[col_name] = {}
|
||||
continue
|
||||
|
||||
col_data = columns[col_name]
|
||||
token_counts = col_data.get('token_counts', {})
|
||||
total_tokens = col_data.get('total_tokens', 1)
|
||||
|
||||
# 转换为概率分布
|
||||
distribution = {}
|
||||
for token_str, count in token_counts.items():
|
||||
token = int(token_str)
|
||||
distribution[token] = count / total_tokens
|
||||
|
||||
distributions[col_name] = distribution
|
||||
print(f" 列 {col_name}: {len(distribution)} 个唯一token, 总token数: {total_tokens:,}")
|
||||
|
||||
return distributions
|
||||
|
||||
|
||||
def compute_data_distribution(data, col_idx):
|
||||
"""
|
||||
计算单个数据在指定列的token分布
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns)
|
||||
col_idx: 列索引
|
||||
|
||||
Returns:
|
||||
dict: {token: probability}
|
||||
"""
|
||||
if data.size == 0:
|
||||
return {}
|
||||
|
||||
tokens = data[:, col_idx]
|
||||
unique, counts = np.unique(tokens, return_counts=True)
|
||||
total = len(tokens)
|
||||
|
||||
distribution = {}
|
||||
for token, count in zip(unique, counts):
|
||||
distribution[int(token)] = count / total
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def compute_emd_distance(dist1, dist2):
|
||||
"""
|
||||
使用推土机距离(Earth Mover's Distance / Wasserstein距离)计算两个分布之间的距离
|
||||
|
||||
Args:
|
||||
dist1: 分布1,dict {token: probability},已归一化
|
||||
dist2: 分布2,dict {token: probability},已归一化
|
||||
|
||||
Returns:
|
||||
float: EMD距离
|
||||
"""
|
||||
# 获取所有token的并集,并排序
|
||||
all_tokens = sorted(set(dist1.keys()) | set(dist2.keys()))
|
||||
|
||||
if not all_tokens:
|
||||
return 0.0
|
||||
|
||||
# 构建概率向量和token值向量
|
||||
p_weights = np.array([dist1.get(token, 0.0) for token in all_tokens])
|
||||
q_weights = np.array([dist2.get(token, 0.0) for token in all_tokens])
|
||||
token_values = np.array(all_tokens, dtype=float)
|
||||
|
||||
# 归一化(处理数值误差)
|
||||
p_sum = p_weights.sum()
|
||||
q_sum = q_weights.sum()
|
||||
|
||||
if p_sum < 1e-10 or q_sum < 1e-10:
|
||||
return 0.0
|
||||
|
||||
p_weights = p_weights / p_sum
|
||||
q_weights = q_weights / q_sum
|
||||
|
||||
# 使用Wasserstein距离(1-Wasserstein距离,即推土机距离)
|
||||
# wasserstein_distance需要两个分布的样本值位置和权重
|
||||
# 对于离散分布,我们使用token值作为位置
|
||||
emd = wasserstein_distance(token_values, token_values, p_weights, q_weights)
|
||||
|
||||
return emd
|
||||
|
||||
|
||||
def compute_distribution_distance(dist1, dist2, method='emd'):
|
||||
"""
|
||||
计算两个分布之间的距离
|
||||
|
||||
Args:
|
||||
dist1: 分布1,dict {token: probability}
|
||||
dist2: 分布2,dict {token: probability}
|
||||
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
|
||||
|
||||
Returns:
|
||||
float: 分布距离
|
||||
"""
|
||||
if method == 'emd':
|
||||
return compute_emd_distance(dist1, dist2)
|
||||
|
||||
# 获取所有token的并集
|
||||
all_tokens = set(dist1.keys()) | set(dist2.keys())
|
||||
|
||||
if not all_tokens:
|
||||
return 0.0
|
||||
|
||||
# 构建概率向量
|
||||
p = np.array([dist1.get(token, 0.0) for token in all_tokens])
|
||||
q = np.array([dist2.get(token, 0.0) for token in all_tokens])
|
||||
|
||||
# 归一化(处理数值误差)
|
||||
p = p / (p.sum() + 1e-10)
|
||||
q = q / (q.sum() + 1e-10)
|
||||
|
||||
if method == 'js':
|
||||
# Jensen-Shannon散度(对称,范围[0, 1])
|
||||
return jensenshannon(p, q)
|
||||
elif method == 'kl':
|
||||
# KL散度(非对称,需要处理零值)
|
||||
# 添加小的平滑项避免log(0)
|
||||
p = p + 1e-10
|
||||
q = q + 1e-10
|
||||
p = p / p.sum()
|
||||
q = q / q.sum()
|
||||
return entropy(p, q)
|
||||
else:
|
||||
raise ValueError(f"未知的距离方法: {method}")
|
||||
|
||||
|
||||
def extract_subdistribution(global_dist, data_tokens):
|
||||
"""
|
||||
从全局分布中提取只包含数据中出现的token的子分布,并归一化
|
||||
|
||||
Args:
|
||||
global_dist: 全局分布,dict {token: probability}
|
||||
data_tokens: 数据中出现的token集合,set或list
|
||||
|
||||
Returns:
|
||||
dict: 子分布,dict {token: probability},已归一化
|
||||
"""
|
||||
if not data_tokens or not global_dist:
|
||||
return {}
|
||||
|
||||
# 提取子分布
|
||||
sub_dist = {token: global_dist.get(token, 0.0) for token in data_tokens}
|
||||
|
||||
# 归一化
|
||||
total = sum(sub_dist.values())
|
||||
if total < 1e-10:
|
||||
return {}
|
||||
|
||||
normalized_sub_dist = {token: prob / total for token, prob in sub_dist.items()}
|
||||
|
||||
return normalized_sub_dist
|
||||
|
||||
|
||||
def compute_data_distance(data, global_distributions, method='emd'):
|
||||
"""
|
||||
计算单个数据与整体分布的距离
|
||||
对每首歌,从数据集分布中找出和这首歌的分布包含的数据相同的子分布,
|
||||
都进行归一化然后计算推土机距离
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns) 或文件路径(如果是延迟加载)
|
||||
global_distributions: 整体分布,dict {column_name: {token: probability}}
|
||||
method: 距离计算方法,'emd' (推土机距离), 'js' (Jensen-Shannon) 或 'kl' (KL散度)
|
||||
|
||||
Returns:
|
||||
float: 平均距离(跨所有列)
|
||||
"""
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
try:
|
||||
data = np.load(data)['arr_0']
|
||||
except Exception as e:
|
||||
# 不打印错误,让调用者处理
|
||||
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
|
||||
|
||||
distances = []
|
||||
|
||||
for col_idx, col_name in enumerate(COLUMN_NAMES):
|
||||
# 计算该数据在该列的分布
|
||||
data_dist = compute_data_distribution(data, col_idx)
|
||||
|
||||
# 获取整体分布
|
||||
global_dist = global_distributions.get(col_name, {})
|
||||
|
||||
if not data_dist or not global_dist:
|
||||
continue
|
||||
|
||||
# 从全局分布中提取只包含数据中出现的token的子分布
|
||||
data_tokens = set(data_dist.keys())
|
||||
sub_global_dist = extract_subdistribution(global_dist, data_tokens)
|
||||
|
||||
if not sub_global_dist:
|
||||
continue
|
||||
|
||||
# 归一化数据分布
|
||||
data_dist_sum = sum(data_dist.values())
|
||||
if data_dist_sum < 1e-10:
|
||||
continue
|
||||
normalized_data_dist = {token: prob / data_dist_sum
|
||||
for token, prob in data_dist.items()}
|
||||
|
||||
# 计算距离(两个分布都已归一化)
|
||||
dist = compute_distribution_distance(normalized_data_dist, sub_global_dist, method=method)
|
||||
distances.append(dist)
|
||||
|
||||
# 返回平均距离
|
||||
return np.mean(distances) if distances else 0.0
|
||||
|
||||
|
||||
def _load_single_file(npz_file):
|
||||
"""
|
||||
加载单个npz文件的辅助函数(用于多线程)
|
||||
|
||||
Args:
|
||||
npz_file: npz文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (data, file_path) 或 None(如果加载失败)
|
||||
"""
|
||||
try:
|
||||
data = np.load(npz_file)['arr_0']
|
||||
if data.ndim == 2:
|
||||
return (data, npz_file)
|
||||
elif data.ndim == 1:
|
||||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_data_file_paths(data_dir):
|
||||
"""
|
||||
获取所有数据文件路径(不加载数据)
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径
|
||||
|
||||
Returns:
|
||||
list: 文件路径列表
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
if data_dir.exists():
|
||||
npz_files = sorted(list(data_dir.rglob("*.npz")))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return []
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件")
|
||||
return npz_files
|
||||
|
||||
|
||||
def load_data_with_paths(data_dir, num_threads=None, lazy=False):
|
||||
"""
|
||||
加载所有数据并返回数据路径列表(多线程版本)
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
lazy: 如果为True,只返回文件路径,不加载数据
|
||||
|
||||
Returns:
|
||||
tuple: (data_list, file_paths_list) 或 (None, file_paths_list) 如果lazy=True
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
if data_dir.exists():
|
||||
npz_files = sorted(list(data_dir.rglob("*.npz")))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return [], []
|
||||
|
||||
if lazy:
|
||||
print(f"找到 {len(npz_files)} 个.npz文件(延迟加载模式)")
|
||||
return None, npz_files
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件,开始加载...")
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(npz_files))
|
||||
|
||||
all_data = []
|
||||
file_paths = []
|
||||
|
||||
# 使用多线程加载文件
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = {executor.submit(_load_single_file, npz_file): npz_file
|
||||
for npz_file in npz_files}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="加载数据"):
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
data, file_path = result
|
||||
all_data.append(data)
|
||||
file_paths.append(file_path)
|
||||
|
||||
# 保持原始顺序
|
||||
if file_paths:
|
||||
sorted_pairs = sorted(zip(file_paths, all_data), key=lambda x: str(x[0]))
|
||||
file_paths, all_data = zip(*sorted_pairs)
|
||||
file_paths = list(file_paths)
|
||||
all_data = list(all_data)
|
||||
|
||||
return all_data, file_paths
|
||||
|
||||
|
||||
def weighted_resample(file_paths, distances, sample_ratio=0.3, method='js', lazy=True):
|
||||
"""
|
||||
根据距离进行加权重采样
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
distances: 距离列表
|
||||
sample_ratio: 采样比例
|
||||
method: 距离计算方法(用于确定权重方向)
|
||||
lazy: 如果为True,返回文件路径而不是数据
|
||||
|
||||
Returns:
|
||||
tuple: (sampled_data_or_paths, sampled_paths, sampled_indices)
|
||||
"""
|
||||
n_samples = int(len(file_paths) * sample_ratio)
|
||||
print(f"\n准备采样 {n_samples} 个数据 (占总数的 {sample_ratio*100:.1f}%)")
|
||||
|
||||
# 将距离转换为权重
|
||||
# 距离越远,权重越大
|
||||
distances = np.array(distances)
|
||||
|
||||
# 处理零距离或异常值
|
||||
min_dist = np.min(distances[distances > 0]) if np.any(distances > 0) else 1e-10
|
||||
distances = np.maximum(distances, min_dist * 0.1)
|
||||
|
||||
# 归一化距离到[0, 1],然后转换为权重
|
||||
# 使用指数函数增强距离差异
|
||||
normalized_distances = (distances - distances.min()) / (distances.max() - distances.min() + 1e-10)
|
||||
weights = np.exp(normalized_distances * 3) # 指数放大,使距离远的更容易被采样
|
||||
|
||||
# 归一化权重
|
||||
weights = weights / weights.sum()
|
||||
|
||||
# 加权随机采样
|
||||
indices = np.arange(len(file_paths))
|
||||
sampled_indices = np.random.choice(indices, size=n_samples, replace=False, p=weights)
|
||||
|
||||
sampled_paths = [file_paths[i] for i in sampled_indices]
|
||||
|
||||
# 如果lazy=True,返回路径;否则加载数据
|
||||
if lazy:
|
||||
sampled_data = sampled_paths # 返回路径,延迟加载
|
||||
else:
|
||||
# 加载采样后的数据
|
||||
sampled_data = []
|
||||
for path in tqdm(sampled_paths, desc="加载采样数据"):
|
||||
try:
|
||||
data = np.load(path)['arr_0']
|
||||
sampled_data.append(data)
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {path} 时出错: {e}")
|
||||
sampled_data.append(None)
|
||||
|
||||
print(f"采样完成:")
|
||||
print(f" 采样数据数量: {len(sampled_paths)}")
|
||||
print(f" 平均距离: {distances[sampled_indices].mean():.6f}")
|
||||
print(f" 最小距离: {distances[sampled_indices].min():.6f}")
|
||||
print(f" 最大距离: {distances[sampled_indices].max():.6f}")
|
||||
|
||||
return sampled_data, sampled_paths, sampled_indices
|
||||
|
||||
|
||||
def _save_single_file(args_tuple):
|
||||
"""
|
||||
保存单个文件的辅助函数(用于多线程)
|
||||
支持延迟加载:如果data是路径,则从文件加载
|
||||
|
||||
Args:
|
||||
args_tuple: (data, original_path, output_dir)
|
||||
|
||||
Returns:
|
||||
tuple: (success, original_path) 或 (False, original_path, error_msg)
|
||||
"""
|
||||
data, original_path, output_dir = args_tuple
|
||||
try:
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
data = np.load(data)['arr_0']
|
||||
|
||||
# 保持相对路径结构
|
||||
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
|
||||
output_path = output_dir / relative_path
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
np.savez_compressed(output_path, data)
|
||||
return (True, original_path)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
|
||||
return (False, original_path, error_msg)
|
||||
|
||||
|
||||
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
|
||||
"""
|
||||
保存采样后的数据(多线程版本)
|
||||
|
||||
Args:
|
||||
sampled_data: 采样后的数据列表
|
||||
sampled_paths: 采样后的文件路径列表
|
||||
output_dir: 输出目录
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n保存采样数据到: {output_dir}")
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(sampled_data))
|
||||
|
||||
# 准备参数
|
||||
save_args = [(data, original_path, output_dir)
|
||||
for data, original_path in zip(sampled_data, sampled_paths)]
|
||||
|
||||
# 使用多线程保存文件
|
||||
success_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(_save_single_file, args)
|
||||
for args in save_args]
|
||||
|
||||
# 收集结果
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
|
||||
try:
|
||||
result = future.result(timeout=300) # 设置超时避免卡死
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
success = result[0]
|
||||
if success:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"错误: 获取保存结果时出错: {e}")
|
||||
|
||||
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="基于token分布距离的重采样")
|
||||
parser.add_argument("--json_path", type=str,
|
||||
default="octuple_token_analysis_report.json",
|
||||
help="token分析报告JSON文件路径")
|
||||
parser.add_argument("--data_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
|
||||
help="数据目录路径")
|
||||
parser.add_argument("--output_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled",
|
||||
help="输出目录路径")
|
||||
parser.add_argument("--sample_ratio", type=float, default=0.3,
|
||||
help="采样比例 (默认: 0.3)")
|
||||
parser.add_argument("--distance_method", type=str, default="emd",
|
||||
choices=["emd", "js", "kl"],
|
||||
help="距离计算方法: 'emd' (推土机距离/EMD), 'js' (Jensen-Shannon) 或 'kl' (KL散度)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子")
|
||||
parser.add_argument("--num_threads", type=int, default=1,
|
||||
help="线程数,None表示使用CPU核心数 (默认: None)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置随机种子
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# 1. 加载整体分布
|
||||
global_distributions = load_distribution_from_json(args.json_path)
|
||||
|
||||
# 2. 获取所有数据文件路径(延迟加载模式,避免一次性加载所有数据)
|
||||
_, file_paths = load_data_with_paths(args.data_dir, lazy=True)
|
||||
|
||||
if not file_paths:
|
||||
print("错误: 未找到任何数据文件")
|
||||
return
|
||||
|
||||
print(f"\n共找到 {len(file_paths)} 个数据文件")
|
||||
|
||||
# 3. 计算每个数据与整体分布的距离(多线程版本,延迟加载)
|
||||
print("\n计算每个数据与整体分布的距离(延迟加载模式)...")
|
||||
|
||||
def _compute_distance_wrapper(args_tuple):
|
||||
"""计算距离的包装函数(用于多线程,支持延迟加载)"""
|
||||
idx, file_path, global_dists, method = args_tuple
|
||||
try:
|
||||
distance = compute_data_distance(file_path, global_dists, method=method)
|
||||
return (idx, distance, None)
|
||||
except Exception as e:
|
||||
return (idx, 0.0, str(e))
|
||||
|
||||
if args.num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(file_paths))
|
||||
else:
|
||||
num_threads = args.num_threads
|
||||
|
||||
# 准备参数(使用文件路径而不是数据,包含索引)
|
||||
distance_args = [(i, file_path, global_distributions, args.distance_method)
|
||||
for i, file_path in enumerate(file_paths)]
|
||||
|
||||
# 使用多线程计算距离(按需加载数据)
|
||||
# 初始化结果列表,保持顺序
|
||||
distances = [0.0] * len(file_paths)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(_compute_distance_wrapper, args)
|
||||
for args in distance_args]
|
||||
|
||||
# 收集结果,使用 tqdm 显示进度
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="计算距离"):
|
||||
try:
|
||||
idx, distance, error = future.result(timeout=300) # 设置超时避免卡死
|
||||
distances[idx] = distance
|
||||
if error:
|
||||
print(f"警告: 计算距离时出错 (索引 {idx}): {error}")
|
||||
except Exception as e:
|
||||
print(f"错误: 获取结果时出错: {e}")
|
||||
# 如果无法获取结果,保持默认值 0.0
|
||||
|
||||
distances = np.array(distances)
|
||||
print(f"\n距离统计:")
|
||||
print(f" 平均距离: {distances.mean():.6f}")
|
||||
print(f" 最小距离: {distances.min():.6f}")
|
||||
print(f" 最大距离: {distances.max():.6f}")
|
||||
print(f" 标准差: {distances.std():.6f}")
|
||||
|
||||
# 4. 根据距离进行加权采样(延迟加载模式)
|
||||
sampled_data, sampled_paths, sampled_indices = weighted_resample(
|
||||
file_paths, distances,
|
||||
sample_ratio=args.sample_ratio,
|
||||
method=args.distance_method,
|
||||
lazy=True # 使用延迟加载,避免重复加载数据
|
||||
)
|
||||
|
||||
# 5. 保存采样结果(多线程,延迟加载)
|
||||
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
|
||||
|
||||
# 6. 保存采样索引(可选,用于后续分析)
|
||||
indices_file = Path(args.output_dir) / "sampled_indices.npy"
|
||||
np.save(indices_file, sampled_indices)
|
||||
print(f"\n采样索引已保存到: {indices_file}")
|
||||
|
||||
# 保存采样信息
|
||||
info = {
|
||||
"total_samples": len(file_paths),
|
||||
"sampled_samples": len(sampled_data),
|
||||
"sample_ratio": args.sample_ratio,
|
||||
"distance_method": args.distance_method,
|
||||
"distance_stats": {
|
||||
"mean": float(distances.mean()),
|
||||
"min": float(distances.min()),
|
||||
"max": float(distances.max()),
|
||||
"std": float(distances.std())
|
||||
},
|
||||
"sampled_distance_stats": {
|
||||
"mean": float(distances[sampled_indices].mean()),
|
||||
"min": float(distances[sampled_indices].min()),
|
||||
"max": float(distances[sampled_indices].max()),
|
||||
"std": float(distances[sampled_indices].std())
|
||||
}
|
||||
}
|
||||
|
||||
info_file = Path(args.output_dir) / "resample_info.json"
|
||||
with open(info_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(info, f, indent=2, ensure_ascii=False)
|
||||
print(f"采样信息已保存到: {info_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
472
data_representation/resampleV2.py
Normal file
472
data_representation/resampleV2.py
Normal file
@ -0,0 +1,472 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于position和duration token的重采样脚本V2
|
||||
对于每首歌:
|
||||
1. 如果包含的position和duration不在总数据集前3个,则必定采样
|
||||
2. 对于包含的,以某个固定的百分比采样
|
||||
3. 两个条件满足一个即可
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import multiprocessing as mp
|
||||
|
||||
# Octuple的列名定义
|
||||
COLUMN_NAMES = [
|
||||
"pitch", # 0: Pitch/PitchDrum
|
||||
"position", # 1: Position
|
||||
"bar", # 2: Bar
|
||||
"velocity", # 3: Velocity
|
||||
"duration", # 4: Duration
|
||||
"program", # 5: Program
|
||||
"tempo", # 6: Tempo
|
||||
"timesig" # 7: TimeSignature
|
||||
]
|
||||
|
||||
|
||||
def load_top_tokens_from_json(json_path, column_name, top_k=3):
|
||||
"""
|
||||
从JSON文件中加载指定列的前top_k个最常见的token
|
||||
|
||||
Args:
|
||||
json_path: JSON文件路径
|
||||
column_name: 列名(如'position'或'duration')
|
||||
top_k: 返回前k个最常见的token
|
||||
|
||||
Returns:
|
||||
set: 前top_k个最常见的token集合
|
||||
"""
|
||||
print(f"读取分布文件: {json_path}")
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
report = json.load(f)
|
||||
|
||||
columns = report.get('columns', {})
|
||||
|
||||
if column_name not in columns:
|
||||
print(f"警告: 列 {column_name} 不在报告中")
|
||||
return set()
|
||||
|
||||
col_data = columns[column_name]
|
||||
token_counts = col_data.get('token_counts', {})
|
||||
|
||||
# 按出现次数排序,获取前top_k个
|
||||
sorted_tokens = sorted(token_counts.items(), key=lambda x: int(x[1]), reverse=True)
|
||||
top_tokens = {int(token_str) for token_str, _ in sorted_tokens[:top_k]}
|
||||
|
||||
print(f" 列 {column_name} 的前{top_k}个最常见token: {sorted(top_tokens)}")
|
||||
|
||||
return top_tokens
|
||||
|
||||
|
||||
def get_data_tokens(data, col_idx):
|
||||
"""
|
||||
获取单个数据在指定列的所有唯一token
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns) 或文件路径
|
||||
col_idx: 列索引
|
||||
|
||||
Returns:
|
||||
set: 唯一token集合
|
||||
"""
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
try:
|
||||
data = np.load(data)['arr_0']
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"加载文件 {data} 时出错: {e}")
|
||||
|
||||
if data.size == 0:
|
||||
return set()
|
||||
|
||||
tokens = data[:, col_idx]
|
||||
unique_tokens = set(int(token) for token in np.unique(tokens))
|
||||
|
||||
return unique_tokens
|
||||
|
||||
|
||||
def should_sample_song(data, top_position_tokens, top_duration_tokens,
|
||||
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9, rng=None):
|
||||
"""
|
||||
判断一首歌是否应该被采样
|
||||
|
||||
Args:
|
||||
data: numpy数组 (num_tokens, num_columns) 或文件路径
|
||||
top_position_tokens: position列的前3个最常见token集合
|
||||
top_duration_tokens: duration列的前3个最常见token集合
|
||||
contain_sample_ratio: 对于包含前3个token的歌曲,采样比例
|
||||
not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率)
|
||||
rng: 随机数生成器,如果为None则使用全局的np.random
|
||||
|
||||
Returns:
|
||||
tuple: (是否应该采样, 是否在前3个) - 在前3个指position和duration都在前3个
|
||||
"""
|
||||
# 获取position和duration列的唯一token
|
||||
position_idx = COLUMN_NAMES.index("position")
|
||||
duration_idx = COLUMN_NAMES.index("duration")
|
||||
|
||||
position_tokens = get_data_tokens(data, position_idx)
|
||||
duration_tokens = get_data_tokens(data, duration_idx)
|
||||
|
||||
# 判断是否在前3个
|
||||
position_in_top3 = bool(position_tokens & top_position_tokens)
|
||||
duration_in_top3 = bool(duration_tokens & top_duration_tokens)
|
||||
in_top3 = position_in_top3 and duration_in_top3
|
||||
|
||||
if rng is None:
|
||||
rng = np.random
|
||||
|
||||
# 条件1: 如果position或duration不包含前3个token,以更高概率采样
|
||||
if not position_in_top3 or not duration_in_top3:
|
||||
should_sample = rng.random() < not_contain_sample_ratio
|
||||
return should_sample, False
|
||||
|
||||
# 条件2: 如果包含前3个token,则以固定百分比采样
|
||||
should_sample = rng.random() < contain_sample_ratio
|
||||
return should_sample, True
|
||||
|
||||
|
||||
def _load_single_file(npz_file):
|
||||
"""
|
||||
加载单个npz文件的辅助函数(用于多线程)
|
||||
|
||||
Args:
|
||||
npz_file: npz文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (data, file_path) 或 None(如果加载失败)
|
||||
"""
|
||||
try:
|
||||
data = np.load(npz_file)['arr_0']
|
||||
if data.ndim == 2:
|
||||
return (data, npz_file)
|
||||
elif data.ndim == 1:
|
||||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_data_file_paths(data_dir):
|
||||
"""
|
||||
获取所有数据文件路径(不加载数据)
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径
|
||||
|
||||
Returns:
|
||||
list: 文件路径列表
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
if data_dir.exists():
|
||||
npz_files = sorted(list(data_dir.rglob("*.npz")))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return []
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件")
|
||||
return npz_files
|
||||
|
||||
|
||||
def resample_songs(file_paths, top_position_tokens, top_duration_tokens,
|
||||
contain_sample_ratio=0.3, not_contain_sample_ratio=0.9,
|
||||
num_threads=None, seed=42):
|
||||
"""
|
||||
根据新逻辑进行重采样
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
top_position_tokens: position列的前3个最常见token集合
|
||||
top_duration_tokens: duration列的前3个最常见token集合
|
||||
contain_sample_ratio: 对于包含前3个token的歌曲,采样比例
|
||||
not_contain_sample_ratio: 对于不包含前3个token的歌曲,采样比例(更高概率)
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
seed: 随机种子
|
||||
|
||||
Returns:
|
||||
tuple: (sampled_paths, sampled_indices, stats)
|
||||
"""
|
||||
import threading
|
||||
|
||||
# 为每个线程创建独立的随机数生成器
|
||||
thread_local = threading.local()
|
||||
|
||||
def get_thread_rng():
|
||||
"""获取当前线程的随机数生成器"""
|
||||
if not hasattr(thread_local, 'rng'):
|
||||
# 使用线程ID和种子创建独立的随机数生成器
|
||||
thread_id = threading.current_thread().ident
|
||||
thread_local.rng = np.random.RandomState(seed + hash(thread_id) % 1000000)
|
||||
return thread_local.rng
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(file_paths))
|
||||
|
||||
print(f"\n开始重采样,使用 {num_threads} 个线程...")
|
||||
print(f" 包含前3个token的采样比例: {contain_sample_ratio*100:.1f}%")
|
||||
print(f" 不包含前3个token的采样比例: {not_contain_sample_ratio*100:.1f}%")
|
||||
|
||||
def _should_sample_wrapper(args_tuple):
|
||||
"""判断是否采样的包装函数(用于多线程)"""
|
||||
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio = args_tuple
|
||||
try:
|
||||
# 使用线程本地的随机数生成器
|
||||
thread_rng = get_thread_rng()
|
||||
should_sample, in_top3 = should_sample_song(
|
||||
file_path, top_pos, top_dur, contain_ratio, not_contain_ratio, thread_rng
|
||||
)
|
||||
return (file_path, should_sample, in_top3, None)
|
||||
except Exception as e:
|
||||
return (file_path, False, False, str(e))
|
||||
|
||||
# 准备参数
|
||||
sample_args = [(file_path, top_position_tokens, top_duration_tokens,
|
||||
contain_sample_ratio, not_contain_sample_ratio)
|
||||
for file_path in file_paths]
|
||||
|
||||
# 使用多线程判断每首歌是否应该采样
|
||||
sampled_paths = []
|
||||
sampled_indices = []
|
||||
stats = {
|
||||
'not_in_top3_count': 0, # 不在前3个的歌曲数量
|
||||
'not_in_top3_sampled': 0, # 不在前3个且被采样的歌曲数量
|
||||
'in_top3_count': 0, # 在前3个的歌曲数量
|
||||
'in_top3_sampled': 0 # 在前3个且被采样的歌曲数量
|
||||
}
|
||||
|
||||
# 限制并发任务数量,避免一次性提交过多任务
|
||||
batch_size = min(1000, len(file_paths))
|
||||
results = {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 分批提交任务
|
||||
for batch_start in range(0, len(sample_args), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(sample_args))
|
||||
batch_args = sample_args[batch_start:batch_end]
|
||||
|
||||
futures = {executor.submit(_should_sample_wrapper, args): args[0]
|
||||
for args in batch_args}
|
||||
|
||||
# 收集结果
|
||||
for future in tqdm(as_completed(futures), total=len(futures),
|
||||
desc=f"判断采样 [{batch_start+1}-{batch_end}/{len(file_paths)}]",
|
||||
leave=False):
|
||||
try:
|
||||
file_path, should_sample, in_top3, error = future.result(timeout=60)
|
||||
results[file_path] = (should_sample, in_top3, error)
|
||||
if error:
|
||||
print(f"警告: 处理 {file_path} 时出错: {error}")
|
||||
except Exception as e:
|
||||
print(f"错误: 获取结果时出错: {e}")
|
||||
|
||||
# 按原始顺序处理结果,并统计
|
||||
for idx, file_path in enumerate(file_paths):
|
||||
if file_path not in results:
|
||||
continue
|
||||
|
||||
should_sample, in_top3, error = results[file_path]
|
||||
if error:
|
||||
continue
|
||||
|
||||
# 统计信息
|
||||
if in_top3:
|
||||
stats['in_top3_count'] += 1
|
||||
if should_sample:
|
||||
stats['in_top3_sampled'] += 1
|
||||
else:
|
||||
stats['not_in_top3_count'] += 1
|
||||
if should_sample:
|
||||
stats['not_in_top3_sampled'] += 1
|
||||
|
||||
if should_sample:
|
||||
sampled_paths.append(file_path)
|
||||
sampled_indices.append(idx)
|
||||
|
||||
print(f"\n采样完成:")
|
||||
print(f" 总歌曲数: {len(file_paths)}")
|
||||
print(f" 采样歌曲数: {len(sampled_paths)}")
|
||||
print(f" 采样比例: {len(sampled_paths)/len(file_paths)*100:.2f}%")
|
||||
print(f" 不在前3个的歌曲数: {stats['not_in_top3_count']}")
|
||||
print(f" 不在前3个且被采样的歌曲数: {stats['not_in_top3_sampled']}")
|
||||
if stats['not_in_top3_count'] > 0:
|
||||
print(f" 不在前3个的歌曲采样比例: {stats['not_in_top3_sampled']/stats['not_in_top3_count']*100:.2f}%")
|
||||
print(f" 在前3个的歌曲数: {stats['in_top3_count']}")
|
||||
print(f" 在前3个且被采样的歌曲数: {stats['in_top3_sampled']}")
|
||||
if stats['in_top3_count'] > 0:
|
||||
print(f" 在前3个的歌曲采样比例: {stats['in_top3_sampled']/stats['in_top3_count']*100:.2f}%")
|
||||
|
||||
return sampled_paths, sampled_indices, stats
|
||||
|
||||
|
||||
def _save_single_file(args_tuple):
|
||||
"""
|
||||
保存单个文件的辅助函数(用于多线程)
|
||||
支持延迟加载:如果data是路径,则从文件加载
|
||||
|
||||
Args:
|
||||
args_tuple: (data, original_path, output_dir)
|
||||
|
||||
Returns:
|
||||
tuple: (success, original_path) 或 (False, original_path, error_msg)
|
||||
"""
|
||||
data, original_path, output_dir = args_tuple
|
||||
try:
|
||||
# 如果data是路径,则加载它
|
||||
if isinstance(data, (str, Path)):
|
||||
data = np.load(data)['arr_0']
|
||||
|
||||
# 保持相对路径结构
|
||||
relative_path = original_path.relative_to(original_path.parents[len(original_path.parts) - 3])
|
||||
output_path = output_dir / relative_path
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
np.savez_compressed(output_path, data)
|
||||
return (True, original_path)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"错误: 保存 {original_path} 时出错: {error_msg}")
|
||||
return (False, original_path, error_msg)
|
||||
|
||||
|
||||
def save_sampled_data(sampled_data, sampled_paths, output_dir, num_threads=None):
|
||||
"""
|
||||
保存采样后的数据(多线程版本)
|
||||
|
||||
Args:
|
||||
sampled_data: 采样后的数据列表
|
||||
sampled_paths: 采样后的文件路径列表
|
||||
output_dir: 输出目录
|
||||
num_threads: 线程数,None表示使用CPU核心数
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n保存采样数据到: {output_dir}")
|
||||
|
||||
if num_threads is None:
|
||||
num_threads = min(mp.cpu_count(), len(sampled_data))
|
||||
|
||||
# 准备参数
|
||||
save_args = [(data, original_path, output_dir)
|
||||
for data, original_path in zip(sampled_data, sampled_paths)]
|
||||
|
||||
# 使用多线程保存文件
|
||||
success_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(_save_single_file, args)
|
||||
for args in save_args]
|
||||
|
||||
# 收集结果
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="保存数据"):
|
||||
try:
|
||||
result = future.result(timeout=300) # 设置超时避免卡死
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
success = result[0]
|
||||
if success:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"错误: 获取保存结果时出错: {e}")
|
||||
|
||||
print(f"保存完成,共保存 {success_count}/{len(sampled_data)} 个文件")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="基于position和duration token的重采样V2")
|
||||
parser.add_argument("--json_path", type=str,
|
||||
default="octuple_token_analysis_report.json",
|
||||
help="token分析报告JSON文件路径")
|
||||
parser.add_argument("--data_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8",
|
||||
help="数据目录路径")
|
||||
parser.add_argument("--output_dir", type=str,
|
||||
default="dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2",
|
||||
help="输出目录路径")
|
||||
parser.add_argument("--contain_sample_ratio", type=float, default=0.1,
|
||||
help="对于包含前3个token的歌曲,采样比例 (默认: 0.1)")
|
||||
parser.add_argument("--not_contain_sample_ratio", type=float, default=0.9,
|
||||
help="对于不包含前3个token的歌曲,采样比例 (默认: 0.9)")
|
||||
parser.add_argument("--top_k", type=int, default=3,
|
||||
help="使用前k个最常见的token (默认: 3)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="随机种子")
|
||||
parser.add_argument("--num_threads", type=int, default=None,
|
||||
help="线程数,None表示使用CPU核心数 (默认: None)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 1. 加载position和duration的前top_k个最常见token
|
||||
top_position_tokens = load_top_tokens_from_json(
|
||||
args.json_path, "position", top_k=args.top_k
|
||||
)
|
||||
top_duration_tokens = load_top_tokens_from_json(
|
||||
args.json_path, "duration", top_k=args.top_k
|
||||
)
|
||||
|
||||
if not top_position_tokens or not top_duration_tokens:
|
||||
print("错误: 无法加载前top_k个token")
|
||||
return
|
||||
|
||||
# 2. 获取所有数据文件路径
|
||||
file_paths = get_data_file_paths(args.data_dir)
|
||||
|
||||
if not file_paths:
|
||||
print("错误: 未找到任何数据文件")
|
||||
return
|
||||
|
||||
print(f"\n共找到 {len(file_paths)} 个数据文件")
|
||||
|
||||
# 3. 根据新逻辑进行重采样
|
||||
sampled_paths, sampled_indices, stats = resample_songs(
|
||||
file_paths,
|
||||
top_position_tokens,
|
||||
top_duration_tokens,
|
||||
contain_sample_ratio=args.contain_sample_ratio,
|
||||
not_contain_sample_ratio=args.not_contain_sample_ratio,
|
||||
num_threads=args.num_threads,
|
||||
seed=args.seed
|
||||
)
|
||||
|
||||
# 4. 保存采样结果(延迟加载)
|
||||
sampled_data = sampled_paths # 使用路径,延迟加载
|
||||
save_sampled_data(sampled_data, sampled_paths, args.output_dir, num_threads=args.num_threads)
|
||||
|
||||
# 5. 保存采样索引(可选,用于后续分析)
|
||||
indices_file = Path(args.output_dir) / "sampled_indices.npy"
|
||||
np.save(indices_file, np.array(sampled_indices))
|
||||
print(f"\n采样索引已保存到: {indices_file}")
|
||||
|
||||
# 6. 保存采样信息
|
||||
info = {
|
||||
"total_samples": len(file_paths),
|
||||
"sampled_samples": len(sampled_paths),
|
||||
"contain_sample_ratio": args.contain_sample_ratio,
|
||||
"not_contain_sample_ratio": args.not_contain_sample_ratio,
|
||||
"top_k": args.top_k,
|
||||
"top_position_tokens": sorted(list(top_position_tokens)),
|
||||
"top_duration_tokens": sorted(list(top_duration_tokens)),
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
info_file = Path(args.output_dir) / "resample_info.json"
|
||||
with open(info_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(info, f, indent=2, ensure_ascii=False)
|
||||
print(f"采样信息已保存到: {info_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,14 +1,282 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
统计octuple分词结果中每一列每个token的出现次数,并生成分析报告
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import defaultdict, Counter
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
|
||||
# 读取 npz 文件
|
||||
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
|
||||
# Octuple的列名定义
|
||||
COLUMN_NAMES = [
|
||||
"pitch", # 0: Pitch/PitchDrum
|
||||
"position", # 1: Position
|
||||
"bar", # 2: Bar
|
||||
"velocity", # 3: Velocity
|
||||
"duration", # 4: Duration
|
||||
"program", # 5: Program
|
||||
"tempo", # 6: Tempo
|
||||
"timesig" # 7: TimeSignature
|
||||
]
|
||||
|
||||
# 查看保存的键
|
||||
print(data.files)
|
||||
# 输出:['filename', 'sequence']
|
||||
|
||||
# 访问数据
|
||||
sequence = data["arr_0"]
|
||||
def load_octuple_data(data_dir):
|
||||
"""
|
||||
加载所有octuple分词后的.npz文件
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录路径,可以是单个目录或包含多个子目录的根目录
|
||||
|
||||
Returns:
|
||||
list: 所有加载的numpy数组列表
|
||||
"""
|
||||
data_dir = Path(data_dir)
|
||||
npz_files = []
|
||||
|
||||
# 如果目录存在,查找所有.npz文件
|
||||
if data_dir.exists():
|
||||
npz_files = list(data_dir.rglob("*.npz"))
|
||||
|
||||
if not npz_files:
|
||||
print(f"警告: 在 {data_dir} 中未找到.npz文件")
|
||||
return []
|
||||
|
||||
print(f"找到 {len(npz_files)} 个.npz文件,开始加载...")
|
||||
|
||||
all_data = []
|
||||
for npz_file in tqdm(npz_files, desc="加载数据"):
|
||||
try:
|
||||
data = np.load(npz_file)['arr_0']
|
||||
# 确保数据是二维数组 (num_tokens, num_columns)
|
||||
if data.ndim == 2:
|
||||
all_data.append(data)
|
||||
elif data.ndim == 1:
|
||||
# 如果是一维,可能需要reshape,但octuple应该是二维的
|
||||
print(f"警告: {npz_file} 是一维数组,跳过")
|
||||
except Exception as e:
|
||||
print(f"错误: 加载 {npz_file} 时出错: {e}")
|
||||
continue
|
||||
|
||||
return all_data
|
||||
|
||||
|
||||
def count_tokens_by_column(all_data):
|
||||
"""
|
||||
统计每一列每个token的出现次数
|
||||
|
||||
Args:
|
||||
all_data: 所有数据的列表,每个元素是一个numpy数组 (num_tokens, num_columns)
|
||||
|
||||
Returns:
|
||||
dict: {column_index: Counter({token_value: count})}
|
||||
"""
|
||||
column_counters = defaultdict(Counter)
|
||||
|
||||
print("统计token出现次数...")
|
||||
for data in tqdm(all_data, desc="处理数据"):
|
||||
if data.size == 0:
|
||||
continue
|
||||
|
||||
num_columns = data.shape[1] if data.ndim == 2 else 1
|
||||
|
||||
for col_idx in range(num_columns):
|
||||
if data.ndim == 2:
|
||||
tokens = data[:, col_idx]
|
||||
else:
|
||||
tokens = data
|
||||
|
||||
# 统计该列中每个token的出现次数
|
||||
unique, counts = np.unique(tokens, return_counts=True)
|
||||
for token, count in zip(unique, counts):
|
||||
column_counters[col_idx][int(token)] += int(count)
|
||||
|
||||
return dict(column_counters)
|
||||
|
||||
|
||||
def generate_report(column_counters, output_file=None):
|
||||
"""
|
||||
生成分析报告
|
||||
|
||||
Args:
|
||||
column_counters: 统计结果字典
|
||||
output_file: 输出文件路径(可选)
|
||||
"""
|
||||
report_lines = []
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("OCTUPLE分词结果统计分析报告")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
# 总体统计
|
||||
total_tokens = sum(sum(counter.values()) for counter in column_counters.values())
|
||||
report_lines.append(f"总token数: {total_tokens:,}")
|
||||
report_lines.append(f"分析的列数: {len(column_counters)}")
|
||||
report_lines.append("")
|
||||
|
||||
# 每一列的详细统计
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
|
||||
report_lines.append("-" * 80)
|
||||
report_lines.append(f"列 {col_idx}: {col_name}")
|
||||
report_lines.append("-" * 80)
|
||||
|
||||
total_in_column = sum(counter.values())
|
||||
unique_tokens = len(counter)
|
||||
min_token = min(counter.keys()) if counter else 0
|
||||
max_token = max(counter.keys()) if counter else 0
|
||||
|
||||
report_lines.append(f" 总token数: {total_in_column:,}")
|
||||
report_lines.append(f" 唯一token数: {unique_tokens:,}")
|
||||
report_lines.append(f" Token值范围: [{min_token}, {max_token}]")
|
||||
report_lines.append(f" 平均出现次数: {total_in_column / unique_tokens:.2f}" if unique_tokens > 0 else " 平均出现次数: N/A")
|
||||
report_lines.append("")
|
||||
|
||||
# Top 20 最常见的token
|
||||
report_lines.append(f" Top 20 最常见的token:")
|
||||
top_tokens = counter.most_common(20)
|
||||
for rank, (token, count) in enumerate(top_tokens, 1):
|
||||
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
|
||||
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
|
||||
report_lines.append("")
|
||||
|
||||
# Top 20 最不常见的token(出现次数>0的)
|
||||
report_lines.append(f" Top 20 最不常见的token (出现次数>0):")
|
||||
bottom_tokens = counter.most_common()[-20:]
|
||||
bottom_tokens.reverse()
|
||||
for rank, (token, count) in enumerate(bottom_tokens, 1):
|
||||
percentage = (count / total_in_column * 100) if total_in_column > 0 else 0
|
||||
report_lines.append(f" {rank:2d}. Token {token:6d}: {count:10,} 次 ({percentage:5.2f}%)")
|
||||
report_lines.append("")
|
||||
|
||||
# 分布统计
|
||||
counts_list = list(counter.values())
|
||||
if counts_list:
|
||||
report_lines.append(f" 分布统计:")
|
||||
report_lines.append(f" 最小出现次数: {min(counts_list):,}")
|
||||
report_lines.append(f" 最大出现次数: {max(counts_list):,}")
|
||||
report_lines.append(f" 中位数出现次数: {np.median(counts_list):,.0f}")
|
||||
report_lines.append(f" 标准差: {np.std(counts_list):,.2f}")
|
||||
report_lines.append("")
|
||||
|
||||
# 跨列分析
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("跨列分析")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("")
|
||||
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
total_in_column = sum(counter.values())
|
||||
percentage = (total_in_column / total_tokens * 100) if total_tokens > 0 else 0
|
||||
report_lines.append(f" {col_name:12s}: {total_in_column:12,} tokens ({percentage:5.2f}%)")
|
||||
|
||||
report_lines.append("")
|
||||
report_lines.append("=" * 80)
|
||||
report_lines.append("报告生成完成")
|
||||
report_lines.append("=" * 80)
|
||||
|
||||
# 输出报告
|
||||
report_text = "\n".join(report_lines)
|
||||
print("\n" + report_text)
|
||||
|
||||
# 保存到文件
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(report_text)
|
||||
print(f"\n报告已保存到: {output_path}")
|
||||
|
||||
# 同时保存JSON格式的详细数据
|
||||
if output_file:
|
||||
json_output = output_path.with_suffix('.json')
|
||||
json_data = {
|
||||
'summary': {
|
||||
'total_tokens': total_tokens,
|
||||
'num_columns': len(column_counters)
|
||||
},
|
||||
'columns': {}
|
||||
}
|
||||
|
||||
for col_idx in sorted(column_counters.keys()):
|
||||
counter = column_counters[col_idx]
|
||||
col_name = COLUMN_NAMES[col_idx] if col_idx < len(COLUMN_NAMES) else f"column_{col_idx}"
|
||||
json_data['columns'][col_name] = {
|
||||
'total_tokens': sum(counter.values()),
|
||||
'unique_tokens': len(counter),
|
||||
'token_counts': dict(counter),
|
||||
'top_20': dict(counter.most_common(20)),
|
||||
'bottom_20': dict(counter.most_common()[-20:])
|
||||
}
|
||||
|
||||
with open(json_output, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f, indent=2, ensure_ascii=False)
|
||||
print(f"详细数据已保存到: {json_output}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 默认数据目录 - 可以根据需要修改
|
||||
default_data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi"
|
||||
|
||||
# 可以指定具体的数据目录,例如:
|
||||
data_dir = "dataset/represented_data/tuneidx/tuneidx_msmidi/oct8_resampled_v2"
|
||||
# 或者使用默认目录扫描所有oct8目录
|
||||
|
||||
# import sys
|
||||
# if len(sys.argv) > 1:
|
||||
# data_dir = sys.argv[1]
|
||||
# else:
|
||||
# # 自动查找所有oct8目录
|
||||
# base_dir = Path(default_data_dir)
|
||||
# oct8_dirs = list(base_dir.rglob("oct8"))
|
||||
# if oct8_dirs:
|
||||
# print(f"找到以下oct8目录:")
|
||||
# for i, d in enumerate(oct8_dirs, 1):
|
||||
# print(f" {i}. {d}")
|
||||
# if len(oct8_dirs) == 1:
|
||||
# data_dir = str(oct8_dirs[0])
|
||||
# print(f"\n使用目录: {data_dir}")
|
||||
# else:
|
||||
# # 使用第一个找到的目录,或者合并所有目录
|
||||
# print(f"\n使用第一个目录: {oct8_dirs[0]}")
|
||||
# print("如需分析其他目录,请指定路径作为参数")
|
||||
# data_dir = str(oct8_dirs[0])
|
||||
# else:
|
||||
# data_dir = default_data_dir
|
||||
# print(f"未找到oct8目录,使用默认目录: {data_dir}")
|
||||
|
||||
# 加载数据
|
||||
all_data = load_octuple_data(data_dir)
|
||||
|
||||
if not all_data:
|
||||
print("错误: 未加载到任何数据")
|
||||
return
|
||||
|
||||
# 检查数据格式
|
||||
if all_data:
|
||||
sample = all_data[0]
|
||||
print(f"\n数据格式检查:")
|
||||
print(f" 样本形状: {sample.shape}")
|
||||
print(f" 样本数据类型: {sample.dtype}")
|
||||
print(f" 列数: {sample.shape[1] if sample.ndim == 2 else 1}")
|
||||
print()
|
||||
|
||||
# 统计token出现次数
|
||||
column_counters = count_tokens_by_column(all_data)
|
||||
|
||||
# 生成报告
|
||||
output_file = "octuple_token_analysis_report_part.txt"
|
||||
generate_report(column_counters, output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
print("token 序列长度:", len(sequence))
|
||||
print("前 20 个 token:", sequence[:20])
|
||||
Reference in New Issue
Block a user