1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MIDI 预处理脚本(并行版)
功能:
1. 使用 miditok 的 Octuple 分词器。
2. 限制 MIDI 文件时长在 8~2000 秒。
3. 缺失 tempo 时默认 120 BPM缺失 time signature 时默认 4/4。
4. 保存 vocab.json。
5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。
"""
import os
import glob
import struct
import numpy as np
from multiprocessing import RLock
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from symusic import Score
from miditok import Octuple, TokenizerConfig
lock = RLock()
def convert_event_dicts(dict_list):
"""
将 event 词表列表按顺序转换为结构化输出
输入: list[dict]
每个 dict 对应一个类别,按固定顺序排列:
0: Pitch/PitchDrum
1: Position
2: Bar
(+ Optional) Velocity
(+ Optional) Duration
(+ Optional) Program
(+ Optional) Tempo
(+ Optional) TimeSignature
输出示例:
{
"pitch": {"0": 0, "1": "Pitch_60", ...},
"position": {...},
...
}
"""
keys_order = [
"pitch", "position", "bar",
"velocity", "duration", "program",
"tempo", "timesig"
]
result = {}
for i, d in enumerate(dict_list):
if i >= len(keys_order):
break # 超出定义范围的忽略
category = keys_order[i]
result[category] = {str(v): k for k, v in d.items()}
return result
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
try:
score = Score(midi_path, ttype="tick")
if(not (8 <= (duration := score.to('second').end()) <= 2000)):
with lock:
print(f" × 时长不符合要求:{midi_path} -> {duration}s")
return
# 分词
tok_seq = tokenizer(score)
token_ids = tok_seq.ids
# add sos token at the beginning
vocab = tokenizer.vocab
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
token_ids.insert(0, sos_token)
token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2
# 保存单个 npz 文件
filename = os.path.splitext(os.path.basename(midi_path))[0]
save_path = os.path.join(output_dir, f"{filename}.npz")
np.savez_compressed(save_path, np.array(token_ids))
except Exception as e:
with lock:
print(f" × 处理文件时出错:{midi_path} -> {e}")
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
# === 1. 初始化分词器并保存词表 ===
print("初始化分词器 Octuple...")
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
vocab = tokenizer.vocab
vocab_structured = convert_event_dicts(vocab)
with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f:
import json
json.dump(vocab_structured, f, ensure_ascii=False, indent=4)
# === 2. 创建输出目录 ===
os.makedirs(output_dir, exist_ok=True)
# === 3. 收集 MIDI 文件 ===
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
midi_paths = list(midi_paths)
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
# === 4. 并行处理 ===
results = []
with ProcessPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
for future in tqdm(as_completed(futures), total=len(futures)):
res = future.result()
if res:
results.append(res)
# === 5. 汇总结果 ===
print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。")
if __name__ == "__main__":
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
dataset_name = midi_directory.split("/")[-1]
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
output_dir = tuneidx_prefix
preprocess_midi_directory(midi_directory, output_dir)