Files
MIDIFoundationModel/data_representation/octuple2tuneinidx.py
2025-10-29 17:14:33 +08:00

135 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MIDI 预处理脚本(并行版)
功能:
1. 使用 miditok 的 Octuple 分词器。
2. 限制 MIDI 文件时长在 8~2000 秒。
3. 缺失 tempo 时默认 120 BPM缺失 time signature 时默认 4/4。
4. 保存 vocab.json。
5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。
"""
import os
import glob
import struct
import numpy as np
from multiprocessing import RLock
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from symusic import Score
from miditok import Octuple, TokenizerConfig
lock = RLock()
def convert_event_dicts(dict_list):
"""
将 event 词表列表按顺序转换为结构化输出
输入: list[dict]
每个 dict 对应一个类别,按固定顺序排列:
0: Pitch/PitchDrum
1: Position
2: Bar
(+ Optional) Velocity
(+ Optional) Duration
(+ Optional) Program
(+ Optional) Tempo
(+ Optional) TimeSignature
输出示例:
{
"pitch": {"0": 0, "1": "Pitch_60", ...},
"position": {...},
...
}
"""
keys_order = [
"pitch", "position", "bar",
"velocity", "duration", "program",
"tempo", "timesig"
]
result = {}
for i, d in enumerate(dict_list):
if i >= len(keys_order):
break # 超出定义范围的忽略
category = keys_order[i]
result[category] = {str(v): k for k, v in d.items()}
return result
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
try:
score = Score(midi_path, ttype="tick")
if(not (8 <= (duration := score.to('second').end()) <= 2000)):
with lock:
print(f" × 时长不符合要求:{midi_path} -> {duration}s")
return
# 分词
tok_seq = tokenizer(score)
token_ids = tok_seq.ids
# add sos token at the beginning
vocab = tokenizer.vocab
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
token_ids.insert(0, sos_token)
token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2
# 保存单个 npz 文件
filename = os.path.splitext(os.path.basename(midi_path))[0]
save_path = os.path.join(output_dir, f"{filename}.npz")
np.savez_compressed(save_path, np.array(token_ids))
except Exception as e:
with lock:
print(f" × 处理文件时出错:{midi_path} -> {e}")
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
# === 1. 初始化分词器并保存词表 ===
print("初始化分词器 Octuple...")
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
vocab = tokenizer.vocab
vocab_structured = convert_event_dicts(vocab)
with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f:
import json
json.dump(vocab_structured, f, ensure_ascii=False, indent=4)
# === 2. 创建输出目录 ===
os.makedirs(output_dir, exist_ok=True)
# === 3. 收集 MIDI 文件 ===
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
midi_paths = list(midi_paths)
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
# === 4. 并行处理 ===
results = []
with ProcessPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
for future in tqdm(as_completed(futures), total=len(futures)):
res = future.result()
if res:
results.append(res)
# === 5. 汇总结果 ===
print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。")
if __name__ == "__main__":
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
dataset_name = midi_directory.split("/")[-1]
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
output_dir = tuneidx_prefix
preprocess_midi_directory(midi_directory, output_dir)