1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MIDI 预处理脚本(并行版)
功能:
1. 使用 miditok 的 Octuple 分词器。
2. 限制 MIDI 文件时长在 8~2000 秒。
3. 缺失 tempo 时默认 120 BPM缺失 time signature 时默认 4/4。
4. 保存 vocab.json。
5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。
"""
import os
import glob
import struct
import numpy as np
from multiprocessing import RLock
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
from symusic import Score
from miditok import Octuple, TokenizerConfig
lock = RLock()
def convert_event_dicts(dict_list):
"""
将 event 词表列表按顺序转换为结构化输出
输入: list[dict]
每个 dict 对应一个类别,按固定顺序排列:
0: Pitch/PitchDrum
1: Position
2: Bar
(+ Optional) Velocity
(+ Optional) Duration
(+ Optional) Program
(+ Optional) Tempo
(+ Optional) TimeSignature
输出示例:
{
"pitch": {"0": 0, "1": "Pitch_60", ...},
"position": {...},
...
}
"""
keys_order = [
"pitch", "position", "bar",
"velocity", "duration", "program",
"tempo", "timesig"
]
result = {}
for i, d in enumerate(dict_list):
if i >= len(keys_order):
break # 超出定义范围的忽略
category = keys_order[i]
result[category] = {str(v): k for k, v in d.items()}
return result
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
try:
score = Score(midi_path, ttype="tick")
if(not (8 <= (duration := score.to('second').end()) <= 2000)):
with lock:
print(f" × 时长不符合要求:{midi_path} -> {duration}s")
return
# 分词
tok_seq = tokenizer(score)
token_ids = tok_seq.ids
# add sos token at the beginning
vocab = tokenizer.vocab
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
token_ids.insert(0, sos_token)
token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2
# 保存单个 npz 文件
filename = os.path.splitext(os.path.basename(midi_path))[0]
save_path = os.path.join(output_dir, f"{filename}.npz")
np.savez_compressed(save_path, np.array(token_ids))
except Exception as e:
with lock:
print(f" × 处理文件时出错:{midi_path} -> {e}")
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
# === 1. 初始化分词器并保存词表 ===
print("初始化分词器 Octuple...")
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
vocab = tokenizer.vocab
vocab_structured = convert_event_dicts(vocab)
with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f:
import json
json.dump(vocab_structured, f, ensure_ascii=False, indent=4)
# === 2. 创建输出目录 ===
os.makedirs(output_dir, exist_ok=True)
# === 3. 收集 MIDI 文件 ===
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
midi_paths = list(midi_paths)
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
# === 4. 并行处理 ===
results = []
with ProcessPoolExecutor(max_workers=num_threads) as executor:
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
for future in tqdm(as_completed(futures), total=len(futures)):
res = future.result()
if res:
results.append(res)
# === 5. 汇总结果 ===
print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。")
if __name__ == "__main__":
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
dataset_name = midi_directory.split("/")[-1]
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
output_dir = tuneidx_prefix
preprocess_midi_directory(midi_directory, output_dir)

View File

@ -0,0 +1,14 @@
import numpy as np
# 读取 npz 文件
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
# 查看保存的键
print(data.files)
# 输出:['filename', 'sequence']
# 访问数据
sequence = data["arr_0"]
print("token 序列长度:", len(sequence))
print("前 20 个 token", sequence[:20])

View File

@ -1,5 +1,6 @@
import pickle
from pathlib import Path
from re import L
from typing import Union
from multiprocessing import Pool, cpu_count
from collections import defaultdict
@ -58,8 +59,8 @@ class LangTokenVocab:
if in_vocab_file_path is not None:
with open(in_vocab_file_path, 'r') as f:
idx2event_temp = json.load(f)
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
for key in idx2event_temp.keys():
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb' or self.encoding_scheme == 'oct':
for key in idx2event_temp.keys():
idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()}
elif self.encoding_scheme == 'remi':
idx2event_temp = {int(idx):tok for idx, tok in idx2event_temp.items()}
@ -71,13 +72,18 @@ class LangTokenVocab:
# Extracts features depending on the number of features chosen (4, 5, 7, 8).
def _get_features(self):
feature_args = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
self.feature_list = feature_args[self.num_features]
if self.encoding_scheme != 'oct':
feature_args = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
self.feature_list = feature_args[self.num_features]
else:
feature_args = {
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
self.feature_list = feature_args[self.num_features]
# Saves the current vocabulary to a specified JSON path.
def save_vocab(self, json_path):
with open(json_path, 'w') as f:
@ -93,13 +99,17 @@ class LangTokenVocab:
self.sos_token = [self.event2idx['SOS_None']]
self.eos_token = [[self.event2idx['EOS_None']]]
else:
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
else: # oct
self.sos_token = [[self.event2idx['pitch']['BOS_None']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['pitch']['EOS_None']] + [0] * (self.num_features - 1)]
# Generates vocabularies by either loading from a file or creating them based on the event data.
def _get_vocab(self, event_data, unique_vocabs=None):
# make new vocab from given event_data
if event_data is not None:
if event_data is not None and self.encoding_scheme != 'oct':
unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))]))
unique_vocabs = sorted(unique_char_list)
unique_vocabs.remove('SOS_None')
@ -119,6 +129,7 @@ class LangTokenVocab:
# load premade vocab
else:
idx2event = unique_vocabs
print(idx2event)
event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()}
return idx2event, event2idx
@ -392,4 +403,47 @@ class MusicTokenVocabNB(MusicTokenVocabCP):
unique_vocabs.insert(3, 'SSS')
unique_vocabs.insert(4, 'SSN')
unique_vocabs.insert(5, 'SNN')
return unique_vocabs
return unique_vocabs
class MusicTokenVocabOct(LangTokenVocab):
def __init__(
self,
in_vocab_file_path:Union[Path, None],
event_data: list,
encoding_scheme: str,
num_features: int
):
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
def _get_vocab(self, event_data, unique_vocabs=None):
if event_data is not None:
# Create vocab mappings (event2idx, idx2event) from the provided event data
print('start to get unique vocab')
event2idx = {}
idx2event = {}
unique_vocabs = defaultdict(set)
# Use multiprocessing to extract unique vocabularies for each event
with Pool(16) as p:
results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data]))
# Combine results from different processes
for result in results:
for key in self.feature_list:
unique_vocabs[key].update(result[key])
# Process each feature type
for key in self.feature_list:
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
# Create event2idx and idx2event mappings for each feature
event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])}
idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])}
return idx2event, event2idx
else:
# If no event data, simply map unique vocab to indexes
event2idx = {}
for key in self.feature_list:
event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()}
return unique_vocabs, event2idx
def get_vocab_size(self):
# Return the size of the vocabulary for each feature
return {key: len(self.idx2event[key]) for key in self.feature_list}