#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ MIDI 预处理脚本(并行版) 功能: 1. 使用 miditok 的 Octuple 分词器。 2. 限制 MIDI 文件时长在 8~2000 秒。 3. 缺失 tempo 时默认 120 BPM;缺失 time signature 时默认 4/4。 4. 保存 vocab.json。 5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。 """ import os import glob import struct import numpy as np from multiprocessing import RLock from concurrent.futures import ProcessPoolExecutor, as_completed from tqdm import tqdm from symusic import Score from miditok import Octuple, TokenizerConfig from itertools import groupby, chain from random import shuffle, seed lock = RLock() def shuffled(seq): shuffle(seq) return seq def permute_inside_and_across_tracks(seq): seq_sorted = sorted(seq, key=lambda x: x[5]) tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks))) def convert_event_dicts(dict_list): """ 将 event 词表列表按顺序转换为结构化输出 输入: list[dict] 每个 dict 对应一个类别,按固定顺序排列: 0: Pitch/PitchDrum 1: Position 2: Bar (+ Optional) Velocity (+ Optional) Duration (+ Optional) Program (+ Optional) Tempo (+ Optional) TimeSignature 输出示例: { "pitch": {"0": 0, "1": "Pitch_60", ...}, "position": {...}, ... } """ keys_order = [ "pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig" ] result = {} for i, d in enumerate(dict_list): if i >= len(keys_order): break # 超出定义范围的忽略 category = keys_order[i] result[category] = {str(v): k for k, v in d.items()} return result def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool): try: score = Score(midi_path, ttype="tick") if(not (8 <= (duration := score.to('second').end()) <= 2000)): with lock: print(f" × 时长不符合要求:{midi_path} -> {duration}s") return # 分词 tok_seq = tokenizer(score) token_ids = tok_seq.ids if whether_shuffle: token_ids = permute_inside_and_across_tracks(token_ids) # add sos token at the beginning vocab = tokenizer.vocab sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1) token_ids.insert(0, sos_token) token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2 # 保存单个 npz 文件 filename = os.path.splitext(os.path.basename(midi_path))[0] save_path = os.path.join(output_dir, f"{filename}.npz") np.savez_compressed(save_path, np.array(token_ids)) except Exception as e: with lock: print(f" × 处理文件时出错:{midi_path} -> {e}") def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False): # === 1. 初始化分词器并保存词表 === print("初始化分词器 Octuple...") config = TokenizerConfig( use_time_signatures=True, use_tempos=True, use_velocities=True, use_programs=True, remove_duplicated_notes=True, delete_equal_successive_tempo_changes=True, ) config.additional_params["max_bar_embedding"] = 512 tokenizer = Octuple(config) vocab = tokenizer.vocab vocab_structured = convert_event_dicts(vocab) with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f: import json json.dump(vocab_structured, f, ensure_ascii=False, indent=4) # === 2. 创建输出目录 === os.makedirs(output_dir, exist_ok=True) # === 3. 收集 MIDI 文件 === # midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \ # glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True) # midi_paths = list(midi_paths) midi_paths = [] for root, _, files in os.walk(midi_dir): for file in files: if file.lower().endswith(('.mid', '.midi')): midi_paths.append(os.path.join(root, file)) print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n") # === 4. 并行处理 === results = [] with ProcessPoolExecutor(max_workers=num_threads) as executor: futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths} for future in tqdm(as_completed(futures), total=len(futures)): res = future.result() if res: results.append(res) # === 5. 汇总结果 === print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。") if __name__ == "__main__": import argparse midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录 parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)") parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录") parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序") dataset_name = midi_directory.split("/")[-1] tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8" output_dir = tuneidx_prefix args = parser.parse_args() preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle)