1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -1,5 +1,6 @@
import pickle
from pathlib import Path
from re import L
from typing import Union
from multiprocessing import Pool, cpu_count
from collections import defaultdict
@ -58,8 +59,8 @@ class LangTokenVocab:
if in_vocab_file_path is not None:
with open(in_vocab_file_path, 'r') as f:
idx2event_temp = json.load(f)
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
for key in idx2event_temp.keys():
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb' or self.encoding_scheme == 'oct':
for key in idx2event_temp.keys():
idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()}
elif self.encoding_scheme == 'remi':
idx2event_temp = {int(idx):tok for idx, tok in idx2event_temp.items()}
@ -71,13 +72,18 @@ class LangTokenVocab:
# Extracts features depending on the number of features chosen (4, 5, 7, 8).
def _get_features(self):
feature_args = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
self.feature_list = feature_args[self.num_features]
if self.encoding_scheme != 'oct':
feature_args = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
self.feature_list = feature_args[self.num_features]
else:
feature_args = {
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
self.feature_list = feature_args[self.num_features]
# Saves the current vocabulary to a specified JSON path.
def save_vocab(self, json_path):
with open(json_path, 'w') as f:
@ -93,13 +99,17 @@ class LangTokenVocab:
self.sos_token = [self.event2idx['SOS_None']]
self.eos_token = [[self.event2idx['EOS_None']]]
else:
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
else: # oct
self.sos_token = [[self.event2idx['pitch']['BOS_None']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['pitch']['EOS_None']] + [0] * (self.num_features - 1)]
# Generates vocabularies by either loading from a file or creating them based on the event data.
def _get_vocab(self, event_data, unique_vocabs=None):
# make new vocab from given event_data
if event_data is not None:
if event_data is not None and self.encoding_scheme != 'oct':
unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))]))
unique_vocabs = sorted(unique_char_list)
unique_vocabs.remove('SOS_None')
@ -119,6 +129,7 @@ class LangTokenVocab:
# load premade vocab
else:
idx2event = unique_vocabs
print(idx2event)
event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()}
return idx2event, event2idx
@ -392,4 +403,47 @@ class MusicTokenVocabNB(MusicTokenVocabCP):
unique_vocabs.insert(3, 'SSS')
unique_vocabs.insert(4, 'SSN')
unique_vocabs.insert(5, 'SNN')
return unique_vocabs
return unique_vocabs
class MusicTokenVocabOct(LangTokenVocab):
def __init__(
self,
in_vocab_file_path:Union[Path, None],
event_data: list,
encoding_scheme: str,
num_features: int
):
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
def _get_vocab(self, event_data, unique_vocabs=None):
if event_data is not None:
# Create vocab mappings (event2idx, idx2event) from the provided event data
print('start to get unique vocab')
event2idx = {}
idx2event = {}
unique_vocabs = defaultdict(set)
# Use multiprocessing to extract unique vocabularies for each event
with Pool(16) as p:
results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data]))
# Combine results from different processes
for result in results:
for key in self.feature_list:
unique_vocabs[key].update(result[key])
# Process each feature type
for key in self.feature_list:
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
# Create event2idx and idx2event mappings for each feature
event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])}
idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])}
return idx2event, event2idx
else:
# If no event data, simply map unique vocab to indexes
event2idx = {}
for key in self.feature_list:
event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()}
return unique_vocabs, event2idx
def get_vocab_size(self):
# Return the size of the vocabulary for each feature
return {key: len(self.idx2event[key]) for key in self.feature_list}