1029 add octuple
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from re import L
|
||||
from typing import Union
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from collections import defaultdict
|
||||
@ -58,8 +59,8 @@ class LangTokenVocab:
|
||||
if in_vocab_file_path is not None:
|
||||
with open(in_vocab_file_path, 'r') as f:
|
||||
idx2event_temp = json.load(f)
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
for key in idx2event_temp.keys():
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb' or self.encoding_scheme == 'oct':
|
||||
for key in idx2event_temp.keys():
|
||||
idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()}
|
||||
elif self.encoding_scheme == 'remi':
|
||||
idx2event_temp = {int(idx):tok for idx, tok in idx2event_temp.items()}
|
||||
@ -71,13 +72,18 @@ class LangTokenVocab:
|
||||
|
||||
# Extracts features depending on the number of features chosen (4, 5, 7, 8).
|
||||
def _get_features(self):
|
||||
feature_args = {
|
||||
4: ["type", "beat", "pitch", "duration"],
|
||||
5: ["type", "beat", "instrument", "pitch", "duration"],
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
|
||||
self.feature_list = feature_args[self.num_features]
|
||||
|
||||
if self.encoding_scheme != 'oct':
|
||||
feature_args = {
|
||||
4: ["type", "beat", "pitch", "duration"],
|
||||
5: ["type", "beat", "instrument", "pitch", "duration"],
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
|
||||
self.feature_list = feature_args[self.num_features]
|
||||
else:
|
||||
feature_args = {
|
||||
7: ["pitch", "position", "bar", "duration", "program", "tempo", "timesig"],
|
||||
8: ["pitch", "position", "bar", "velocity", "duration", "program", "tempo", "timesig"]}
|
||||
self.feature_list = feature_args[self.num_features]
|
||||
# Saves the current vocabulary to a specified JSON path.
|
||||
def save_vocab(self, json_path):
|
||||
with open(json_path, 'w') as f:
|
||||
@ -93,13 +99,17 @@ class LangTokenVocab:
|
||||
self.sos_token = [self.event2idx['SOS_None']]
|
||||
self.eos_token = [[self.event2idx['EOS_None']]]
|
||||
else:
|
||||
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
|
||||
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
|
||||
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
|
||||
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
|
||||
else: # oct
|
||||
self.sos_token = [[self.event2idx['pitch']['BOS_None']] + [0] * (self.num_features - 1)]
|
||||
self.eos_token = [[self.event2idx['pitch']['EOS_None']] + [0] * (self.num_features - 1)]
|
||||
|
||||
# Generates vocabularies by either loading from a file or creating them based on the event data.
|
||||
def _get_vocab(self, event_data, unique_vocabs=None):
|
||||
# make new vocab from given event_data
|
||||
if event_data is not None:
|
||||
if event_data is not None and self.encoding_scheme != 'oct':
|
||||
unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))]))
|
||||
unique_vocabs = sorted(unique_char_list)
|
||||
unique_vocabs.remove('SOS_None')
|
||||
@ -119,6 +129,7 @@ class LangTokenVocab:
|
||||
# load premade vocab
|
||||
else:
|
||||
idx2event = unique_vocabs
|
||||
print(idx2event)
|
||||
event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()}
|
||||
return idx2event, event2idx
|
||||
|
||||
@ -392,4 +403,47 @@ class MusicTokenVocabNB(MusicTokenVocabCP):
|
||||
unique_vocabs.insert(3, 'SSS')
|
||||
unique_vocabs.insert(4, 'SSN')
|
||||
unique_vocabs.insert(5, 'SNN')
|
||||
return unique_vocabs
|
||||
return unique_vocabs
|
||||
|
||||
|
||||
class MusicTokenVocabOct(LangTokenVocab):
|
||||
def __init__(
|
||||
self,
|
||||
in_vocab_file_path:Union[Path, None],
|
||||
event_data: list,
|
||||
encoding_scheme: str,
|
||||
num_features: int
|
||||
):
|
||||
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
|
||||
|
||||
def _get_vocab(self, event_data, unique_vocabs=None):
|
||||
if event_data is not None:
|
||||
# Create vocab mappings (event2idx, idx2event) from the provided event data
|
||||
print('start to get unique vocab')
|
||||
event2idx = {}
|
||||
idx2event = {}
|
||||
unique_vocabs = defaultdict(set)
|
||||
# Use multiprocessing to extract unique vocabularies for each event
|
||||
with Pool(16) as p:
|
||||
results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data]))
|
||||
# Combine results from different processes
|
||||
for result in results:
|
||||
for key in self.feature_list:
|
||||
unique_vocabs[key].update(result[key])
|
||||
# Process each feature type
|
||||
for key in self.feature_list:
|
||||
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
# Create event2idx and idx2event mappings for each feature
|
||||
event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])}
|
||||
idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])}
|
||||
return idx2event, event2idx
|
||||
else:
|
||||
# If no event data, simply map unique vocab to indexes
|
||||
event2idx = {}
|
||||
for key in self.feature_list:
|
||||
event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()}
|
||||
return unique_vocabs, event2idx
|
||||
|
||||
def get_vocab_size(self):
|
||||
# Return the size of the vocabulary for each feature
|
||||
return {key: len(self.idx2event[key]) for key in self.feature_list}
|
||||
Reference in New Issue
Block a user