first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

View File

@ -0,0 +1,650 @@
import argparse
import time
import itertools
import copy
from copy import deepcopy
from pathlib import Path
from multiprocessing import Pool, cpu_count
from collections import defaultdict
from fractions import Fraction
from typing import List
import os
from muspy import sort
import numpy as np
import pickle
from tqdm import tqdm
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument
from chorder import Dechorder
from constants import NUM2PITCH, PROGRAM_INSTRUMENT_MAP, INSTRUMENT_PROGRAM_MAP
'''
This script is designed to preprocess MIDI files and convert them into a structured corpus suitable for symbolic music analysis or model training.
It handles various tasks, including setting beat resolution, calculating duration, velocity, and tempo bins, and processing MIDI data into quantized musical events.
'''
def get_tempo_bin(max_tempo:int, ratio:float=1.1):
bpm = 30
regular_tempo_bins = [bpm]
while bpm < max_tempo:
bpm *= ratio
bpm = round(bpm)
if bpm > max_tempo:
break
regular_tempo_bins.append(bpm)
return np.array(regular_tempo_bins)
def split_markers(markers:List[miditoolkit.midi.containers.Marker]):
'''
split markers into chord, tempo, label
'''
chords = []
for marker in markers:
splitted_text = marker.text.split('_')
if splitted_text[0] != 'global' and 'Boundary' not in splitted_text[0]:
chords.append(marker)
return chords
class CorpusMaker():
def __init__(
self,
dataset_name:str,
num_features:int,
in_dir:Path,
out_dir:Path,
debug:bool
):
'''
Initialize the CorpusMaker with dataset information and directory paths.
It sets up MIDI paths, output directories, and debug mode, then
retrieves the beat resolution, duration bins, velocity/tempo bins, and prepares the MIDI file list.
'''
self.dataset_name = dataset_name
self.num_features = num_features
self.midi_path = in_dir / f"{dataset_name}"
self.out_dir = out_dir
self.debug = debug
self._get_in_beat_resolution()
self._get_duration_bins()
self._get_velocity_tempo_bins()
self._get_min_max_last_time()
self._prepare_midi_list()
def _get_in_beat_resolution(self):
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
try:
self.in_beat_resolution = in_beat_resolution_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
def _get_duration_bins(self):
# Set up regular duration bins for quantizing note lengths, based on the beat resolution.
base_duration = {4:[1,2,3,4,5,6,8,10,12,16,20,24,28,32],
8:[1,2,3,4,6,8,10,12,14,16,20,24,28,32,36,40,48,56,64],
12:[1,2,3,4,6,9,12,15,18,24,30,36,42,48,54,60,72,84,96]}
base_duration_list = base_duration[self.in_beat_resolution]
self.regular_duration_bins = np.array(base_duration_list)
def _get_velocity_tempo_bins(self):
# Define velocity and tempo bins based on whether the dataset is a performance or score type.
midi_type_dict = {'BachChorale': 'score', 'Pop1k7': 'perform', 'Pop909': 'score', 'SOD': 'score', 'LakhClean': 'score', 'Symphony': 'score'}
try:
midi_type = midi_type_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
midi_type = midi_type_dict['LakhClean']
# For performance-type datasets, set finer granularity of velocity and tempo bins.
if midi_type == 'perform':
self.regular_velocity_bins = np.array(list(range(40, 128, 8)) + [127])
self.regular_tempo_bins = get_tempo_bin(max_tempo=240, ratio=1.04)
# For score-type datasets, use coarser velocity and tempo bins.
elif midi_type == 'score':
self.regular_velocity_bins = np.array([40, 60, 80, 100, 120])
self.regular_tempo_bins = get_tempo_bin(max_tempo=390, ratio=1.04)
def _get_min_max_last_time(self):
'''
Set the minimum and maximum allowed length of a MIDI track, depending on the dataset.
0 to 2000 means no limitation
'''
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)}
try:
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
self.min_last_time, self.max_last_time = last_time_dict['LakhClean']
def _prepare_midi_list(self):
midi_path = Path(self.midi_path)
# detect subdirectories and get all midi files
if not midi_path.exists():
raise ValueError(f"midi_path {midi_path} does not exist")
# go though all subdirectories and get all midi files
midi_files = []
for root, _, files in os.walk(midi_path):
for file in files:
if file.endswith('.mid'):
# print(Path(root) / file)
midi_files.append(Path(root) / file)
self.midi_list = midi_files
print(f"Found {len(self.midi_list)} MIDI files in {midi_path}")
def make_corpus(self) -> None:
'''
Main method to process the MIDI files and create the corpus data.
It supports both single-processing (debug mode) and multi-processing for large datasets.
'''
print("preprocessing midi data to corpus data")
# check the corpus folder is already exist and make it if not
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
Path(self.out_dir / f"corpus_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
Path(self.out_dir / f"midi_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
start_time = time.time()
if self.debug:
# single processing for debugging
broken_counter = 0
success_counter = 0
for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
message = self._mp_midi2corpus(file_path)
if message == "error":
broken_counter += 1
elif message == "success":
success_counter += 1
else:
# Multi-threaded processing for faster corpus generation.
broken_counter = 0
success_counter = 0
# filter out processed files
print(self.out_dir)
processed_files = list(Path(self.out_dir).glob(f"midi_{self.dataset_name}/*.mid"))
processed_files = [x.name for x in processed_files]
print(f"processed files: {len(processed_files)}")
print("length of midi list: ", len(self.midi_list))
# Use set for faster lookup (O(1) per check)
processed_files_set = set(processed_files)
self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
# reverse the list to process the latest files first
self.midi_list.reverse()
print(f"length of midi list after filtering: ", len(self.midi_list))
with Pool(16) as p:
for message in tqdm(p.imap(self._mp_midi2corpus, self.midi_list, 1000), total=len(self.midi_list)):
if message == "error":
broken_counter += 1
elif message == "success":
success_counter += 1
# for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
# message = self._mp_midi2corpus(file_path)
# if message == "error":
# broken_counter += 1
# elif message == "success":
# success_counter += 1
print(f"Making corpus takes: {time.time() - start_time}s, success: {success_counter}, broken: {broken_counter}")
def _mp_midi2corpus(self, file_path: Path):
"""Convert MIDI to corpus format and save both corpus (.pkl) and MIDI (.mid)."""
try:
midi_obj = self._analyze(file_path)
corpus, midi_obj = self._midi2corpus(midi_obj)
# --- 1. Save corpus (.pkl) ---
relative_path = file_path.relative_to(self.midi_path) # Get relative path from input dir
safe_name = str(relative_path).replace("/", "_").replace("\\", "_").replace(".mid", ".pkl")
save_path = Path(self.out_dir) / f"corpus_{self.dataset_name}" / safe_name
save_path.parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
with save_path.open("wb") as f:
pickle.dump(corpus, f)
# --- 2. Save MIDI (.mid) ---
midi_save_dir = Path("../dataset/represented_data/corpus") / f"midi_{self.dataset_name}"
midi_save_dir.mkdir(parents=True, exist_ok=True)
midi_save_path = midi_save_dir / file_path.name # Keep original MIDI filename
midi_obj.dump(midi_save_path)
del midi_obj, corpus
return "success"
except (OSError, EOFError, ValueError, KeyError, AssertionError) as e:
print(f"Error processing {file_path.name}: {e}")
return "error"
except Exception as e:
print(f"Unexpected error in {file_path.name}: {e}")
return "error"
def _check_length(self, last_time:float):
if last_time < self.min_last_time:
raise ValueError(f"last time {last_time} is out of range")
def _analyze(self, midi_path:Path):
# Loads and analyzes a MIDI file, performing various checks and extracting chords.
midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
# check length
mapping = midi_obj.get_tick_to_time_mapping()
last_time = mapping[midi_obj.max_tick]
self._check_length(last_time)
for ins in midi_obj.instruments:
# delete instrument with no notes
if len(ins.notes) == 0:
midi_obj.instruments.remove(ins)
continue
notes = ins.notes
notes = sorted(notes, key=lambda x: (x.start, x.pitch))
# three steps to merge instruments
self._merge_percussion(midi_obj)
self._pruning_instrument(midi_obj)
self._limit_max_track(midi_obj)
if self.num_features == 7 or self.num_features == 8:
# in case of 7 or 8 features, we need to extract chords
new_midi_obj = self._pruning_notes_for_chord_extraction(midi_obj)
chords = Dechorder.dechord(new_midi_obj)
markers = []
for cidx, chord in enumerate(chords):
if chord.is_complete():
chord_text = NUM2PITCH[chord.root_pc] + '_' + chord.quality + '_' + NUM2PITCH[chord.bass_pc]
else:
chord_text = 'N_N_N'
markers.append(Marker(time=int(cidx*new_midi_obj.ticks_per_beat), text=chord_text))
# de-duplication
prev_chord = None
dedup_chords = []
for m in markers:
if m.text != prev_chord:
prev_chord = m.text
dedup_chords.append(m)
# return midi
midi_obj.markers = dedup_chords
return midi_obj
def _pruning_grouped_notes_from_quantization(self, instr_grid:dict):
'''
In case where notes are grouped in the same quant_time but with different start time, unintentional chords are created
rule1: if notes have half step interval, delete the shorter one
rule2: if notes do not share 50% of duration of the shorter note, delete the shorter one
'''
for instr in instr_grid.keys():
time_list = sorted(list(instr_grid[instr].keys()))
for time in time_list:
notes = instr_grid[instr][time]
if len(notes) == 1:
continue
else:
new_notes = []
# sort in pitch with ascending order
notes.sort(key=lambda x: x.pitch)
for i in range(len(notes)-1):
# if start time is same add to new_notes
if notes[i].start == notes[i+1].start:
new_notes.append(notes[i])
new_notes.append(notes[i+1])
continue
if notes[i].pitch == notes[i+1].pitch or notes[i].pitch + 1 == notes[i+1].pitch:
# select longer note
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
new_notes.append(notes[i])
else:
new_notes.append(notes[i+1])
else:
# check how much duration they share
shared_duration = min(notes[i].end, notes[i+1].end) - max(notes[i].start, notes[i+1].start)
shorter_duration = min(notes[i].end - notes[i].start, notes[i+1].end - notes[i+1].start)
# unless they share more than 80% of duration, select longer note (pruning shorter note)
if shared_duration / shorter_duration < 0.8:
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
new_notes.append(notes[i])
else:
new_notes.append(notes[i+1])
else:
if len(new_notes) == 0:
new_notes.append(notes[i])
new_notes.append(notes[i+1])
else:
new_notes.append(notes[i+1])
instr_grid[instr][time] = new_notes
def _midi2corpus(self, midi_obj:miditoolkit.midi.parser.MidiFile):
# Checks if the ticks per beat in the MIDI file is lower than the expected resolution.
# If it is, raise an error.
if midi_obj.ticks_per_beat < self.in_beat_resolution:
raise ValueError(f'[x] Irregular ticks_per_beat. {midi_obj.ticks_per_beat}')
# Ensure there is at least one time signature change in the MIDI file.
# if len(midi_obj.time_signature_changes) == 0:
# raise ValueError('[x] No time_signature_changes')
# Ensure there are no duplicated time signature changes.
# time_list = [ts.time for ts in midi_obj.time_signature_changes]
# if len(time_list) != len(set(time_list)):
# raise ValueError('[x] Duplicated time_signature_changes')
# If the dataset is 'LakhClean' or 'SymphonyMIDI', verify there are at least 4 tracks.
# if self.dataset_name == 'LakhClean' or self.dataset_name == 'SymphonyMIDI':
# if len(midi_obj.instruments) < 4:
# raise ValueError('[x] We will use more than 4 tracks in Lakh Clean dataset.')
# Calculate the resolution of ticks per beat as a fraction.
in_beat_tick_resol = Fraction(midi_obj.ticks_per_beat, self.in_beat_resolution)
# Extract the initial time signature (numerator and denominator) and calculate the number of ticks for the first bar.
if len(midi_obj.time_signature_changes) != 0:
initial_numerator = midi_obj.time_signature_changes[0].numerator
initial_denominator = midi_obj.time_signature_changes[0].denominator
else:
# If no time signature changes, set default values
initial_numerator = 4
initial_denominator = 4
first_bar_resol = int(midi_obj.ticks_per_beat * initial_numerator * (4 / initial_denominator))
# --- load notes --- #
instr_notes = self._make_instr_notes(midi_obj)
# --- load information --- #
# load chords, labels
chords = split_markers(midi_obj.markers)
chords.sort(key=lambda x: x.time)
# load tempos
tempos = midi_obj.tempo_changes if len(midi_obj.tempo_changes) > 0 else []
if len(tempos) == 0:
# if no tempo changes, set the default tempo to 120 BPM
tempos = [miditoolkit.midi.containers.TempoChange(time=0, tempo=120)]
tempos.sort(key=lambda x: x.time)
# --- process items to grid --- #
# compute empty bar offset at head
first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
quant_time_first = int(round(first_note_time / in_beat_tick_resol)) * in_beat_tick_resol
offset = quant_time_first // first_bar_resol # empty bar
offset_by_resol = offset * first_bar_resol
# --- process notes --- #
instr_grid = dict()
for key in instr_notes.keys():
notes = instr_notes[key]
note_grid = defaultdict(list)
for note in notes:
# skip notes out of range, below C-1 and above C8
if note.pitch < 12 or note.pitch >= 120:
continue
# in case when the first note starts at slightly before the first bar
note.start = note.start - offset_by_resol if note.start - offset_by_resol > 0 else 0
note.end = note.end - offset_by_resol if note.end - offset_by_resol > 0 else 0
# relative duration
# skip note with 0 duration
note_duration = note.end - note.start
relative_duration = round(note_duration / in_beat_tick_resol)
if relative_duration == 0:
continue
if relative_duration > self.in_beat_resolution * 8: # 8 beats
relative_duration = self.in_beat_resolution * 8
# use regular duration bins
note.quantized_duration = self.regular_duration_bins[np.argmin(abs(self.regular_duration_bins-relative_duration))]
# quantize start time
quant_time = int(round(note.start / in_beat_tick_resol)) * in_beat_tick_resol
# velocity
note.velocity = self.regular_velocity_bins[
np.argmin(abs(self.regular_velocity_bins-note.velocity))]
# append
note_grid[quant_time].append(note)
# set to track
instr_grid[key] = note_grid
# --- pruning grouped notes --- #
self._pruning_grouped_notes_from_quantization(instr_grid)
# --- process chords --- #
chord_grid = defaultdict(list)
for chord in chords:
# quantize
chord.time = chord.time - offset_by_resol
chord.time = 0 if chord.time < 0 else chord.time
quant_time = int(round(chord.time / in_beat_tick_resol)) * in_beat_tick_resol
chord_grid[quant_time].append(chord)
# --- process tempos --- #
first_notes_list = []
for instr in instr_grid.keys():
time_list = sorted(list(instr_grid[instr].keys()))
if len(time_list) == 0: # 跳过空轨道
continue
first_quant_time = time_list[0]
first_notes_list.append(first_quant_time)
# 处理全空情况
if not first_notes_list:
raise ValueError("[x] No valid notes found in any instrument track.")
quant_first_note_time = min(first_notes_list)
tempo_grid = defaultdict(list)
for tempo in tempos:
# quantize
tempo.time = tempo.time - offset_by_resol if tempo.time - offset_by_resol > 0 else 0
quant_time = int(round(tempo.time / in_beat_tick_resol)) * in_beat_tick_resol
tempo.tempo = self.regular_tempo_bins[
np.argmin(abs(self.regular_tempo_bins-tempo.tempo))]
if quant_time < quant_first_note_time:
tempo_grid[quant_first_note_time].append(tempo)
else:
tempo_grid[quant_time].append(tempo)
if len(tempo_grid[quant_first_note_time]) > 1:
tempo_grid[quant_first_note_time] = [tempo_grid[quant_first_note_time][-1]]
# --- process time signature --- #
quant_time_signature = deepcopy(midi_obj.time_signature_changes)
quant_time_signature.sort(key=lambda x: x.time)
for ts in quant_time_signature:
ts.time = ts.time - offset_by_resol if ts.time - offset_by_resol > 0 else 0
ts.time = int(round(ts.time / in_beat_tick_resol)) * in_beat_tick_resol
# --- make new midi object to check processed values --- #
new_midi_obj = miditoolkit.midi.parser.MidiFile()
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
new_midi_obj.max_tick = midi_obj.max_tick
for instr_idx in instr_grid.keys():
new_instrument = Instrument(program=instr_idx)
new_instrument.notes = [y for x in instr_grid[instr_idx].values() for y in x]
new_midi_obj.instruments.append(new_instrument)
new_midi_obj.markers = [y for x in chord_grid.values() for y in x]
new_midi_obj.tempo_changes = [y for x in tempo_grid.values() for y in x]
new_midi_obj.time_signature_changes = midi_obj.time_signature_changes
# make corpus
song_data = {
'notes': instr_grid,
'chords': chord_grid,
'tempos': tempo_grid,
'metadata': {
'first_note': first_note_time,
'last_note': last_note_time,
'time_signature': quant_time_signature,
'ticks_per_beat': midi_obj.ticks_per_beat,
}
}
return song_data, new_midi_obj
def _make_instr_notes(self, midi_obj):
'''
This part is important, we can use three different ways to merge instruments
1st option: compare the number of notes and choose tracks with more notes
2nd option: merge all instruments with the same tracks
3rd option: leave all instruments as they are. differentiate tracks with different track number
In this version we choose to use the 2nd option as it helps to reduce the number of tracks and sequence length
'''
instr_notes = defaultdict(list)
for instr in midi_obj.instruments:
instr_idx = instr.program
# change instrument idx
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
if instr_name is None:
continue
new_instr_idx = INSTRUMENT_PROGRAM_MAP[instr_name]
instr_notes[new_instr_idx].extend(instr.notes)
instr_notes[new_instr_idx].sort(key=lambda x: (x.start, -x.pitch))
return instr_notes
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
def _merge_percussion(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
merge drum track to one track
'''
drum_0_lst = []
new_instruments = []
for instrument in midi_obj.instruments:
if len(instrument.notes) == 0:
continue
if instrument.is_drum:
drum_0_lst.extend(instrument.notes)
else:
new_instruments.append(instrument)
if len(drum_0_lst) > 0:
drum_0_lst.sort(key=lambda x: x.start)
# remove duplicate
drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst))
drum_0_instrument = Instrument(program=114, is_drum=True, name="percussion")
drum_0_instrument.notes = drum_0_lst
new_instruments.append(drum_0_instrument)
midi_obj.instruments = new_instruments
# referred to mmt "https://github.com/salu133445/mmt"
def _pruning_instrument(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
merge instrument number with similar intrument category
ex. 0: Acoustic Grand Piano, 1: Bright Acoustic Piano, 2: Electric Grand Piano into 0: Acoustic Grand Piano
'''
new_instruments = []
for instr in midi_obj.instruments:
instr_idx = instr.program
# change instrument idx
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
if instr_name != None:
new_instruments.append(instr)
midi_obj.instruments = new_instruments
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
def _limit_max_track(self, midi_obj:miditoolkit.midi.parser.MidiFile, MAX_TRACK:int=16):
'''
merge track with least notes to other track with same program
and limit the maximum amount of track to 16
'''
if len(midi_obj.instruments) == 1:
if midi_obj.instruments[0].is_drum:
midi_obj.instruments[0].program = 114
midi_obj.instruments[0].is_drum = False
return midi_obj
good_instruments = midi_obj.instruments
good_instruments.sort(
key=lambda x: (not x.is_drum, -len(x.notes))) # place drum track or the most note track at first
assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len(
good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3])
# assert good_instruments[0].is_drum == False, (, len(good_instruments[2]))
track_idx_lst = list(range(len(good_instruments)))
if len(good_instruments) > MAX_TRACK:
new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK])
# print(midi_file_path)
for id in track_idx_lst[MAX_TRACK:]:
cur_ins = good_instruments[id]
merged = False
new_good_instruments.sort(key=lambda x: len(x.notes))
for nid, ins in enumerate(new_good_instruments):
if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum:
new_good_instruments[nid].notes.extend(cur_ins.notes)
merged = True
break
if not merged:
pass
good_instruments = new_good_instruments
assert len(good_instruments) <= MAX_TRACK, len(good_instruments)
for idx, good_instrument in enumerate(good_instruments):
if good_instrument.is_drum:
good_instruments[idx].program = 114
good_instruments[idx].is_drum = False
midi_obj.instruments = good_instruments
def _pruning_notes_for_chord_extraction(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
extract notes for chord extraction
'''
new_midi_obj = miditoolkit.midi.parser.MidiFile()
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
new_midi_obj.max_tick = midi_obj.max_tick
new_instrument = Instrument(program=0, is_drum=False, name="for_chord")
new_instruments = []
new_notes = []
for instrument in midi_obj.instruments:
if instrument.program == 114 or instrument.is_drum: # pass drum track
continue
valid_notes = [note for note in instrument.notes if note.pitch >= 21 and note.pitch <= 108]
new_notes.extend(valid_notes)
new_notes.sort(key=lambda x: x.start)
new_instrument.notes = new_notes
new_instruments.append(new_instrument)
new_midi_obj.instruments = new_instruments
return new_midi_obj
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dataset",
required=True,
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
type=str,
help="dataset names",
)
parser.add_argument(
"-f",
"--num_features",
required=True,
choices=(4, 5, 7, 8),
type=int,
help="number of features",
)
parser.add_argument(
"-i",
"--in_dir",
default="../dataset/",
type=Path,
help="input data directory",
)
parser.add_argument(
"-o",
"--out_dir",
default="../dataset/represented_data/corpus/",
type=Path,
help="output data directory",
)
parser.add_argument(
"--debug",
action="store_true",
help="enable debug mode",
)
return parser
def main():
parser = get_argument_parser()
args = parser.parse_args()
corpus_maker = CorpusMaker(args.dataset, args.num_features, args.in_dir, args.out_dir, args.debug)
corpus_maker.make_corpus()
if __name__ == "__main__":
main()
# python3 step1_midi2corpus.py --dataset SOD --num_features 5
# python3 step2_corpus2event.py --dataset LakhClean --num_features 5 --encoding nb
# python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb
# python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb