1029 add octuple
This commit is contained in:
135
data_representation/octuple2tuneinidx.py
Normal file
135
data_representation/octuple2tuneinidx.py
Normal file
@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MIDI 预处理脚本(并行版)
|
||||
功能:
|
||||
1. 使用 miditok 的 Octuple 分词器。
|
||||
2. 限制 MIDI 文件时长在 8~2000 秒。
|
||||
3. 缺失 tempo 时默认 120 BPM;缺失 time signature 时默认 4/4。
|
||||
4. 保存 vocab.json。
|
||||
5. 使用多线程遍历目录下所有 MIDI 文件分词,每个文件单独保存为 {filename}.npz。
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import struct
|
||||
import numpy as np
|
||||
from multiprocessing import RLock
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
from tqdm import tqdm
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
|
||||
|
||||
lock = RLock()
|
||||
|
||||
def convert_event_dicts(dict_list):
|
||||
"""
|
||||
将 event 词表列表按顺序转换为结构化输出
|
||||
输入: list[dict]
|
||||
每个 dict 对应一个类别,按固定顺序排列:
|
||||
0: Pitch/PitchDrum
|
||||
1: Position
|
||||
2: Bar
|
||||
(+ Optional) Velocity
|
||||
(+ Optional) Duration
|
||||
(+ Optional) Program
|
||||
(+ Optional) Tempo
|
||||
(+ Optional) TimeSignature
|
||||
输出示例:
|
||||
{
|
||||
"pitch": {"0": 0, "1": "Pitch_60", ...},
|
||||
"position": {...},
|
||||
...
|
||||
}
|
||||
"""
|
||||
keys_order = [
|
||||
"pitch", "position", "bar",
|
||||
"velocity", "duration", "program",
|
||||
"tempo", "timesig"
|
||||
]
|
||||
|
||||
result = {}
|
||||
for i, d in enumerate(dict_list):
|
||||
if i >= len(keys_order):
|
||||
break # 超出定义范围的忽略
|
||||
category = keys_order[i]
|
||||
result[category] = {str(v): k for k, v in d.items()}
|
||||
|
||||
return result
|
||||
|
||||
def process_single_midi(midi_path: str, tokenizer: Octuple, output_dir: str):
|
||||
try:
|
||||
score = Score(midi_path, ttype="tick")
|
||||
|
||||
if(not (8 <= (duration := score.to('second').end()) <= 2000)):
|
||||
with lock:
|
||||
print(f" × 时长不符合要求:{midi_path} -> {duration}s")
|
||||
return
|
||||
|
||||
# 分词
|
||||
tok_seq = tokenizer(score)
|
||||
token_ids = tok_seq.ids
|
||||
# add sos token at the beginning
|
||||
vocab = tokenizer.vocab
|
||||
sos_token = [vocab[0]['BOS_None']] + [0] * (len(vocab) - 1)
|
||||
token_ids.insert(0, sos_token)
|
||||
token_ids.sort(key=lambda x: (x[2], x[1])) # pos in 1, bar in 2
|
||||
|
||||
# 保存单个 npz 文件
|
||||
filename = os.path.splitext(os.path.basename(midi_path))[0]
|
||||
save_path = os.path.join(output_dir, f"{filename}.npz")
|
||||
np.savez_compressed(save_path, np.array(token_ids))
|
||||
except Exception as e:
|
||||
with lock:
|
||||
print(f" × 处理文件时出错:{midi_path} -> {e}")
|
||||
|
||||
|
||||
def preprocess_midi_directory(midi_dir: str, output_dir: str = "dataset_npz", num_threads: int = int(os.cpu_count() // 2)):
|
||||
# === 1. 初始化分词器并保存词表 ===
|
||||
print("初始化分词器 Octuple...")
|
||||
config = TokenizerConfig(
|
||||
use_time_signatures=True,
|
||||
use_tempos=True,
|
||||
use_velocities=True,
|
||||
use_programs=True,
|
||||
remove_duplicated_notes=True,
|
||||
delete_equal_successive_tempo_changes=True,
|
||||
)
|
||||
config.additional_params["max_bar_embedding"] = 512
|
||||
tokenizer = Octuple(config)
|
||||
vocab = tokenizer.vocab
|
||||
vocab_structured = convert_event_dicts(vocab)
|
||||
with open( "vocab/oct_vocab.json", "w", encoding="utf-8") as f:
|
||||
import json
|
||||
json.dump(vocab_structured, f, ensure_ascii=False, indent=4)
|
||||
# === 2. 创建输出目录 ===
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# === 3. 收集 MIDI 文件 ===
|
||||
midi_paths = glob.glob(os.path.join(midi_dir, "**", "*.mid"), recursive=True) + \
|
||||
glob.glob(os.path.join(midi_dir, "**", "*.midi"), recursive=True)
|
||||
midi_paths = list(midi_paths)
|
||||
print(f"共发现 {len(midi_paths)} 个 MIDI 文件,使用 {num_threads} 个线程处理。\n")
|
||||
|
||||
# === 4. 并行处理 ===
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = {executor.submit(process_single_midi, path, tokenizer, output_dir): path for path in midi_paths}
|
||||
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
res = future.result()
|
||||
if res:
|
||||
results.append(res)
|
||||
|
||||
# === 5. 汇总结果 ===
|
||||
print(f"\n处理完成:成功生成 {len(results)} 个 .npz 文件,保存在 {output_dir}/ 中。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midi_directory = "dataset/Melody" # 修改为你的 MIDI 文件目录
|
||||
dataset_name = midi_directory.split("/")[-1]
|
||||
tuneidx_prefix = f"dataset/represented_data/tuneidx/tuneidx_{dataset_name}/oct8"
|
||||
output_dir = tuneidx_prefix
|
||||
preprocess_midi_directory(midi_directory, output_dir)
|
||||
14
data_representation/test.py
Normal file
14
data_representation/test.py
Normal file
@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
|
||||
# 读取 npz 文件
|
||||
data = np.load("dataset/represented_data/tuneidx/tuneidx_Melody/octuple8/AIDemo-recuKqEwVxsfij.npz", allow_pickle=True)
|
||||
|
||||
# 查看保存的键
|
||||
print(data.files)
|
||||
# 输出:['filename', 'sequence']
|
||||
|
||||
# 访问数据
|
||||
sequence = data["arr_0"]
|
||||
|
||||
print("token 序列长度:", len(sequence))
|
||||
print("前 20 个 token:", sequence[:20])
|
||||
@ -1,5 +1,6 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from re import L
|
||||
from typing import Union
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from collections import defaultdict
|
||||
@ -58,8 +59,8 @@ class LangTokenVocab:
|
||||
if in_vocab_file_path is not None:
|
||||
with open(in_vocab_file_path, 'r') as f:
|
||||
idx2event_temp = json.load(f)
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
for key in idx2event_temp.keys():
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb' or self.encoding_scheme == 'oct':
|
||||
for key in idx2event_temp.keys():
|
||||
idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()}
|
||||
elif self.encoding_scheme == 'remi':
|
||||
idx2event_temp = {int(idx):tok for idx, tok in idx2event_temp.items()}
|
||||
@ -71,13 +72,18 @@ class LangTokenVocab:
|
||||
|
||||
# Extracts features depending on the number of features chosen (4, 5, 7, 8).
|
||||
def _get_features(self):
|
||||
feature_args = {
|
||||
4: ["type", "beat", "pitch", "duration"],
|
||||
5: ["type", "beat", "instrument", "pitch", "duration"],
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
|
||||
self.feature_list = feature_args[self.num_features]
|
||||
|
||||
if self.encoding_scheme != 'oct':
|
||||
feature_args = {
|
||||
4: ["type", "beat", "pitch", "duration"],
|
||||
5: ["type", "beat", "instrument", "pitch", "duration"],
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
|
||||
self.feature_list = feature_args[self.num_features]
|
||||
else:
|
||||
feature_args = {
|
||||
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
|
||||
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
|
||||
self.feature_list = feature_args[self.num_features]
|
||||
# Saves the current vocabulary to a specified JSON path.
|
||||
def save_vocab(self, json_path):
|
||||
with open(json_path, 'w') as f:
|
||||
@ -93,13 +99,17 @@ class LangTokenVocab:
|
||||
self.sos_token = [self.event2idx['SOS_None']]
|
||||
self.eos_token = [[self.event2idx['EOS_None']]]
|
||||
else:
|
||||
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
|
||||
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
|
||||
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
|
||||
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
|
||||
else: # oct
|
||||
self.sos_token = [[self.event2idx['pitch']['BOS_None']] + [0] * (self.num_features - 1)]
|
||||
self.eos_token = [[self.event2idx['pitch']['EOS_None']] + [0] * (self.num_features - 1)]
|
||||
|
||||
# Generates vocabularies by either loading from a file or creating them based on the event data.
|
||||
def _get_vocab(self, event_data, unique_vocabs=None):
|
||||
# make new vocab from given event_data
|
||||
if event_data is not None:
|
||||
if event_data is not None and self.encoding_scheme != 'oct':
|
||||
unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))]))
|
||||
unique_vocabs = sorted(unique_char_list)
|
||||
unique_vocabs.remove('SOS_None')
|
||||
@ -119,6 +129,7 @@ class LangTokenVocab:
|
||||
# load premade vocab
|
||||
else:
|
||||
idx2event = unique_vocabs
|
||||
print(idx2event)
|
||||
event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()}
|
||||
return idx2event, event2idx
|
||||
|
||||
@ -392,4 +403,47 @@ class MusicTokenVocabNB(MusicTokenVocabCP):
|
||||
unique_vocabs.insert(3, 'SSS')
|
||||
unique_vocabs.insert(4, 'SSN')
|
||||
unique_vocabs.insert(5, 'SNN')
|
||||
return unique_vocabs
|
||||
return unique_vocabs
|
||||
|
||||
|
||||
class MusicTokenVocabOct(LangTokenVocab):
|
||||
def __init__(
|
||||
self,
|
||||
in_vocab_file_path:Union[Path, None],
|
||||
event_data: list,
|
||||
encoding_scheme: str,
|
||||
num_features: int
|
||||
):
|
||||
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
|
||||
|
||||
def _get_vocab(self, event_data, unique_vocabs=None):
|
||||
if event_data is not None:
|
||||
# Create vocab mappings (event2idx, idx2event) from the provided event data
|
||||
print('start to get unique vocab')
|
||||
event2idx = {}
|
||||
idx2event = {}
|
||||
unique_vocabs = defaultdict(set)
|
||||
# Use multiprocessing to extract unique vocabularies for each event
|
||||
with Pool(16) as p:
|
||||
results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data]))
|
||||
# Combine results from different processes
|
||||
for result in results:
|
||||
for key in self.feature_list:
|
||||
unique_vocabs[key].update(result[key])
|
||||
# Process each feature type
|
||||
for key in self.feature_list:
|
||||
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
# Create event2idx and idx2event mappings for each feature
|
||||
event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])}
|
||||
idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])}
|
||||
return idx2event, event2idx
|
||||
else:
|
||||
# If no event data, simply map unique vocab to indexes
|
||||
event2idx = {}
|
||||
for key in self.feature_list:
|
||||
event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()}
|
||||
return unique_vocabs, event2idx
|
||||
|
||||
def get_vocab_size(self):
|
||||
# Return the size of the vocabulary for each feature
|
||||
return {key: len(self.idx2event[key]) for key in self.feature_list}
|
||||
Reference in New Issue
Block a user