1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -20,10 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from symusic import Score
from miditok import Octuple, TokenizerConfig
from itertools import groupby, chain
from random import shuffle, seed
lock = RLock()
def shuffled(seq):
shuffle(seq)
return seq
def permute_inside_and_across_tracks(seq):
seq_sorted = sorted(seq, key=lambda x: x[5])
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
def convert_event_dicts(dict_list):
"""
将 event 词表列表按顺序转换为结构化输出
@ -59,7 +71,7 @@ def convert_event_dicts(dict_list):
return result
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool):
try:
score = Score(midi_path, ttype="tick")
@ -71,6 +83,8 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
# 分词
tok_seq = tokenizer(score)
token_ids = tok_seq.ids
if whether_shuffle:
token_ids = permute_inside_and_across_tracks(token_ids)
# add sos token at the beginning
vocab = tokenizer.vocab
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
@ -86,7 +100,7 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
print(f" × 处理文件时出错:{midi_path} -> {e}")
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False):
# === 1. 初始化分词器并保存词表 ===
print("初始化分词器 Octuple...")
config = TokenizerConfig(
@ -108,15 +122,20 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
os.makedirs(output_dir, exist_ok=True)
# === 3. 收集 MIDI 文件 ===
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
midi_paths = list(midi_paths)
# midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
# glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
# midi_paths = list(midi_paths)
midi_paths = []
for root, _, files in os.walk(midi_dir):
for file in files:
if file.lower().endswith(('.mid', '.midi')):
midi_paths.append(os.path.join(root, file))
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
# === 4. 并行处理 ===
results = []
with ProcessPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths}
for future in tqdm(as_completed(futures), total=len(futures)):
res = future.result()
@ -128,8 +147,18 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
if __name__ == "__main__":
import argparse
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)")
parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录")
parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序")
dataset_name = midi_directory.split("/")[-1]
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
output_dir = tuneidx_prefix
preprocess_midi_directory(midi_directory, output_dir)
args = parser.parse_args()
preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle)