Files
MIDIFoundationModel/Amadeus/symbolic_encoding/data_utils.py
2025-10-29 17:14:33 +08:00

1896 lines
87 KiB
Python

import re
import random
from pathlib import Path
from collections import OrderedDict
from typing import Union, List, Tuple, Dict
import numpy as np
import matplotlib.pyplot as plt
# lock of thread
from threading import Lock
import json
from tqdm import tqdm
from torch.utils.data import Dataset,IterableDataset
from transformers import T5Tokenizer
from .augmentor import Augmentor
from .compile_utils import VanillaTransformer_compiler
from data_representation import vocab_utils
def get_emb_total_size(config, vocab):
emb_param = config.nn_params.emb
total_size = 0
for feature in vocab.feature_list:
size = int(emb_param[feature] * emb_param.emb_size)
total_size += size
emb_param[feature] = size
emb_param.total_size = total_size
config.nn_params.emb = emb_param
return config
class TuneCompiler(Dataset):
def __init__(
self,
data:List[Tuple[np.ndarray, str]],
data_type:str,
augmentor:Augmentor,
vocab:vocab_utils.LangTokenVocab,
input_length:int,
first_pred_feature:str,
caption_path:Union[str, None] = None,
for_evaluation: bool = False
):
'''
The data is distributed on-the-fly by the TuneCompiler
Pitch, Chord augementation is applied to the training data every iteration
Segmentation is applied every epoch for the training data
'''
super().__init__()
self.data_list = data
self.data_type = data_type
self.augmentor = augmentor
self.eos_token = vocab.eos_token
self.compile_function = VanillaTransformer_compiler(
data_list=self.data_list,
augmentor=self.augmentor,
eos_token=self.eos_token,
input_length=input_length,
first_pred_feature=first_pred_feature,
encoding_scheme=vocab.encoding_scheme
)
self.segment2tune_name = None
self.tune_name2segment = None
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large",legacy=False) # Initialize T5 tokenizer for caption processing
if self.data_type == 'valid' or self.data_type == 'test':
self._update_segments_for_validset()
else:
self._update_segments_for_trainset()
def _update_segments_for_trainset(self, random_seed=0):
random.seed(random_seed)
if self.segment2tune_name is not None:
# If segments are already compiled, we can skip the compilation
print("Segments are already compiled, skipping compilation")
return
print("Compiling segments for training data")
with Lock():
self.segments, _, self.segment2tune_name = self.compile_function.make_segments(self.data_type)
print(f"number of trainset segments: {len(self.segments)}")
def _update_segments_for_validset(self, random_seed=0):
random.seed(random_seed)
with Lock():
self.segments, self.tune_name2segment, self.segment2tune_name = self.compile_function.make_segments(self.data_type)
print(f"number of testset segments: {len(self.segments)}")
def __getitem__(self, idx):
segment, tensor_mask = self.segments[idx]
tune_name = self.segment2tune_name[idx]
try:
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
except Exception as e:
print(f"Error encoding caption for tune {tune_name}: {e}")
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
return segment, tensor_mask, tune_name, encoded_caption
def get_segments_with_tune_idx(self, tune_name, seg_order):
'''
This function is used to retrieve the segment with the tune name and segment order during the validation
'''
segments_list = self.tune_name2segment[tune_name]
segment_idx = segments_list[seg_order]
segment, mask = self.segments[segment_idx][0], self.segments[segment_idx][1]
return segment, mask
def __len__(self):
return len(self.segments)
class IterTuneCompiler(IterableDataset):
def __init__(
self,
data: List[Tuple[np.ndarray, str]],
data_type: str,
augmentor: Augmentor,
vocab: vocab_utils.LangTokenVocab,
input_length: int,
first_pred_feature: str,
caption_path: Union[str, None] = None,
for_evaluation: bool = False
):
'''
The data is distributed on-the-fly by the IterTuneCompiler.
Pitch, Chord augmentation is applied to the training data every iteration.
Segmentation is applied every epoch for the training data.
'''
super().__init__()
self.data_list = data
self.data_type = data_type
self.augmentor = augmentor
self.eos_token = vocab.eos_token
self.vocab = vocab
self.compile_function = VanillaTransformer_compiler(
data_list=self.data_list,
augmentor=self.augmentor,
eos_token=self.eos_token,
input_length=input_length,
first_pred_feature=first_pred_feature,
encoding_scheme=vocab.encoding_scheme
)
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base", legacy=False)
self.random_seed = 0
def __iter__(self):
# This will yield ([segment, mask], tune_name2segment, segment2tune_name)
generator = self.compile_function.make_segments_iters(self.data_type)
for ([segment, mask], tune_name2segment, segment2tune_name) in generator:
# print(len(segment2tune_name), len(tune_name2segment))
tune_name = segment2tune_name[-1] # Get the last tune name from the segment2tune_name list
# print(f"Processing tune: {tune_name}")
try:
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
except Exception as e:
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
segment = self.augmentor(segment)
# use input_ids replace tune_name
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
yield segment, mask, tune_name, encoded_caption
def __len__(self):
# If you want to use __len__, you need to know the number of segments in advance.
# Otherwise, you can raise an exception or return a default value.
raise NotImplementedError("IterTuneCompiler is an iterable dataset and does not support __len__.")
class SymbolicMusicDataset(Dataset):
def __init__(
self,
vocab: vocab_utils.LangTokenVocab,
encoding_scheme: str,
num_features: int,
debug: bool,
aug_type: Union[str, None],
input_length: int,
first_pred_feature: str,
caption_path: Union[str, None] = None,
for_evaluation: bool = False
):
'''
The vocabulary containing token representations for the dataset
The encoding scheme used for representing symbolic music (e.g., REMI, NB, etc.)
The number of features used for the dataset
Debug mode; limits dataset size for faster testing if enabled
Type of data augmentation to apply, if 'random' the compiler will apply pitch and chord augmentation
Length of the input sequence for each sample
Feature to predict first which is used for compound shift for NB, if not shift, 'type' is used
'''
super().__init__()
# Initializing instance variables
self.encoding_scheme = encoding_scheme
self.num_features = num_features
self.debug = debug
self.input_length = input_length
self.first_pred_feature = first_pred_feature
self.caption_path = caption_path
self.for_evaluation = for_evaluation
# Load the vocabulary passed into the constructor
self.vocab = vocab
# Initialize augmentor for data augmentation
self.augmentor = Augmentor(vocab=self.vocab, aug_type=aug_type, input_length=input_length)
# Load preprocessed tune indices
if self.for_evaluation:
# For evaluation, we load the tune indices without any augmentation
self.tune_in_idx, self.len_tunes, self.file_name_list = [], [], []
else:
self.tune_in_idx, self.len_tunes, self.file_name_list = self._load_tune_in_idx()
# Plot the histogram of tune lengths for analysis
dataset_name = self.__class__.__name__ # Get the class name (dataset name)
len_dir_path = Path(f"len_tunes/{dataset_name}") # Directory to store tune length histograms
len_dir_path.mkdir(parents=True, exist_ok=True) # Create directory if it doesn't exist
if self. for_evaluation is False:
self._plot_hist(self.len_tunes, len_dir_path / f"len_{encoding_scheme}{num_features}.png")
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
# Load preprocessed tune indices from .npz files
print("preprocessed tune_in_idx data is being loaded")
# List of files containing tune index data
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
# If debug mode is enabled, limit the number of loaded files
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
# Initialize dictionaries and lists for storing tune index data, tune lengths, and file names
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# Load tune index data from each .npz file
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
tune_in_idx = np.load(tune_in_idx_file)['arr_0'] # Load the numpy array from the file
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx # Store the tune indices in the dictionary
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) # Record the length of the tune
file_name_list.append(tune_in_idx_file.stem) # Append the file name (without extension)
return tune_in_idx_dict, len_tunes, file_name_list # Return the data structures
def _plot_hist(self, len_tunes, path_outfile):
# Plot histogram of tune lengths and save the plot
Path(path_outfile).parent.mkdir(parents=True, exist_ok=True) # Ensure the directory for the plot exists
# Convert tune lengths to a NumPy array
data = np.array(list(len_tunes.values()))
# Compute mean and standard deviation of tune lengths
self.mean_len_tunes = np.mean(data)
data_mean = np.mean(data)
data_std = np.std(data)
# cumpute the total length of all tunes
self.total_len_tunes = np.sum(data)
# Plot the histogram
plt.figure(dpi=100)
plt.hist(data, bins=50)
plt.title(f"mean: {data_mean:.2f}, std: {data_std:.2f}, total: {self.total_len_tunes}, num_tunes: {len(data)}")
plt.savefig(path_outfile) # Save the plot to file
plt.close() # Close the plot to free memory
def _get_split_list_from_tune_in_idx(self, ratio, seed):
# Split the dataset into train, validation, and test sets based on the given ratio
try:
shuffled_tune_names = list(self.tune_in_idx.keys()) # Get the list of all tune names
except:
shuffled_tune_names = []
random.seed(seed) # Set the seed for reproducibility
random.shuffle(shuffled_tune_names) # Shuffle the tune names
# Compute the number of training, validation, and test samples
num_train = int(len(shuffled_tune_names) * ratio)
num_valid = int(len(shuffled_tune_names) * (1 - ratio) / 2)
# Split the tune names into training, validation, and test sets
train_names = shuffled_tune_names[:num_train]
valid_names = shuffled_tune_names[num_train:num_train + num_valid]
test_names = shuffled_tune_names[num_train + num_valid:]
return train_names, valid_names, test_names, shuffled_tune_names # Return the split lists
def split_train_valid_test_set(self, dataset_name=None, ratio=None, seed=42, save_dir=None, for_evaluation: bool = False):
# Split the dataset into train, validation, and test sets or load an existing split
if not Path(f"metadata/{dataset_name}_caption_metadata.json").exists():
# If no metadata exists, perform a random split and save metadata
assert ratio is not None, "ratio should be given when you make metadata for split"
# Perform the split
train_names, valid_names, test_names, shuffled_tune_names = self._get_split_list_from_tune_in_idx(ratio, seed)
# Log the split information
print(f"Randomly split train and test set using seed {seed}")
out_dict = {'shuffle_seed': seed, # Seed used for shuffling
'shuffled_names': shuffled_tune_names, # Shuffled list of tune names
'train': train_names, # Training set names
'valid': valid_names, # Validation set names
'test': test_names} # Test set names
# Save the split metadata to a JSON file
with open(f"metadata/{dataset_name}_caption_metadata.json", "w") as f:
json.dump(out_dict, f, indent=2)
else:
# If metadata already exists, load it
with open(f"metadata/{dataset_name}_caption_metadata.json", "r") as f:
out_dict = json.load(f)
# Ensure that the loaded data matches the current dataset
train_names, valid_names, test_names = out_dict['train'], out_dict['valid'], out_dict['test']
if self.for_evaluation is False:
assert set(out_dict['shuffled_names']) == set(self.tune_in_idx.keys()), "Loaded data is not matched with the recorded metadata"
# Prepare training, validation, and test datasets using the TuneCompiler
if self.for_evaluation:
# For evaluation, we do not need to create train and valid datasets
train_data = []
valid_data = []
self.test_data = []
else:
train_data = [(self.tune_in_idx[tune_name], tune_name) for tune_name in train_names]
valid_data = [(self.tune_in_idx[tune_name], tune_name) for tune_name in valid_names]
self.test_data = [(self.tune_in_idx[tune_name], tune_name) for tune_name in test_names]
# Initialize TuneCompiler objects for each split
# if self.for_evaluation:
# train_dataset = None # No training dataset for evaluation
# valid_dataset = None
# test_dataset = TuneCompiler(data=self.test_data, data_type='test', augmentor=self.augmentor, vocab=self.vocab, input_length=self.input_length, first_pred_feature=self.first_pred_feature)
# else:
train_dataset = IterTuneCompiler(data=train_data, data_type='train', augmentor=self.augmentor, vocab=self.vocab, input_length=self.input_length, first_pred_feature=self.first_pred_feature)
valid_dataset = TuneCompiler(data=valid_data, data_type='valid', augmentor=self.augmentor, vocab=self.vocab, input_length=self.input_length, first_pred_feature=self.first_pred_feature)
test_dataset = TuneCompiler(data=self.test_data, data_type='test', augmentor=self.augmentor, vocab=self.vocab, input_length=self.input_length, first_pred_feature=self.first_pred_feature)
# Save metadata to a directory if specified
if save_dir is not None:
Path(save_dir).mkdir(parents=True, exist_ok=True)
with open(Path(save_dir) / f"{dataset_name}_metadata.json", "w") as f:
json.dump(out_dict, f, indent=2)
# Return the datasets for training, validation, and testing
return train_dataset, valid_dataset, test_dataset
class Pop1k7(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
class SymphonyMIDI(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
class LakhClean(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class chorus(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class Melody(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
ratio = 0.8
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class msmidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class IrishMan(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
try:
shuffled_tune_names = list(self.tune_in_idx.keys())
except:
shuffled_tune_names = []
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class ariamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
# class gigamidi(SymbolicMusicDataset):
# def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
# for_evaluation: bool = False):
# super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
# for_evaluation=for_evaluation)
# def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
# '''
# Irregular tunes are removed from the dataset for better generation quality
# It includes tunes that are not quantized properly, mostly theay are expressive performance data
# '''
# print("preprocessed tune_in_idx data is being loaded")
# tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
# if self.debug:
# tune_in_idx_list = tune_in_idx_list[:5000]
# tune_in_idx_dict = OrderedDict()
# len_tunes = OrderedDict()
# file_name_list = []
# with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
# irregular_tunes = json.load(f)
# for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# if tune_in_idx_file.stem in irregular_tunes:
# continue
# tune_in_idx = np.load(tune_in_idx_file)['arr_0']
# tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
# len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
# file_name_list.append(tune_in_idx_file.stem)
# print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
# return tune_in_idx_dict, len_tunes, file_name_list
# def _get_split_list_from_tune_in_idx(self, ratio, seed):
# '''
# As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
# '''
# shuffled_tune_names = list(self.tune_in_idx.keys())
# song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
# song_dict = {}
# for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
# if song not in song_dict:
# song_dict[song] = []
# song_dict[song].append(orig_song)
# unique_song_names = list(song_dict.keys())
# random.seed(seed)
# random.shuffle(unique_song_names)
# num_train = int(len(unique_song_names)*ratio)
# num_valid = int(len(unique_song_names)*(1-ratio)/2)
# train_names = []
# valid_names = []
# test_names = []
# for song_name in unique_song_names[:num_train]:
# train_names.extend(song_dict[song_name])
# for song_name in unique_song_names[num_train:num_train+num_valid]:
# valid_names.extend(song_dict[song_name])
# for song_name in unique_song_names[num_train+num_valid:]:
# test_names.extend(song_dict[song_name])
# return train_names, valid_names, test_names, shuffled_tune_names
class ariamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
try:
shuffled_tune_names = list(self.tune_in_idx.keys())
except:
shuffled_tune_names = []
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class gigamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
if "drums-only" in tune_in_idx_file.stem:
print(f"skipping {tune_in_idx_file.stem} as it is a drums-only file")
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class PretrainingDataset(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx_aria(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Load preprocessed tune indices for the aria dataset
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_ariamidi/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_giga(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Load preprocessed tune indices for the gigamidi dataset
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_gigamidi/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if "drums-only" in tune_in_idx_file.stem:
print(f"skipping {tune_in_idx_file.stem} as it is a drums-only file")
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_pop1k7(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Load preprocessed tune indices for the Pop1k7 dataset
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_pop1k7/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_sod(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Load preprocessed tune indices for the SOD dataset
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_SOD/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_lakh(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_LakhALLFined/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.lakh_caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem.replace("_", "/", 1) + ".mid"
location_key = f"lmd_full/{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_xmidi(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_XMIDI_Dataset/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.xmidi_caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem + ".midi"
location_key = f"{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_new(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_new_dataset/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.new_caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem + ".mid"
location_key = f"new_data_new_dataset/{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
self.lakh_caption_path = "dataset/represented_data/tuneidx/train_set.json"
self.xmidi_caption_path = "dataset/represented_data/tuneidx/all_captions.json"
self.new_caption_path = "dataset/represented_data/tuneidx/new_dataset_captions_final.jsonl"
# load all tune_in_idx data from aria, giga datasets
tune_in_idx_giga, len_tunes_giga, file_name_list_giga = self._load_tune_in_idx_giga()
tune_in_idx_aria, len_tunes_aria, file_name_list_aria = self._load_tune_in_idx_aria()
tune_in_idx_lakh, len_tunes_lakh, file_name_list_lakh = self._load_tune_in_idx_lakh()
tune_in_idx_xmidi, len_tunes_xmidi, file_name_list_xmidi = self._load_tune_in_idx_xmidi()
tune_in_idx_new, len_tunes_new, file_name_list_new = self._load_tune_in_idx_new()
# merge the two datasets
tune_in_idx = {**tune_in_idx_aria, **tune_in_idx_giga, **tune_in_idx_lakh, **tune_in_idx_xmidi, **tune_in_idx_new}
len_tunes = {**len_tunes_aria, **len_tunes_giga, **len_tunes_lakh, **len_tunes_xmidi, **len_tunes_new}
file_name_list = file_name_list_aria + file_name_list_giga + file_name_list_lakh + file_name_list_xmidi + file_name_list_new
print(f"number of loaded tunes: {len(tune_in_idx)}")
return tune_in_idx, len_tunes, file_name_list
class SOD(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/SOD_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
class BachChorale(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
class Pop909(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Pop909 dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"-v\d+$", "", tune) for tune in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class LakhALL(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
# # remove file in tune_in_idx_list
# location2caption[item["location"]] = "test_set"
# continue
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem.replace("_", "/", 1) + ".mid"
location_key = f"lmd_full/{location_key}"
try:
caption = location2caption.get(location_key, None)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
print(f"Caption for {location_key} is None, skipping this tune")
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class LakhALLFined(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
# if item["test_set"] is True:
# continue # skip test set tunes
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem.replace("_", "/", 1) + ".mid"
location_key = f"lmd_full/{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
# filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class XMIDI_Dataset(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
# if item["test_set"] is True:
# continue # skip test set tunes
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem + ".midi"
print(f"Processing file: {tune_in_idx_file.stem}, location_key: {location_key}")
location_key = f"{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
# filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class new_dataset(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
# if item["test_set"] is True:
# continue # skip test set tunes
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem.split("/")[-1] + ".mid"
location_key = f"new_data_new_dataset/{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
# filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class SymphonyNet_Dataset(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem.replace("_", "/", 1) + ".mid"
location_key = f"/data2/suhongju/research/music-generation/BandZero/SymphonyNet_Dataset/{location_key}"
try:
caption = location2caption.get(location_key, None)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
print(f"Caption for {location_key} is None, skipping this tune")
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = shuffled_tune_names
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
# use lakhAllFined, XMIDI_dataset, new_dataset, as finetune dataset
class FinetuneDataset(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx_lakh(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_LakhALLFined/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.lakh_caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem.replace("_", "/", 1) + ".mid"
location_key = f"lmd_full/{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_xmidi(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_XMIDI_Dataset/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.xmidi_caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem + ".midi"
location_key = f"{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx_new(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_new_dataset/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
# load caption
self.caption_list = []
with open(self.new_caption_path, "r") as f:
# every line is a caption for the tune
for line in f:
self.caption_list.append(line.strip())
print(f"number of loaded captions: {len(self.caption_list)}")
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
# 构建 location 到 caption 的映射
location2caption = {}
for line in self.caption_list:
try:
# 假设每行是一个json字符串
item = json.loads(line)
location2caption[item["location"]] = item["caption"]
except Exception:
continue
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
# 0_06d3f5a5954848ba13b9128f68f0a1d1 -> 0/06d3f5a5954848ba13b9128f68f0a1d1
location_key = tune_in_idx_file.stem + ".mid"
location_key = f"new_data_new_dataset/{location_key}"
try:
caption = location2caption.get(location_key)
except KeyError:
print(f"KeyError: {location_key} not found in location2caption")
continue
if caption is None:
continue
# print(tune_in_idx_file.stem, location_key, caption)
# print("*" * 20)
# 你可以在这里使用caption变量
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[caption] = tune_in_idx
len_tunes[caption] = len(tune_in_idx)
file_name_list.append(caption)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Load tune_in_idx from all three datasets
'''
self.lakh_caption_path = "dataset/represented_data/tuneidx/train_set.json"
self.xmidi_caption_path = "dataset/represented_data/tuneidx/all_captions.json"
self.new_caption_path = "dataset/represented_data/tuneidx/new_dataset_captions_final.jsonl"
tune_in_idx_lakh, len_tunes_lakh, file_name_list_lakh = self._load_tune_in_idx_lakh()
tune_in_idx_xmidi, len_tunes_xmidi, file_name_list_xmidi = self._load_tune_in_idx_xmidi()
tune_in_idx_new, len_tunes_new, file_name_list_new = self._load_tune_in_idx_new()
# 合并三个数据集
tune_in_idx = {**tune_in_idx_lakh, **tune_in_idx_xmidi, **tune_in_idx_new}
len_tunes = {**len_tunes_lakh, **len_tunes_xmidi, **len_tunes_new}
file_name_list = file_name_list_lakh + file_name_list_xmidi + file_name_list_new
print(f"number of loaded tunes: {len(tune_in_idx)}")
return tune_in_idx, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
# filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names