135 lines
4.5 KiB
Python
135 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
MIDI 预处理脚本(并行版)
|
||
功能:
|
||
1. 使用 miditok 的 Octuple 分词器。
|
||
2. 限制 MIDI 文件时长在 8~2000 秒。
|
||
3. 缺失 tempo 时默认 120 BPM;缺失 time signature 时默认 4/4。
|
||
4. 保存 vocab.json。
|
||
5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。
|
||
"""
|
||
|
||
import os
|
||
import glob
|
||
import struct
|
||
import numpy as np
|
||
from multiprocessing import RLock
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
|
||
from tqdm import tqdm
|
||
from symusic import Score
|
||
from miditok import Octuple, TokenizerConfig
|
||
|
||
|
||
lock = RLock()
|
||
|
||
def convert_event_dicts(dict_list):
|
||
"""
|
||
将 event 词表列表按顺序转换为结构化输出
|
||
输入: list[dict]
|
||
每个 dict 对应一个类别,按固定顺序排列:
|
||
0: Pitch/PitchDrum
|
||
1: Position
|
||
2: Bar
|
||
(+ Optional) Velocity
|
||
(+ Optional) Duration
|
||
(+ Optional) Program
|
||
(+ Optional) Tempo
|
||
(+ Optional) TimeSignature
|
||
输出示例:
|
||
{
|
||
"pitch": {"0": 0, "1": "Pitch_60", ...},
|
||
"position": {...},
|
||
...
|
||
}
|
||
"""
|
||
keys_order = [
|
||
"pitch", "position", "bar",
|
||
"velocity", "duration", "program",
|
||
"tempo", "timesig"
|
||
]
|
||
|
||
result = {}
|
||
for i, d in enumerate(dict_list):
|
||
if i >= len(keys_order):
|
||
break # 超出定义范围的忽略
|
||
category = keys_order[i]
|
||
result[category] = {str(v): k for k, v in d.items()}
|
||
|
||
return result
|
||
|
||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||
try:
|
||
score = Score(midi_path, ttype="tick")
|
||
|
||
if(not (8 <= (duration := score.to('second').end()) <= 2000)):
|
||
with lock:
|
||
print(f" × 时长不符合要求:{midi_path} -> {duration}s")
|
||
return
|
||
|
||
# 分词
|
||
tok_seq = tokenizer(score)
|
||
token_ids = tok_seq.ids
|
||
# add sos token at the beginning
|
||
vocab = tokenizer.vocab
|
||
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
|
||
token_ids.insert(0, sos_token)
|
||
token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2
|
||
|
||
# 保存单个 npz 文件
|
||
filename = os.path.splitext(os.path.basename(midi_path))[0]
|
||
save_path = os.path.join(output_dir, f"{filename}.npz")
|
||
np.savez_compressed(save_path, np.array(token_ids))
|
||
except Exception as e:
|
||
with lock:
|
||
print(f" × 处理文件时出错:{midi_path} -> {e}")
|
||
|
||
|
||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
|
||
# === 1. 初始化分词器并保存词表 ===
|
||
print("初始化分词器 Octuple...")
|
||
config = TokenizerConfig(
|
||
use_time_signatures=True,
|
||
use_tempos=True,
|
||
use_velocities=True,
|
||
use_programs=True,
|
||
remove_duplicated_notes=True,
|
||
delete_equal_successive_tempo_changes=True,
|
||
)
|
||
config.additional_params["max_bar_embedding"] = 512
|
||
tokenizer = Octuple(config)
|
||
vocab = tokenizer.vocab
|
||
vocab_structured = convert_event_dicts(vocab)
|
||
with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f:
|
||
import json
|
||
json.dump(vocab_structured, f, ensure_ascii=False, indent=4)
|
||
# === 2. 创建输出目录 ===
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# === 3. 收集 MIDI 文件 ===
|
||
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||
midi_paths = list(midi_paths)
|
||
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
|
||
|
||
# === 4. 并行处理 ===
|
||
results = []
|
||
with ProcessPoolExecutor(max_workers=num_threads) as executor:
|
||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
|
||
|
||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||
res = future.result()
|
||
if res:
|
||
results.append(res)
|
||
|
||
# === 5. 汇总结果 ===
|
||
print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
|
||
dataset_name = midi_directory.split("/")[-1]
|
||
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
|
||
output_dir = tuneidx_prefix
|
||
preprocess_midi_directory(midi_directory, output_dir) |