first commit
This commit is contained in:
147
data_representation/step2_corpus2event.py
Normal file
147
data_representation/step2_corpus2event.py
Normal file
@ -0,0 +1,147 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
|
||||
import encoding_utils
|
||||
|
||||
'''
|
||||
This script is for converting corpus data to event data.
|
||||
'''
|
||||
|
||||
class Corpus2Event():
|
||||
def __init__(
|
||||
self,
|
||||
dataset: str,
|
||||
encoding_scheme: str,
|
||||
num_features: int,
|
||||
in_dir: Path,
|
||||
out_dir: Path,
|
||||
debug: bool,
|
||||
cache: bool,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.encoding_name = encoding_scheme + str(num_features)
|
||||
self.in_dir = in_dir / f"corpus_{self.dataset}"
|
||||
self.out_dir = out_dir / f"events_{self.dataset}" / self.encoding_name
|
||||
self.debug = debug
|
||||
self.cache = cache
|
||||
self.encoding_function = getattr(encoding_utils, f'Corpus2event_{encoding_scheme}')(num_features)
|
||||
self._get_in_beat_resolution()
|
||||
|
||||
def _get_in_beat_resolution(self):
|
||||
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
|
||||
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
|
||||
try:
|
||||
self.in_beat_resolution = in_beat_resolution_dict[self.dataset]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset} is not supported. use the setting of LakhClean")
|
||||
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
|
||||
|
||||
def make_events(self):
|
||||
'''
|
||||
Preprocess corpus data to events data.
|
||||
The process in each encoding scheme is different.
|
||||
Please refer to encoding_utils.py for more details.
|
||||
'''
|
||||
print("preprocessing corpus data to events data")
|
||||
# check output directory exists
|
||||
self.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
start_time = time.time()
|
||||
# single-processing
|
||||
broken_count = 0
|
||||
success_count = 0
|
||||
corpus_list = sorted(list(self.in_dir.rglob("*.pkl")))
|
||||
if corpus_list == []:
|
||||
print(f"No corpus files found in {self.in_dir}. Please check the directory.")
|
||||
corpus_list = sorted(list(self.in_dir.glob("*.pkli")))
|
||||
# remove the corpus files that are already in the out_dir
|
||||
# Use set for faster existence checks
|
||||
existing_files = set(f.name for f in self.out_dir.glob("*.pkl"))
|
||||
# corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
|
||||
for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)):
|
||||
if event is None:
|
||||
broken_count += 1
|
||||
continue
|
||||
# if using cache, check if the event file already exists
|
||||
if self.cache and (self.out_dir / filepath_name).exists():
|
||||
# print(f"event file {filepath_name} already exists, skipping")
|
||||
continue
|
||||
with open(self.out_dir / filepath_name, 'wb') as f:
|
||||
pickle.dump(event, f)
|
||||
success_count += 1
|
||||
del event
|
||||
print(f"taken time for making events is {time.time()-start_time}s, success: {success_count}, broken: {broken_count}")
|
||||
|
||||
def _load_single_corpus_and_make_event(self, file_path):
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
corpus = pickle.load(f)
|
||||
event = self.encoding_function(corpus, self.in_beat_resolution)
|
||||
except Exception as e:
|
||||
print(f"error in encoding {file_path}: {e}")
|
||||
event = None
|
||||
return file_path.name, event
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
required=True,
|
||||
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
||||
type=str,
|
||||
help="dataset names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--encoding",
|
||||
required=True,
|
||||
choices=("remi", "cp", "nb", "remi_pos"),
|
||||
type=str,
|
||||
help="encoding scheme",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--num_features",
|
||||
required=True,
|
||||
choices=(4, 5, 7, 8),
|
||||
type=int,
|
||||
help="number of features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--in_dir",
|
||||
default="../dataset/represented_data/corpus/",
|
||||
type=Path,
|
||||
help="input data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--out_dir",
|
||||
default="../dataset/represented_data/events/",
|
||||
type=Path,
|
||||
help="output data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="enable debug mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache",
|
||||
action="store_true",
|
||||
help="enable cache mode",
|
||||
)
|
||||
return parser
|
||||
|
||||
def main():
|
||||
args = get_argument_parser().parse_args()
|
||||
corpus2event = Corpus2Event(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug, args.cache)
|
||||
corpus2event.make_events()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user