148 lines
4.6 KiB
Python
148 lines
4.6 KiB
Python
import argparse
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pickle
|
|
from tqdm import tqdm
|
|
from multiprocessing import Pool
|
|
|
|
import encoding_utils
|
|
|
|
'''
|
|
This script is for converting corpus data to event data.
|
|
'''
|
|
|
|
class Corpus2Event():
|
|
def __init__(
|
|
self,
|
|
dataset: str,
|
|
encoding_scheme: str,
|
|
num_features: int,
|
|
in_dir: Path,
|
|
out_dir: Path,
|
|
debug: bool,
|
|
cache: bool,
|
|
):
|
|
self.dataset = dataset
|
|
self.encoding_name = encoding_scheme + str(num_features)
|
|
self.in_dir = in_dir / f"corpus_{self.dataset}"
|
|
self.out_dir = out_dir / f"events_{self.dataset}" / self.encoding_name
|
|
self.debug = debug
|
|
self.cache = cache
|
|
self.encoding_function = getattr(encoding_utils, f'Corpus2event_{encoding_scheme}')(num_features)
|
|
self._get_in_beat_resolution()
|
|
|
|
def _get_in_beat_resolution(self):
|
|
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
|
|
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
|
|
try:
|
|
self.in_beat_resolution = in_beat_resolution_dict[self.dataset]
|
|
except KeyError:
|
|
print(f"Dataset {self.dataset} is not supported. use the setting of LakhClean")
|
|
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
|
|
|
|
def make_events(self):
|
|
'''
|
|
Preprocess corpus data to events data.
|
|
The process in each encoding scheme is different.
|
|
Please refer to encoding_utils.py for more details.
|
|
'''
|
|
print("preprocessing corpus data to events data")
|
|
# check output directory exists
|
|
self.out_dir.mkdir(parents=True, exist_ok=True)
|
|
start_time = time.time()
|
|
# single-processing
|
|
broken_count = 0
|
|
success_count = 0
|
|
corpus_list = sorted(list(self.in_dir.rglob("*.pkl")))
|
|
if corpus_list == []:
|
|
print(f"No corpus files found in {self.in_dir}. Please check the directory.")
|
|
corpus_list = sorted(list(self.in_dir.glob("*.pkli")))
|
|
# remove the corpus files that are already in the out_dir
|
|
# Use set for faster existence checks
|
|
existing_files = set(f.name for f in self.out_dir.glob("*.pkl"))
|
|
corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
|
|
for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)):
|
|
if event is None:
|
|
broken_count += 1
|
|
continue
|
|
# if using cache, check if the event file already exists
|
|
if self.cache and (self.out_dir / filepath_name).exists():
|
|
# print(f"event file {filepath_name} already exists, skipping")
|
|
continue
|
|
with open(self.out_dir / filepath_name, 'wb') as f:
|
|
pickle.dump(event, f)
|
|
success_count += 1
|
|
del event
|
|
print(f"taken time for making events is {time.time()-start_time}s, success: {success_count}, broken: {broken_count}")
|
|
|
|
def _load_single_corpus_and_make_event(self, file_path):
|
|
try:
|
|
with open(file_path, 'rb') as f:
|
|
corpus = pickle.load(f)
|
|
event = self.encoding_function(corpus, self.in_beat_resolution)
|
|
except Exception as e:
|
|
print(f"error in encoding {file_path}: {e}")
|
|
event = None
|
|
return file_path.name, event
|
|
|
|
def get_argument_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-d",
|
|
"--dataset",
|
|
required=True,
|
|
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
|
type=str,
|
|
help="dataset names",
|
|
)
|
|
parser.add_argument(
|
|
"-e",
|
|
"--encoding",
|
|
required=True,
|
|
choices=("remi", "cp", "nb", "remi_pos"),
|
|
type=str,
|
|
help="encoding scheme",
|
|
)
|
|
parser.add_argument(
|
|
"-f",
|
|
"--num_features",
|
|
required=True,
|
|
choices=(4, 5, 7, 8),
|
|
type=int,
|
|
help="number of features",
|
|
)
|
|
parser.add_argument(
|
|
"-i",
|
|
"--in_dir",
|
|
default="../dataset/represented_data/corpus/",
|
|
type=Path,
|
|
help="input data directory",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--out_dir",
|
|
default="../dataset/represented_data/events/",
|
|
type=Path,
|
|
help="output data directory",
|
|
)
|
|
parser.add_argument(
|
|
"--debug",
|
|
action="store_true",
|
|
help="enable debug mode",
|
|
)
|
|
parser.add_argument(
|
|
"--cache",
|
|
action="store_true",
|
|
help="enable cache mode",
|
|
)
|
|
return parser
|
|
|
|
def main():
|
|
args = get_argument_parser().parse_args()
|
|
corpus2event = Corpus2Event(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug, args.cache)
|
|
corpus2event.make_events()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|