1127 update to latest
This commit is contained in:
@ -20,10 +20,22 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
from itertools import groupby, chain
|
||||
from random import shuffle, seed
|
||||
|
||||
|
||||
lock = RLock()
|
||||
|
||||
def shuffled(seq):
|
||||
shuffle(seq)
|
||||
return seq
|
||||
|
||||
def permute_inside_and_across_tracks(seq):
|
||||
seq_sorted = sorted(seq, key=lambda x: x[5])
|
||||
tracks = [list(g) for _, g in groupby(seq_sorted, key=lambda x: x[5])] # 5 is program
|
||||
return list(chain.from_iterable(shuffled(t) for t in shuffled(tracks)))
|
||||
|
||||
|
||||
def convert_event_dicts(dict_list):
|
||||
"""
|
||||
将 event 词表列表按顺序转换为结构化输出
|
||||
@ -59,7 +71,7 @@ def convert_event_dicts(dict_list):
|
||||
|
||||
return result
|
||||
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str, whether_shuffle: bool):
|
||||
try:
|
||||
score = Score(midi_path, ttype="tick")
|
||||
|
||||
@ -71,6 +83,8 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
# 分词
|
||||
tok_seq = tokenizer(score)
|
||||
token_ids = tok_seq.ids
|
||||
if whether_shuffle:
|
||||
token_ids = permute_inside_and_across_tracks(token_ids)
|
||||
# add sos token at the beginning
|
||||
vocab = tokenizer.vocab
|
||||
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
|
||||
@ -86,7 +100,7 @@ def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
print(f" × 处理文件时出错:{midi_path} -> {e}")
|
||||
|
||||
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2), whether_shuffle: bool = False):
|
||||
# === 1. 初始化分词器并保存词表 ===
|
||||
print("初始化分词器 Octuple...")
|
||||
config = TokenizerConfig(
|
||||
@ -108,15 +122,20 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# === 3. 收集 MIDI 文件 ===
|
||||
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
midi_paths = list(midi_paths)
|
||||
# midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
# glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
# midi_paths = list(midi_paths)
|
||||
midi_paths = []
|
||||
for root, _, files in os.walk(midi_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith(('.mid', '.midi')):
|
||||
midi_paths.append(os.path.join(root, file))
|
||||
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
|
||||
|
||||
# === 4. 并行处理 ===
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir, whether_shuffle): path for path in midi_paths}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
res = future.result()
|
||||
@ -128,8 +147,18 @@ def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", nu
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
|
||||
|
||||
parser = argparse.ArgumentParser(description="MIDI 预处理脚本(并行版)")
|
||||
parser.add_argument("--midi_dir", type=str, default=midi_directory, help="MIDI 文件目录")
|
||||
parser.add_argument("--shuffle", action="store_true", help="是否在处理前打乱文件顺序")
|
||||
|
||||
dataset_name = midi_directory.split("/")[-1]
|
||||
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
|
||||
|
||||
|
||||
output_dir = tuneidx_prefix
|
||||
preprocess_midi_directory(midi_directory, output_dir)
|
||||
args = parser.parse_args()
|
||||
preprocess_midi_directory(midi_directory, output_dir, whether_shuffle=args.shuffle)
|
||||
Reference in New Issue
Block a user