654 lines
27 KiB
Python
654 lines
27 KiB
Python
import argparse
|
|
import time
|
|
import itertools
|
|
import copy
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from multiprocessing import Pool, cpu_count
|
|
from collections import defaultdict
|
|
from fractions import Fraction
|
|
from typing import List
|
|
import os
|
|
from muspy import sort
|
|
import numpy as np
|
|
import pickle
|
|
from tqdm import tqdm
|
|
|
|
import miditoolkit
|
|
from miditoolkit.midi.containers import Marker, Instrument
|
|
from chorder import Dechorder
|
|
|
|
from constants import NUM2PITCH,FINED_PROGRAM_INSTRUMENT_MAP, INSTRUMENT_PROGRAM_MAP
|
|
|
|
'''
|
|
This script is designed to preprocess MIDI files and convert them into a structured corpus suitable for symbolic music analysis or model training.
|
|
It handles various tasks, including setting beat resolution, calculating duration, velocity, and tempo bins, and processing MIDI data into quantized musical events.
|
|
We dont do instrument merging here.
|
|
'''
|
|
|
|
def get_tempo_bin(max_tempo:int, ratio:float=1.1):
|
|
bpm = 30
|
|
regular_tempo_bins = [bpm]
|
|
while bpm < max_tempo:
|
|
bpm *= ratio
|
|
bpm = round(bpm)
|
|
if bpm > max_tempo:
|
|
break
|
|
regular_tempo_bins.append(bpm)
|
|
return np.array(regular_tempo_bins)
|
|
|
|
def split_markers(markers:List[miditoolkit.midi.containers.Marker]):
|
|
'''
|
|
split markers into chord, tempo, label
|
|
'''
|
|
chords = []
|
|
for marker in markers:
|
|
splitted_text = marker.text.split('_')
|
|
if splitted_text[0] != 'global' and 'Boundary' not in splitted_text[0]:
|
|
chords.append(marker)
|
|
return chords
|
|
|
|
class CorpusMaker():
|
|
def __init__(
|
|
self,
|
|
dataset_name:str,
|
|
num_features:int,
|
|
in_dir:Path,
|
|
out_dir:Path,
|
|
debug:bool
|
|
):
|
|
'''
|
|
Initialize the CorpusMaker with dataset information and directory paths.
|
|
It sets up MIDI paths, output directories, and debug mode, then
|
|
retrieves the beat resolution, duration bins, velocity/tempo bins, and prepares the MIDI file list.
|
|
'''
|
|
self.dataset_name = dataset_name
|
|
self.num_features = num_features
|
|
self.midi_path = in_dir / f"{dataset_name}"
|
|
self.out_dir = out_dir
|
|
self.debug = debug
|
|
self._get_in_beat_resolution()
|
|
self._get_duration_bins()
|
|
self._get_velocity_tempo_bins()
|
|
self._get_min_max_last_time()
|
|
self._prepare_midi_list()
|
|
|
|
def _get_in_beat_resolution(self):
|
|
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
|
|
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
|
|
try:
|
|
self.in_beat_resolution = in_beat_resolution_dict[self.dataset_name]
|
|
except KeyError:
|
|
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
|
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
|
|
|
|
def _get_duration_bins(self):
|
|
# Set up regular duration bins for quantizing note lengths, based on the beat resolution.
|
|
base_duration = {4:[1,2,3,4,5,6,8,10,12,16,20,24,28,32],
|
|
8:[1,2,3,4,6,8,10,12,14,16,20,24,28,32,36,40,48,56,64],
|
|
12:[1,2,3,4,6,9,12,15,18,24,30,36,42,48,54,60,72,84,96]}
|
|
base_duration_list = base_duration[self.in_beat_resolution]
|
|
self.regular_duration_bins = np.array(base_duration_list)
|
|
|
|
def _get_velocity_tempo_bins(self):
|
|
# Define velocity and tempo bins based on whether the dataset is a performance or score type.
|
|
midi_type_dict = {'BachChorale': 'score', 'Pop1k7': 'perform', 'Pop909': 'score', 'SOD': 'score', 'LakhClean': 'score', 'Symphony': 'score'}
|
|
try:
|
|
midi_type = midi_type_dict[self.dataset_name]
|
|
except KeyError:
|
|
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
|
midi_type = midi_type_dict['LakhClean']
|
|
# For performance-type datasets, set finer granularity of velocity and tempo bins.
|
|
if midi_type == 'perform':
|
|
self.regular_velocity_bins = np.array(list(range(40, 128, 8)) + [127])
|
|
self.regular_tempo_bins = get_tempo_bin(max_tempo=240, ratio=1.04)
|
|
# For score-type datasets, use coarser velocity and tempo bins.
|
|
elif midi_type == 'score':
|
|
self.regular_velocity_bins = np.array([40, 60, 80, 100, 120])
|
|
self.regular_tempo_bins = get_tempo_bin(max_tempo=390, ratio=1.04)
|
|
|
|
def _get_min_max_last_time(self):
|
|
'''
|
|
Set the minimum and maximum allowed length of a MIDI track, depending on the dataset.
|
|
0 to 2000 means no limitation
|
|
'''
|
|
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
|
|
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (8, 3000), 'Symphony': (60, 1500)}
|
|
try:
|
|
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
|
|
except KeyError:
|
|
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
|
self.min_last_time, self.max_last_time = last_time_dict['LakhClean']
|
|
|
|
def _prepare_midi_list(self):
|
|
midi_path = Path(self.midi_path)
|
|
# detect subdirectories and get all midi files
|
|
if not midi_path.exists():
|
|
raise ValueError(f"midi_path {midi_path} does not exist")
|
|
# go though all subdirectories and get all midi files
|
|
midi_files = []
|
|
for root, _, files in os.walk(midi_path):
|
|
for file in files:
|
|
if file.endswith('.mid') or file.endswith('.midi') or file.endswith('.MID'):
|
|
# print(Path(root) / file)
|
|
midi_files.append(Path(root) / file)
|
|
self.midi_list = midi_files
|
|
print(f"Found {len(self.midi_list)} MIDI files in {midi_path}")
|
|
|
|
def make_corpus(self) -> None:
|
|
'''
|
|
Main method to process the MIDI files and create the corpus data.
|
|
It supports both single-processing (debug mode) and multi-processing for large datasets.
|
|
'''
|
|
print("preprocessing midi data to corpus data")
|
|
# check the corpus folder is already exist and make it if not
|
|
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
|
|
Path(self.out_dir / f"corpus_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
|
|
Path(self.out_dir / f"midi_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
|
|
start_time = time.time()
|
|
if self.debug:
|
|
# single processing for debugging
|
|
broken_counter = 0
|
|
success_counter = 0
|
|
for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
|
|
message = self._mp_midi2corpus(file_path)
|
|
if message == "error":
|
|
broken_counter += 1
|
|
elif message == "success":
|
|
success_counter += 1
|
|
else:
|
|
# Multi-threaded processing for faster corpus generation.
|
|
broken_counter = 0
|
|
success_counter = 0
|
|
# filter out processed files
|
|
print(self.out_dir)
|
|
processed_files = list(Path(self.out_dir).glob(f"midi_{self.dataset_name}/*.mid"))
|
|
processed_files = [x.name for x in processed_files]
|
|
print(f"processed files: {len(processed_files)}")
|
|
print("length of midi list: ", len(self.midi_list))
|
|
# Use set for faster lookup (O(1) per check)
|
|
processed_files_set = set(processed_files)
|
|
self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
|
|
# reverse the list to process the latest files first
|
|
self.midi_list.reverse()
|
|
print(f"length of midi list after filtering: ", len(self.midi_list))
|
|
with Pool(16) as p:
|
|
for message in tqdm(p.imap(self._mp_midi2corpus, self.midi_list, 500), total=len(self.midi_list)):
|
|
if message == "error":
|
|
broken_counter += 1
|
|
elif message == "success":
|
|
success_counter += 1
|
|
# for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
|
|
# message = self._mp_midi2corpus(file_path)
|
|
# if message == "error":
|
|
# broken_counter += 1
|
|
# elif message == "success":
|
|
# success_counter += 1
|
|
print(f"Making corpus takes: {time.time() - start_time}s, success: {success_counter}, broken: {broken_counter}")
|
|
|
|
def _mp_midi2corpus(self, file_path: Path):
|
|
"""Convert MIDI to corpus format and save both corpus (.pkl) and MIDI (.mid)."""
|
|
try:
|
|
midi_obj = self._analyze(file_path)
|
|
corpus, midi_obj = self._midi2corpus(midi_obj)
|
|
# --- 1. Save corpus (.pkl) ---
|
|
relative_path = file_path.relative_to(self.midi_path) # Get relative path from input dir
|
|
safe_name = str(relative_path).replace("/", "_").replace("\\", "_").replace(".mid", ".pkl")
|
|
save_path = Path(self.out_dir) / f"corpus_{self.dataset_name}" / safe_name
|
|
save_path.parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
|
|
with save_path.open("wb") as f:
|
|
pickle.dump(corpus, f)
|
|
|
|
# --- 2. Save MIDI (.mid) ---
|
|
midi_save_dir = Path("../dataset/represented_data/corpus") / f"midi_{self.dataset_name}"
|
|
midi_save_dir.mkdir(parents=True, exist_ok=True)
|
|
midi_save_path = midi_save_dir / file_path.name # Keep original MIDI filename
|
|
midi_obj.dump(midi_save_path)
|
|
|
|
del midi_obj, corpus
|
|
return "success"
|
|
|
|
except (OSError, EOFError, ValueError, KeyError, AssertionError) as e:
|
|
print(f"Error processing {file_path.name}: {e}")
|
|
return "error"
|
|
except Exception as e:
|
|
print(f"Unexpected error in {file_path.name}: {e}")
|
|
return "error"
|
|
def _check_length(self, last_time:float):
|
|
if last_time < self.min_last_time:
|
|
raise ValueError(f"last time {last_time} is out of range")
|
|
|
|
def _analyze(self, midi_path:Path):
|
|
# Loads and analyzes a MIDI file, performing various checks and extracting chords.
|
|
midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
|
|
|
|
# check length
|
|
mapping = midi_obj.get_tick_to_time_mapping()
|
|
last_time = mapping[midi_obj.max_tick]
|
|
self._check_length(last_time)
|
|
|
|
for ins in midi_obj.instruments:
|
|
# delete instrument with no notes
|
|
if len(ins.notes) == 0:
|
|
midi_obj.instruments.remove(ins)
|
|
continue
|
|
notes = ins.notes
|
|
notes = sorted(notes, key=lambda x: (x.start, x.pitch))
|
|
|
|
# three steps to merge instruments
|
|
self._merge_percussion(midi_obj)
|
|
# self._pruning_instrument(midi_obj)
|
|
self._limit_max_track(midi_obj)
|
|
|
|
if self.num_features == 7 or self.num_features == 8:
|
|
# in case of 7 or 8 features, we need to extract chords
|
|
new_midi_obj = self._pruning_notes_for_chord_extraction(midi_obj)
|
|
chords = Dechorder.dechord(new_midi_obj)
|
|
markers = []
|
|
for cidx, chord in enumerate(chords):
|
|
if chord.is_complete():
|
|
chord_text = NUM2PITCH[chord.root_pc] + '_' + chord.quality + '_' + NUM2PITCH[chord.bass_pc]
|
|
else:
|
|
chord_text = 'N_N_N'
|
|
markers.append(Marker(time=int(cidx*new_midi_obj.ticks_per_beat), text=chord_text))
|
|
|
|
# de-duplication
|
|
prev_chord = None
|
|
dedup_chords = []
|
|
for m in markers:
|
|
if m.text != prev_chord:
|
|
prev_chord = m.text
|
|
dedup_chords.append(m)
|
|
|
|
# return midi
|
|
midi_obj.markers = dedup_chords
|
|
return midi_obj
|
|
|
|
def _pruning_grouped_notes_from_quantization(self, instr_grid:dict):
|
|
'''
|
|
In case where notes are grouped in the same quant_time but with different start time, unintentional chords are created
|
|
rule1: if notes have half step interval, delete the shorter one
|
|
rule2: if notes do not share 50% of duration of the shorter note, delete the shorter one
|
|
'''
|
|
for instr in instr_grid.keys():
|
|
time_list = sorted(list(instr_grid[instr].keys()))
|
|
for time in time_list:
|
|
notes = instr_grid[instr][time]
|
|
if len(notes) == 1:
|
|
continue
|
|
else:
|
|
new_notes = []
|
|
# sort in pitch with ascending order
|
|
notes.sort(key=lambda x: x.pitch)
|
|
for i in range(len(notes)-1):
|
|
# if start time is same add to new_notes
|
|
if notes[i].start == notes[i+1].start:
|
|
new_notes.append(notes[i])
|
|
new_notes.append(notes[i+1])
|
|
continue
|
|
if notes[i].pitch == notes[i+1].pitch or notes[i].pitch + 1 == notes[i+1].pitch:
|
|
# select longer note
|
|
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
|
|
new_notes.append(notes[i])
|
|
else:
|
|
new_notes.append(notes[i+1])
|
|
else:
|
|
# check how much duration they share
|
|
shared_duration = min(notes[i].end, notes[i+1].end) - max(notes[i].start, notes[i+1].start)
|
|
shorter_duration = min(notes[i].end - notes[i].start, notes[i+1].end - notes[i+1].start)
|
|
# unless they share more than 80% of duration, select longer note (pruning shorter note)
|
|
if shared_duration / shorter_duration < 0.8:
|
|
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
|
|
new_notes.append(notes[i])
|
|
else:
|
|
new_notes.append(notes[i+1])
|
|
else:
|
|
if len(new_notes) == 0:
|
|
new_notes.append(notes[i])
|
|
new_notes.append(notes[i+1])
|
|
else:
|
|
new_notes.append(notes[i+1])
|
|
instr_grid[instr][time] = new_notes
|
|
|
|
def _midi2corpus(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
|
# Checks if the ticks per beat in the MIDI file is lower than the expected resolution.
|
|
# If it is, raise an error.
|
|
if midi_obj.ticks_per_beat < self.in_beat_resolution:
|
|
raise ValueError(f'[x] Irregular ticks_per_beat. {midi_obj.ticks_per_beat}')
|
|
|
|
# Ensure there is at least one time signature change in the MIDI file.
|
|
# if len(midi_obj.time_signature_changes) == 0:
|
|
# raise ValueError('[x] No time_signature_changes')
|
|
|
|
# Ensure there are no duplicated time signature changes.
|
|
# time_list = [ts.time for ts in midi_obj.time_signature_changes]
|
|
# if len(time_list) != len(set(time_list)):
|
|
# raise ValueError('[x] Duplicated time_signature_changes')
|
|
|
|
# If the dataset is 'LakhClean' or 'SymphonyMIDI', verify there are at least 4 tracks.
|
|
# if self.dataset_name == 'LakhClean' or self.dataset_name == 'SymphonyMIDI':
|
|
# if len(midi_obj.instruments) < 4:
|
|
# raise ValueError('[x] We will use more than 4 tracks in Lakh Clean dataset.')
|
|
|
|
# Calculate the resolution of ticks per beat as a fraction.
|
|
in_beat_tick_resol = Fraction(midi_obj.ticks_per_beat, self.in_beat_resolution)
|
|
|
|
# Extract the initial time signature (numerator and denominator) and calculate the number of ticks for the first bar.
|
|
if len(midi_obj.time_signature_changes) != 0:
|
|
initial_numerator = midi_obj.time_signature_changes[0].numerator
|
|
initial_denominator = midi_obj.time_signature_changes[0].denominator
|
|
else:
|
|
# If no time signature changes, set default values
|
|
initial_numerator = 4
|
|
initial_denominator = 4
|
|
first_bar_resol = int(midi_obj.ticks_per_beat * initial_numerator * (4 / initial_denominator))
|
|
|
|
# --- load notes --- #
|
|
instr_notes = self._make_instr_notes(midi_obj)
|
|
# --- load information --- #
|
|
# load chords, labels
|
|
chords = split_markers(midi_obj.markers)
|
|
chords.sort(key=lambda x: x.time)
|
|
|
|
|
|
# load tempos
|
|
tempos = midi_obj.tempo_changes if len(midi_obj.tempo_changes) > 0 else []
|
|
if len(tempos) == 0:
|
|
# if no tempo changes, set the default tempo to 120 BPM
|
|
tempos = [miditoolkit.midi.containers.TempoChange(time=0, tempo=120)]
|
|
tempos.sort(key=lambda x: x.time)
|
|
|
|
# --- process items to grid --- #
|
|
# compute empty bar offset at head
|
|
first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
|
|
last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
|
|
|
|
quant_time_first = int(round(first_note_time / in_beat_tick_resol)) * in_beat_tick_resol
|
|
offset = quant_time_first // first_bar_resol # empty bar
|
|
offset_by_resol = offset * first_bar_resol
|
|
# --- process notes --- #
|
|
instr_grid = dict()
|
|
for key in instr_notes.keys():
|
|
notes = instr_notes[key]
|
|
note_grid = defaultdict(list)
|
|
for note in notes:
|
|
# skip notes out of range, below C-1 and above C8
|
|
if note.pitch < 12 or note.pitch >= 120:
|
|
continue
|
|
|
|
# in case when the first note starts at slightly before the first bar
|
|
note.start = note.start - offset_by_resol if note.start - offset_by_resol > 0 else 0
|
|
note.end = note.end - offset_by_resol if note.end - offset_by_resol > 0 else 0
|
|
|
|
# relative duration
|
|
# skip note with 0 duration
|
|
note_duration = note.end - note.start
|
|
relative_duration = round(note_duration / in_beat_tick_resol)
|
|
if relative_duration == 0:
|
|
continue
|
|
if relative_duration > self.in_beat_resolution * 8: # 8 beats
|
|
relative_duration = self.in_beat_resolution * 8
|
|
|
|
# use regular duration bins
|
|
note.quantized_duration = self.regular_duration_bins[np.argmin(abs(self.regular_duration_bins-relative_duration))]
|
|
|
|
# quantize start time
|
|
quant_time = int(round(note.start / in_beat_tick_resol)) * in_beat_tick_resol
|
|
|
|
# velocity
|
|
note.velocity = self.regular_velocity_bins[
|
|
np.argmin(abs(self.regular_velocity_bins-note.velocity))]
|
|
|
|
# append
|
|
note_grid[quant_time].append(note)
|
|
|
|
# set to track
|
|
instr_grid[key] = note_grid
|
|
|
|
# --- pruning grouped notes --- #
|
|
self._pruning_grouped_notes_from_quantization(instr_grid)
|
|
|
|
# --- process chords --- #
|
|
chord_grid = defaultdict(list)
|
|
for chord in chords:
|
|
# quantize
|
|
chord.time = chord.time - offset_by_resol
|
|
chord.time = 0 if chord.time < 0 else chord.time
|
|
quant_time = int(round(chord.time / in_beat_tick_resol)) * in_beat_tick_resol
|
|
chord_grid[quant_time].append(chord)
|
|
|
|
# --- process tempos --- #
|
|
|
|
first_notes_list = []
|
|
for instr in instr_grid.keys():
|
|
time_list = sorted(list(instr_grid[instr].keys()))
|
|
if len(time_list) == 0: # 跳过空轨道
|
|
continue
|
|
first_quant_time = time_list[0]
|
|
first_notes_list.append(first_quant_time)
|
|
|
|
# 处理全空情况
|
|
if not first_notes_list:
|
|
raise ValueError("[x] No valid notes found in any instrument track.")
|
|
quant_first_note_time = min(first_notes_list)
|
|
tempo_grid = defaultdict(list)
|
|
for tempo in tempos:
|
|
# quantize
|
|
tempo.time = tempo.time - offset_by_resol if tempo.time - offset_by_resol > 0 else 0
|
|
quant_time = int(round(tempo.time / in_beat_tick_resol)) * in_beat_tick_resol
|
|
tempo.tempo = self.regular_tempo_bins[
|
|
np.argmin(abs(self.regular_tempo_bins-tempo.tempo))]
|
|
if quant_time < quant_first_note_time:
|
|
tempo_grid[quant_first_note_time].append(tempo)
|
|
else:
|
|
tempo_grid[quant_time].append(tempo)
|
|
if len(tempo_grid[quant_first_note_time]) > 1:
|
|
tempo_grid[quant_first_note_time] = [tempo_grid[quant_first_note_time][-1]]
|
|
# --- process time signature --- #
|
|
quant_time_signature = deepcopy(midi_obj.time_signature_changes)
|
|
quant_time_signature.sort(key=lambda x: x.time)
|
|
for ts in quant_time_signature:
|
|
ts.time = ts.time - offset_by_resol if ts.time - offset_by_resol > 0 else 0
|
|
ts.time = int(round(ts.time / in_beat_tick_resol)) * in_beat_tick_resol
|
|
|
|
# --- make new midi object to check processed values --- #
|
|
new_midi_obj = miditoolkit.midi.parser.MidiFile()
|
|
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
|
|
new_midi_obj.max_tick = midi_obj.max_tick
|
|
for instr_idx in instr_grid.keys():
|
|
new_instrument = Instrument(program=instr_idx)
|
|
new_instrument.notes = [y for x in instr_grid[instr_idx].values() for y in x]
|
|
new_midi_obj.instruments.append(new_instrument)
|
|
new_midi_obj.markers = [y for x in chord_grid.values() for y in x]
|
|
new_midi_obj.tempo_changes = [y for x in tempo_grid.values() for y in x]
|
|
new_midi_obj.time_signature_changes = midi_obj.time_signature_changes
|
|
|
|
# make corpus
|
|
song_data = {
|
|
'notes': instr_grid,
|
|
'chords': chord_grid,
|
|
'tempos': tempo_grid,
|
|
'metadata': {
|
|
'first_note': first_note_time,
|
|
'last_note': last_note_time,
|
|
'time_signature': quant_time_signature,
|
|
'ticks_per_beat': midi_obj.ticks_per_beat,
|
|
}
|
|
}
|
|
return song_data, new_midi_obj
|
|
|
|
def _make_instr_notes(self, midi_obj):
|
|
'''
|
|
This part is important, we can use three different ways to merge instruments
|
|
1st option: compare the number of notes and choose tracks with more notes
|
|
2nd option: merge all instruments with the same tracks
|
|
3rd option: leave all instruments as they are. differentiate tracks with different track number
|
|
|
|
In this version we choose to use the 2nd option as it helps to reduce the number of tracks and sequence length
|
|
'''
|
|
instr_notes = defaultdict(list)
|
|
for instr in midi_obj.instruments:
|
|
instr_idx = instr.program
|
|
# change instrument idx
|
|
instr_name = FINED_PROGRAM_INSTRUMENT_MAP.get(instr_idx)
|
|
if instr_name is None:
|
|
continue
|
|
# new_instr_idx = INSTRUMENT_PROGRAM_MAP[instr_name]
|
|
new_instr_idx = instr_idx
|
|
if new_instr_idx not in instr_notes:
|
|
instr_notes[new_instr_idx] = []
|
|
instr_notes[new_instr_idx].extend(instr.notes)
|
|
instr_notes[new_instr_idx].sort(key=lambda x: (x.start, -x.pitch))
|
|
return instr_notes
|
|
|
|
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
|
|
def _merge_percussion(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
|
'''
|
|
merge drum track to one track
|
|
'''
|
|
drum_0_lst = []
|
|
new_instruments = []
|
|
for instrument in midi_obj.instruments:
|
|
if len(instrument.notes) == 0:
|
|
continue
|
|
if instrument.is_drum:
|
|
drum_0_lst.extend(instrument.notes)
|
|
else:
|
|
new_instruments.append(instrument)
|
|
if len(drum_0_lst) > 0:
|
|
drum_0_lst.sort(key=lambda x: x.start)
|
|
# remove duplicate
|
|
drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst))
|
|
drum_0_instrument = Instrument(program=114, is_drum=True, name="percussion")
|
|
drum_0_instrument.notes = drum_0_lst
|
|
new_instruments.append(drum_0_instrument)
|
|
midi_obj.instruments = new_instruments
|
|
|
|
# referred to mmt "https://github.com/salu133445/mmt"
|
|
def _pruning_instrument(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
|
'''
|
|
merge instrument number with similar intrument category
|
|
ex. 0: Acoustic Grand Piano, 1: Bright Acoustic Piano, 2: Electric Grand Piano into 0: Acoustic Grand Piano
|
|
'''
|
|
new_instruments = []
|
|
for instr in midi_obj.instruments:
|
|
instr_idx = instr.program
|
|
# change instrument idx
|
|
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
|
|
if instr_name != None:
|
|
new_instruments.append(instr)
|
|
midi_obj.instruments = new_instruments
|
|
|
|
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
|
|
def _limit_max_track(self, midi_obj:miditoolkit.midi.parser.MidiFile, MAX_TRACK:int=16):
|
|
'''
|
|
merge track with least notes to other track with same program
|
|
and limit the maximum amount of track to 16
|
|
'''
|
|
if len(midi_obj.instruments) == 1:
|
|
if midi_obj.instruments[0].is_drum:
|
|
midi_obj.instruments[0].program = 114
|
|
midi_obj.instruments[0].is_drum = False
|
|
return midi_obj
|
|
good_instruments = midi_obj.instruments
|
|
good_instruments.sort(
|
|
key=lambda x: (not x.is_drum, -len(x.notes))) # place drum track or the most note track at first
|
|
assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len(
|
|
good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3])
|
|
# assert good_instruments[0].is_drum == False, (, len(good_instruments[2]))
|
|
track_idx_lst = list(range(len(good_instruments)))
|
|
if len(good_instruments) > MAX_TRACK:
|
|
new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK])
|
|
# print(midi_file_path)
|
|
for id in track_idx_lst[MAX_TRACK:]:
|
|
cur_ins = good_instruments[id]
|
|
merged = False
|
|
new_good_instruments.sort(key=lambda x: len(x.notes))
|
|
for nid, ins in enumerate(new_good_instruments):
|
|
if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum:
|
|
new_good_instruments[nid].notes.extend(cur_ins.notes)
|
|
merged = True
|
|
break
|
|
if not merged:
|
|
pass
|
|
good_instruments = new_good_instruments
|
|
|
|
assert len(good_instruments) <= MAX_TRACK, len(good_instruments)
|
|
for idx, good_instrument in enumerate(good_instruments):
|
|
if good_instrument.is_drum:
|
|
good_instruments[idx].program = 114
|
|
good_instruments[idx].is_drum = False
|
|
midi_obj.instruments = good_instruments
|
|
|
|
def _pruning_notes_for_chord_extraction(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
|
'''
|
|
extract notes for chord extraction
|
|
'''
|
|
new_midi_obj = miditoolkit.midi.parser.MidiFile()
|
|
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
|
|
new_midi_obj.max_tick = midi_obj.max_tick
|
|
new_instrument = Instrument(program=0, is_drum=False, name="for_chord")
|
|
new_instruments = []
|
|
new_notes = []
|
|
for instrument in midi_obj.instruments:
|
|
if instrument.program == 114 or instrument.is_drum: # pass drum track
|
|
continue
|
|
valid_notes = [note for note in instrument.notes if note.pitch >= 21 and note.pitch <= 108]
|
|
new_notes.extend(valid_notes)
|
|
new_notes.sort(key=lambda x: x.start)
|
|
new_instrument.notes = new_notes
|
|
new_instruments.append(new_instrument)
|
|
new_midi_obj.instruments = new_instruments
|
|
return new_midi_obj
|
|
|
|
def get_argument_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-d",
|
|
"--dataset",
|
|
required=True,
|
|
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
|
type=str,
|
|
help="dataset names",
|
|
)
|
|
parser.add_argument(
|
|
"-f",
|
|
"--num_features",
|
|
required=True,
|
|
choices=(4, 5, 7, 8),
|
|
type=int,
|
|
help="number of features",
|
|
)
|
|
parser.add_argument(
|
|
"-i",
|
|
"--in_dir",
|
|
default="../dataset/",
|
|
type=Path,
|
|
help="input data directory",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--out_dir",
|
|
default="../dataset/represented_data/corpus/",
|
|
type=Path,
|
|
help="output data directory",
|
|
)
|
|
parser.add_argument(
|
|
"--debug",
|
|
action="store_true",
|
|
help="enable debug mode",
|
|
)
|
|
return parser
|
|
|
|
def main():
|
|
parser = get_argument_parser()
|
|
args = parser.parse_args()
|
|
corpus_maker = CorpusMaker(args.dataset, args.num_features, args.in_dir, args.out_dir, args.debug)
|
|
corpus_maker.make_corpus()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
# python3 step1_midi2corpus.py --dataset SOD --num_features 5
|
|
# python3 step2_corpus2event.py --dataset LakhClean --num_features 5 --encoding nb
|
|
# python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb
|
|
# python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb |