1029 add octuple
This commit is contained in:
135
data_representation/octuple2tuneinidx.py
Normal file
135
data_representation/octuple2tuneinidx.py
Normal file
@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MIDI 预处理脚本(并行版)
|
||||
功能:
|
||||
1. 使用 miditok 的 Octuple 分词器。
|
||||
2. 限制 MIDI 文件时长在 8~2000 秒。
|
||||
3. 缺失 tempo 时默认 120 BPM;缺失 time signature 时默认 4/4。
|
||||
4. 保存 vocab.json。
|
||||
5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import struct
|
||||
import numpy as np
|
||||
from multiprocessing import RLock
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
from tqdm import tqdm
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
|
||||
|
||||
lock = RLock()
|
||||
|
||||
def convert_event_dicts(dict_list):
|
||||
"""
|
||||
将 event 词表列表按顺序转换为结构化输出
|
||||
输入: list[dict]
|
||||
每个 dict 对应一个类别,按固定顺序排列:
|
||||
0: Pitch/PitchDrum
|
||||
1: Position
|
||||
2: Bar
|
||||
(+ Optional) Velocity
|
||||
(+ Optional) Duration
|
||||
(+ Optional) Program
|
||||
(+ Optional) Tempo
|
||||
(+ Optional) TimeSignature
|
||||
输出示例:
|
||||
{
|
||||
"pitch": {"0": 0, "1": "Pitch_60", ...},
|
||||
"position": {...},
|
||||
...
|
||||
}
|
||||
"""
|
||||
keys_order = [
|
||||
"pitch", "position", "bar",
|
||||
"velocity", "duration", "program",
|
||||
"tempo", "timesig"
|
||||
]
|
||||
|
||||
result = {}
|
||||
for i, d in enumerate(dict_list):
|
||||
if i >= len(keys_order):
|
||||
break # 超出定义范围的忽略
|
||||
category = keys_order[i]
|
||||
result[category] = {str(v): k for k, v in d.items()}
|
||||
|
||||
return result
|
||||
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
try:
|
||||
score = Score(midi_path, ttype="tick")
|
||||
|
||||
if(not (8 <= (duration := score.to('second').end()) <= 2000)):
|
||||
with lock:
|
||||
print(f" × 时长不符合要求:{midi_path} -> {duration}s")
|
||||
return
|
||||
|
||||
# 分词
|
||||
tok_seq = tokenizer(score)
|
||||
token_ids = tok_seq.ids
|
||||
# add sos token at the beginning
|
||||
vocab = tokenizer.vocab
|
||||
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
|
||||
token_ids.insert(0, sos_token)
|
||||
token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2
|
||||
|
||||
# 保存单个 npz 文件
|
||||
filename = os.path.splitext(os.path.basename(midi_path))[0]
|
||||
save_path = os.path.join(output_dir, f"{filename}.npz")
|
||||
np.savez_compressed(save_path, np.array(token_ids))
|
||||
except Exception as e:
|
||||
with lock:
|
||||
print(f" × 处理文件时出错:{midi_path} -> {e}")
|
||||
|
||||
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
|
||||
# === 1. 初始化分词器并保存词表 ===
|
||||
print("初始化分词器 Octuple...")
|
||||
config = TokenizerConfig(
|
||||
use_time_signatures=True,
|
||||
use_tempos=True,
|
||||
use_velocities=True,
|
||||
use_programs=True,
|
||||
remove_duplicated_notes=True,
|
||||
delete_equal_successive_tempo_changes=True,
|
||||
)
|
||||
config.additional_params["max_bar_embedding"] = 512
|
||||
tokenizer = Octuple(config)
|
||||
vocab = tokenizer.vocab
|
||||
vocab_structured = convert_event_dicts(vocab)
|
||||
with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f:
|
||||
import json
|
||||
json.dump(vocab_structured, f, ensure_ascii=False, indent=4)
|
||||
# === 2. 创建输出目录 ===
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# === 3. 收集 MIDI 文件 ===
|
||||
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
midi_paths = list(midi_paths)
|
||||
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
|
||||
|
||||
# === 4. 并行处理 ===
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
res = future.result()
|
||||
if res:
|
||||
results.append(res)
|
||||
|
||||
# === 5. 汇总结果 ===
|
||||
print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
|
||||
dataset_name = midi_directory.split("/")[-1]
|
||||
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
|
||||
output_dir = tuneidx_prefix
|
||||
preprocess_midi_directory(midi_directory, output_dir)
|
||||
Reference in New Issue
Block a user