first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

View File

@ -0,0 +1,81 @@
# Dataset Download
Our model supports four different datasets:
- **Symbolic Orchestral Database (SOD)**: [Link](https://qsdfo.github.io/LOP/database.html)
- **Lakh MIDI Dataset (Clean version)**: [Link](https://colinraffel.com/projects/lmd/)
- **Pop1k7**: [Link](https://github.com/YatingMusic/compound-word-transformer)
- **Pop909**: [Link](https://github.com/music-x-lab/POP909-Dataset)
### Download Instructions
You can download the datasets via the command line:
```sh
# SOD
wget https://qsdfo.github.io/LOP/database/SOD.zip
# LakhClean
wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
```
For Pop1k7, the official repository link is currently unavailable. However, you can download it from this Google Drive link:
[Download Pop1k7](https://drive.google.com/file/d/1GnbELjE-kQ4WOkBmZ3XapFKIaltySRyV/view?usp=drive_link)
For Pop909, the dataset is uploaded in the official Github repository: [Repository link](https://github.com/music-x-lab/POP909-Dataset)
### Using Your Own Dataset
If you plan to use your own dataset, you can modify the dataset class in the data_utils.py script under symbolic_encoding folder inside the nested_music_transformer folder. Alternatively, for a simpler approach, rename your dataset to match one of the following options:
- SOD: Use this for score-based MIDI datasets that require finer-grained quantization (supports up to 16th note triplet level quantization; 24 samples per quarter note).
- LakhClean: Suitable for score-based MIDI datasets requiring coarse-grained quantization (supports up to 16th note level quantization; 4 samples per quarter note).
- Pop1k7, Pop909: Ideal for expressive-based MIDI datasets requiring coarse-grained quantization (supports up to 16th note level quantization; 4 samples per quarter note).
# Data Representation
<p align="center">
<img src="figure/Data_Representation_Pipeline.png" width="1000">
</p>
This document outlines our standard data processing pipeline. By following the instructions and running the corresponding Python scripts, you can generate a data representation suited to your specific needs.
We focus on symbolic music and limit the use of musical features to a select few. Each feature set size corresponds to specific musical attributes. Through various experiments, we decided to use **7 features** for the *Pop1k7* and *Pop909* datasets, which consist of pop piano music requiring velocity for expression, and **5 features** for the *Symbolic Orchestral Database (SOD)*, *Lakh MIDI*, and *SymphonyMIDI* datasets.
- **4 features**: `["type", "beat", "pitch", "duration"]`
- **5 features**: `["type", "beat", "instrument", "pitch", "duration"]`
- **7 features**: `["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"]`
- **8 features**: `["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]`
## Parse Argument
- `-d`, `--dataset`: This required argument specifies the dataset to be used. It takes one of the following values: `"BachChorale"`, `"Pop1k7"`, `"Pop909"`, `"SOD"`, `"LakhClean"`, or `"SymphonyMIDI"`.
- `-e`, `--encoding`: This required argument specifies the encoding scheme to use. It accepts one of the following: `"remi"`, `"cp"`, `"nb"`, or `"remi_pos"`.
- `-f`, `--num_features`: This required argument specifies the number of features. It can take one of the following values: `4`, `5`, `7`, or `8`.
- `-i`, `--in_dir`: This optional argument specifies the input data directory. It defaults to `../dataset/represented_data/corpus/` if not provided.
- `-o`, `--out_dir`: This optional argument specifies the output data directory. It defaults to `../dataset/represented_data/events/`.
- `--debug`: This flag enables debug mode when included. No additional value is needed.
## 1. MIDI to Corpus
In this step, we convert MIDI files into a set of events containing various musical information. The MIDI files should be aligned with the beat and contain accurate time signature information. Place the MIDI files in `<nmt/dataset/MIDI_dataset>` and refer to the example files provided. Navigate to the `<nmt/data_representation>` folder and run the script. The converted data will be stored in `<nmt/dataset/represented_data/corpus>`.
- Example usage: `python3 step1_midi2corpus.py --dataset SOD --num_features 5`
## 2. Corpus to Event
We provide three types of representations: **REMI**, **Compound Word (CP)**, and **Note-based Encoding (NB)**. The converted data will be stored in `<nmt/dataset/represented_data/events>`.
- Example usage: `python3 step2_corpus2event.py --dataset SOD --num_features 5 --encoding nb`
## 3. Creating Vocabulary
This script creates a vocabulary in the `<nmt/vocab>` folder. The vocabulary includes event-to-index pair information.
- Example usage: `python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb`
## 4. Event to Index
In this step, we convert events into indices for efficient model training. The converted data will be stored in `<nmt/dataset/represented_data/tuneidx>`.
- Example usage: `python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb`

View File

View File

@ -0,0 +1,422 @@
import numpy as np
# for chord analysis
NUM2PITCH = {
0: 'C',
1: 'C#',
2: 'D',
3: 'D#',
4: 'E',
5: 'F',
6: 'F#',
7: 'G',
8: 'G#',
9: 'A',
10: 'A#',
11: 'B',
}
# referred to mmt "https://github.com/salu133445/mmt"
PROGRAM_INSTRUMENT_MAP = {
# Pianos
0: "piano",
1: "piano",
2: "piano",
3: "piano",
4: "electric-piano",
5: "electric-piano",
6: "harpsichord",
7: "clavinet",
# Chromatic Percussion
8: "celesta",
9: "glockenspiel",
10: "music-box",
11: "vibraphone",
12: "marimba",
13: "xylophone",
14: "tubular-bells",
15: "dulcimer",
# Organs
16: "organ",
17: "organ",
18: "organ",
19: "church-organ",
20: "organ",
21: "accordion",
22: "harmonica",
23: "bandoneon",
# Guitars
24: "nylon-string-guitar",
25: "steel-string-guitar",
26: "electric-guitar",
27: "electric-guitar",
28: "electric-guitar",
29: "electric-guitar",
30: "electric-guitar",
31: "electric-guitar",
# Basses
32: "bass",
33: "electric-bass",
34: "electric-bass",
35: "electric-bass",
36: "slap-bass",
37: "slap-bass",
38: "synth-bass",
39: "synth-bass",
# Strings
40: "violin",
41: "viola",
42: "cello",
43: "contrabass",
44: "strings",
45: "strings",
46: "harp",
47: "timpani",
# Ensemble
48: "strings",
49: "strings",
50: "synth-strings",
51: "synth-strings",
52: "voices",
53: "voices",
54: "voices",
55: "orchestra-hit",
# Brass
56: "trumpet",
57: "trombone",
58: "tuba",
59: "trumpet",
60: "horn",
61: "brasses",
62: "synth-brasses",
63: "synth-brasses",
# Reed
64: "soprano-saxophone",
65: "alto-saxophone",
66: "tenor-saxophone",
67: "baritone-saxophone",
68: "oboe",
69: "english-horn",
70: "bassoon",
71: "clarinet",
# Pipe
72: "piccolo",
73: "flute",
74: "recorder",
75: "pan-flute",
76: None,
77: None,
78: None,
79: "ocarina",
# Synth Lead
80: "lead",
81: "lead",
82: "lead",
83: "lead",
84: "lead",
85: "lead",
86: "lead",
87: "lead",
# Synth Pad
88: "pad",
89: "pad",
90: "pad",
91: "pad",
92: "pad",
93: "pad",
94: "pad",
95: "pad",
# Synth Effects
96: None,
97: None,
98: None,
99: None,
100: None,
101: None,
102: None,
103: None,
# Ethnic
104: "sitar",
105: "banjo",
106: "shamisen",
107: "koto",
108: "kalimba",
109: "bag-pipe",
110: "violin",
111: "shehnai",
# Percussive
112: None,
113: None,
114: "steel-drums",
115: None,
116: None,
117: "melodic-tom",
118: "synth-drums",
119: "synth-drums",
# Sound effects
120: None,
121: None,
122: None,
123: None,
124: None,
125: None,
126: None,
127: None,
}
# referred to mmt "https://github.com/salu133445/mmt"
INSTRUMENT_PROGRAM_MAP = {
# Pianos
"piano": 0,
"electric-piano": 4,
"harpsichord": 6,
"clavinet": 7,
# Chromatic Percussion
"celesta": 8,
"glockenspiel": 9,
"music-box": 10,
"vibraphone": 11,
"marimba": 12,
"xylophone": 13,
"tubular-bells": 14,
"dulcimer": 15,
# Organs
"organ": 16,
"church-organ": 19,
"accordion": 21,
"harmonica": 22,
"bandoneon": 23,
# Guitars
"nylon-string-guitar": 24,
"steel-string-guitar": 25,
"electric-guitar": 26,
# Basses
"bass": 32,
"electric-bass": 33,
"slap-bass": 36,
"synth-bass": 38,
# Strings
"violin": 40,
"viola": 41,
"cello": 42,
"contrabass": 43,
"harp": 46,
"timpani": 47,
# Ensemble
"strings": 49,
"synth-strings": 50,
"voices": 52,
"orchestra-hit": 55,
# Brass
"trumpet": 56,
"trombone": 57,
"tuba": 58,
"horn": 60,
"brasses": 61,
"synth-brasses": 62,
# Reed
"soprano-saxophone": 64,
"alto-saxophone": 65,
"tenor-saxophone": 66,
"baritone-saxophone": 67,
"oboe": 68,
"english-horn": 69,
"bassoon": 70,
"clarinet": 71,
# Pipe
"piccolo": 72,
"flute": 73,
"recorder": 74,
"pan-flute": 75,
"ocarina": 79,
# Synth Lead
"lead": 80,
# Synth Pad
"pad": 88,
# Ethnic
"sitar": 104,
"banjo": 105,
"shamisen": 106,
"koto": 107,
"kalimba": 108,
"bag-pipe": 109,
"shehnai": 111,
# Percussive
"steel-drums": 114,
"melodic-tom": 117,
"synth-drums": 118,
}
FINED_PROGRAM_INSTRUMENT_MAP ={
# Pianos
0: "Acoustic-Grand-Piano",
1: "Bright-Acoustic-Piano",
2: "Electric-Grand-Piano",
3: "Honky-Tonk-Piano",
4: "Electric-Piano-1",
5: "Electric-Piano-2",
6: "Harpsichord",
7: "Clavinet",
# Chromatic Percussion
8: "Celesta",
9: "Glockenspiel",
10: "Music-Box",
11: "Vibraphone",
12: "Marimba",
13: "Xylophone",
14: "Tubular-Bells",
15: "Dulcimer",
# Organs
16: "Drawbar-Organ",
17: "Percussive-Organ",
18: "Rock-Organ",
19: "Church-Organ",
20: "Reed-Organ",
21: "Accordion",
22: "Harmonica",
23: "Tango-Accordion",
# Guitars
24: "Acoustic-Guitar-nylon",
25: "Acoustic-Guitar-steel",
26: "Electric-Guitar-jazz",
27: "Electric-Guitar-clean",
28: "Electric-Guitar-muted",
29: "Overdriven-Guitar",
30: "Distortion-Guitar",
31: "Guitar-harmonics",
# Basses
32: "Acoustic-Bass",
33: "Electric-Bass-finger",
34: "Electric-Bass-pick",
35: "Fretless-Bass",
36: "Slap-Bass-1",
37: "Slap-Bass-2",
38: "Synth-Bass-1",
39: "Synth-Bass-2",
# Strings & Orchestral
40: "Violin",
41: "Viola",
42: "Cello",
43: "Contrabass",
44: "Tremolo-Strings",
45: "Pizzicato-Strings",
46: "Orchestral-Harp",
47: "Timpani",
# Ensemble
48: "String-Ensemble-1",
49: "String-Ensemble-2",
50: "Synth-Strings-1",
51: "Synth-Strings-2",
52: "Choir-Aahs",
53: "Voice-Oohs",
54: "Synth-Voice",
55: "Orchestra-Hit",
# Brass
56: "Trumpet",
57: "Trombone",
58: "Tuba",
59: "Muted-Trumpet",
60: "French-Horn",
61: "Brass-Section",
62: "Synth-Brass-1",
63: "Synth-Brass-2",
# Reeds
64: "Soprano-Sax",
65: "Alto-Sax",
66: "Tenor-Sax",
67: "Baritone-Sax",
68: "Oboe",
69: "English-Horn",
70: "Bassoon",
71: "Clarinet",
# Pipes
72: "Piccolo",
73: "Flute",
74: "Recorder",
75: "Pan-Flute",
76: "Blown-Bottle",
77: "Shakuhachi",
78: "Whistle",
79: "Ocarina",
# Synth Lead
80: "Lead-1-square",
81: "Lead-2-sawtooth",
82: "Lead-3-calliope",
83: "Lead-4-chiff",
84: "Lead-5-charang",
85: "Lead-6-voice",
86: "Lead-7-fifths",
87: "Lead-8-bass+lead",
# Synth Pad
88: "Pad-1-new-age",
89: "Pad-2-warm",
90: "Pad-3-polysynth",
91: "Pad-4-choir",
92: "Pad-5-bowed",
93: "Pad-6-metallic",
94: "Pad-7-halo",
95: "Pad-8-sweep",
# Effects
96: "FX-1-rain",
97: "FX-2-soundtrack",
98: "FX-3-crystal",
99: "FX-4-atmosphere",
100: "FX-5-brightness",
101: "FX-6-goblins",
102: "FX-7-echoes",
103: "FX-8-sci-fi",
# Ethnic & Percussion
104: "Sitar",
105: "Banjo",
106: "Shamisen",
107: "Koto",
108: "Kalimba",
109: "Bag-pipe",
110: "Fiddle",
111: "Shanai",
# Percussive
112: "Tinkle-Bell",
113: "Agogo",
114: "Steel-Drums",
115: "Woodblock",
116: "Taiko-Drum",
117: "Melodic-Tom",
118: "Synth-Drum",
119: "Reverse-Cymbal",
# Sound Effects
120: "Guitar-Fret-Noise",
121: "Breath-Noise",
122: "Seashore",
123: "Bird-Tweet",
124: "Telephone-Ring",
125: "Helicopter",
126: "Applause",
127: "Gunshot"
}
REGULAR_NUM_DENOM = [(1, 1), (1, 2), (2, 2), (3, 2), (4, 2),
(1, 4), (2, 4), (3, 4), (4, 4), (5, 4), (6, 4), (7, 4), (8, 4),
(1, 8), (2, 8), (3, 8), (4, 8), (5, 8), (6, 8), (7, 8), (8, 8), (9, 8), (11, 8), (12, 8)]
CORE_NUM_DENOM = [(1, 1), (1, 2), (2, 2), (4, 2),
(1, 4), (2, 4), (3, 4), (4, 4), (5, 4),
(1, 8), (2, 8), (3, 8), (6, 8), (9, 8), (12, 8)]
VALID_TIME_SIGNATURES = ['time_signature_' + str(x[0]) + '/' + str(x[1]) for x in REGULAR_NUM_DENOM]
# cover possible time signatures
REGULAR_TICKS_PER_BEAT = [48, 96, 192, 384, 120, 240, 480, 960, 256, 512, 1024]

View File

@ -0,0 +1,879 @@
from typing import Any
from fractions import Fraction
from collections import defaultdict
from miditoolkit import TimeSignature
from constants import *
'''
This script contains specific encoding functions for different encoding schemes.
'''
def frange(start, stop, step):
while start < stop:
yield start
start += step
################################# for REMI style encoding #################################
class Corpus2event_remi():
def __init__(self, num_features:int):
self.num_features = num_features
def _create_event(self, name, value):
event = dict()
event['name'] = name
event['value'] = value
return event
def _break_down_numerator(self, numerator, possible_time_signatures):
"""Break down a numerator into smaller time signatures.
Args:
numerator: Target numerator to decompose (must be > 0).
possible_time_signatures: List of (numerator, denominator) tuples,
sorted in descending order (e.g., [(4,4), (3,4)]).
Returns:
List of decomposed time signatures (e.g., [(4,4), (3,4)]).
Raises:
ValueError: If decomposition is impossible.
"""
if numerator <= 0:
raise ValueError("Numerator must be positive.")
if not possible_time_signatures:
raise ValueError("No possible time signatures provided.")
result = []
original_numerator = numerator # For error message
# Sort signatures in descending order to prioritize larger chunks
possible_time_signatures = sorted(possible_time_signatures, key=lambda x: -x[0])
while numerator > 0:
subtracted = False # Track if any subtraction occurred in this iteration
for sig in possible_time_signatures:
sig_numerator, _ = sig
if sig_numerator <= 0:
continue # Skip invalid signatures
while numerator >= sig_numerator:
result.append(sig)
numerator -= sig_numerator
subtracted = True
# If no progress was made, decomposition failed
if not subtracted:
raise ValueError(
f"Cannot decompose numerator {original_numerator} "
f"with given time signatures {possible_time_signatures}. "
f"Remaining: {numerator}"
)
return result
def _normalize_time_signature(self, time_signature, ticks_per_beat, next_change_point):
"""
Normalize irregular time signatures to standard ones by breaking them down
into common time signatures, and adjusting their durations to fit the given
musical structure.
Parameters:
- time_signature: TimeSignature object with numerator, denominator, and start time.
- ticks_per_beat: Number of ticks per beat, representing the resolution of the timing.
- next_change_point: Tick position where the next time signature change occurs.
Returns:
- A list of TimeSignature objects, normalized to fit within regular time signatures.
Procedure:
1. If the time signature is already a standard one (in REGULAR_NUM_DENOM), return it.
2. For non-standard signatures, break them down into simpler, well-known signatures.
- For unusual denominations (e.g., 16th, 32nd, or 64th notes), normalize to 4/4.
- For 6/4 signatures, break it into two 3/4 measures.
3. If the time signature has a non-standard numerator and denominator, break it down
into the largest possible numerators that still fit within the denominator.
This ensures that the final measure fits within the regular time signature format.
4. Calculate the resolution (duration in ticks) for each bar and ensure the bars
fit within the time until the next change point.
- Adjust the number of bars if they exceed the available space.
- If the total length is too short, repeat the first (largest) bar to fill the gap.
5. Convert the breakdown into TimeSignature objects and return the normalized result.
"""
# Check if the time signature is a regular one, return it if so
if (time_signature.numerator, time_signature.denominator) in REGULAR_NUM_DENOM:
return [time_signature]
# Extract time signature components
numerator, denominator, bar_start_tick = time_signature.numerator, time_signature.denominator, time_signature.time
# Normalize time signatures with 16th, 32nd, or 64th note denominators to 4/4
if denominator in [16, 32, 64]:
return [TimeSignature(4, 4, time_signature.time)]
# Special case for 6/4, break it into two 3/4 bars
elif denominator == 6 and numerator == 4:
return [TimeSignature(3, 4, time_signature.time), TimeSignature(3, 4, time_signature.time)]
# Determine possible regular signatures for the given denominator
possible_time_signatures = [sig for sig in CORE_NUM_DENOM if sig[1] == denominator]
# Sort by numerator in descending order to prioritize larger numerators
possible_time_signatures.sort(key=lambda x: x[0], reverse=True)
result = []
# Break down the numerator into smaller regular numerators
max_iterations = 100 # Prevent infinite loops
original_numerator = numerator # Store original for error message
# Break down the numerator into smaller regular numerators
iteration_count = 0
while numerator > 0:
iteration_count += 1
if iteration_count > max_iterations:
raise ValueError(
f"Failed to normalize time signature {original_numerator}/{denominator}. "
f"Could not break down numerator {original_numerator} with available signatures: "
f"{possible_time_signatures}"
)
for sig in possible_time_signatures:
# Subtract numerators and add to the result
while numerator >= sig[0]:
result.append(sig)
numerator -= sig[0]
# Calculate the resolution (length in ticks) of each bar
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
# Adjust bars to fit within the remaining ticks before the next change point
total_length = 0
for idx, bar_resol in enumerate(bar_resol_list):
total_length += bar_resol
if total_length > next_change_point - bar_start_tick:
result = result[:idx+1]
break
# If the total length is too short, repeat the first (largest) bar until the gap is filled
while total_length < next_change_point - bar_start_tick:
result.append(result[0])
total_length += int(ticks_per_beat * result[0][0] * (4 / result[0][1]))
# Recalculate bar resolutions for the final result
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
# Insert a starting resolution of 0 and calculate absolute tick positions for each TimeSignature
bar_resol_list.insert(0, 0)
total_length = bar_start_tick
normalized_result = []
for sig, length in zip(result, bar_resol_list):
total_length += length
normalized_result.append(TimeSignature(sig[0], sig[1], total_length))
return normalized_result
def _process_time_signature(self, time_signature_changes, ticks_per_beat, first_note_tick, global_end):
"""
Process and normalize time signature changes for a given musical piece.
Parameters:
- time_signature_changes: A list of TimeSignature objects representing time signature changes in the music.
- ticks_per_beat: The resolution of timing in ticks per beat.
- first_note_tick: The tick position of the first note in the piece.
- global_end: The tick position where the piece ends.
Returns:
- A list of processed and normalized time signature changes. If no valid time signature
changes are found, returns None.
Procedure:
1. Check the validity of the time signature changes:
- Ensure there is at least one time signature change.
- Ensure the first time signature change occurs at the beginning (before the first note).
2. Remove duplicate consecutive time signatures:
- Only add time signatures that differ from the previous one (de-duplication).
3. Normalize the time signatures:
- For each time signature, determine its duration by calculating the time until the
next change point or the end of the piece.
- Use the _normalize_time_signature method to break down non-standard signatures into
simpler, well-known signatures that fit within the musical structure.
4. Return the processed and normalized time signature changes.
"""
# Check if there are any time signature changes
if len(time_signature_changes) == 0:
print("No time signature change in this tune, default to 4/4 time signature")
# default to 4/4 time signature if none are found
return [TimeSignature(4, 4, 0)]
# Ensure the first time signature change is at the start of the piece (before the first note)
if time_signature_changes[0].time != 0 and time_signature_changes[0].time > first_note_tick:
print("The first time signature change is not at the beginning of the tune")
return None
# Remove consecutive duplicate time signatures (de-duplication)
processed_time_signature_changes = []
for idx, time_sig in enumerate(time_signature_changes):
if idx == 0:
processed_time_signature_changes.append(time_sig)
else:
prev_time_sig = time_signature_changes[idx-1]
# Only add time signature if it's different from the previous one
if not (prev_time_sig.numerator == time_sig.numerator and prev_time_sig.denominator == time_sig.denominator):
processed_time_signature_changes.append(time_sig)
# Normalize the time signatures to standard formats
normalized_time_signature_changes = []
for idx, time_signature in enumerate(processed_time_signature_changes):
if idx == len(time_signature_changes) - 1:
# If it's the last time signature change, set the next change point as the end of the piece
next_change_point = global_end
else:
# Otherwise, set the next change point as the next time signature's start time
next_change_point = time_signature_changes[idx+1].time
# Normalize the current time signature and extend the result
normalized_time_signature_changes.extend(self._normalize_time_signature(time_signature, ticks_per_beat, next_change_point))
# Return the list of processed and normalized time signatures
time_signature_changes = normalized_time_signature_changes
return time_signature_changes
def _half_step_interval_gap_check_across_instruments(self, instrument_note_dict):
'''
This function checks for half-step interval gaps between notes across different instruments.
It will avoid half-step intervals by keeping one note from any pair of notes that are a half-step apart,
regardless of which instrument they belong to.
'''
# order instrument_note_dict by pitch in descending order
instrument_note_dict = dict(sorted(instrument_note_dict.items()))
# Create a dictionary to store all pitches across instruments
all_pitches = {}
# Collect all pitches from each instrument and sort them in descending order
for instrument, notes in instrument_note_dict.items():
for pitch, durations in notes.items():
all_pitches[pitch] = all_pitches.get(pitch, []) + [(instrument, durations)]
# Sort the pitches in descending order
sorted_pitches = sorted(all_pitches.keys(), reverse=True)
# Create a new list to store the final pitches after comparison
final_pitch_list = []
# Use an index pointer to control the sliding window
idx = 0
while idx < len(sorted_pitches) - 1:
current_pitch = sorted_pitches[idx]
next_pitch = sorted_pitches[idx + 1]
if current_pitch - next_pitch == 1: # Check for a half-step interval gap
current_max_duration = max(duration for _, durations in all_pitches[current_pitch] for duration, _ in durations)
next_max_duration = max(duration for _, durations in all_pitches[next_pitch] for duration, _ in durations)
if current_max_duration < next_max_duration:
# Keep the higher pitch (next_pitch) and skip the current_pitch
final_pitch_list.append(next_pitch)
else:
# Keep the lower pitch (current_pitch) and skip the next_pitch
final_pitch_list.append(current_pitch)
# Skip the next pitch because we already handled it
idx += 2
else:
# No half-step gap, keep the current pitch and move to the next one
final_pitch_list.append(current_pitch)
idx += 1
# Ensure the last pitch is added if it's not part of a half-step interval
if idx == len(sorted_pitches) - 1:
final_pitch_list.append(sorted_pitches[-1])
# Filter out notes not in the final pitch list and update the instrument_note_dict
for instrument in instrument_note_dict.keys():
instrument_note_dict[instrument] = {
pitch: instrument_note_dict[instrument][pitch]
for pitch in sorted(instrument_note_dict[instrument].keys(), reverse=True) if pitch in final_pitch_list
}
return instrument_note_dict
def __call__(self, song_data, in_beat_resolution):
'''
Process a song's data to generate a sequence of musical events, including bars, chords, tempo,
and notes, similar to the approach used in the CP paper (corpus2event_remi_v2).
Parameters:
- song_data: A dictionary containing metadata, notes, chords, and tempos of the song.
- in_beat_resolution: The resolution of timing in beats (how many divisions per beat).
Returns:
- A sequence of musical events including start (SOS), bars, chords, tempo, instruments, notes,
and an end (EOS) event. If the time signature is invalid, returns None.
Procedure:
1. **Global Setup**:
- Extract global metadata like first and last note ticks, time signature changes, and ticks
per beat.
- Compute `in_beat_tick_resol`, the ratio of ticks per beat to the input beat resolution,
to assist in dividing bars later.
- Get a sorted list of instruments in the song.
2. **Time Signature Processing**:
- Call `_process_time_signature` to clean up and normalize the time signatures in the song.
- If the time signatures are invalid (e.g., no time signature changes or missing at the
start), the function exits early with None.
3. **Sequence Generation**:
- Initialize the sequence with a start token (SOS) and prepare variables for tracking
previous chord, tempo, and instrument states.
- Loop through each time signature change, dividing the song into measures based on the
current time signature's numerator and denominator.
- For each measure, append "Bar" tokens to mark measure boundaries, while ensuring that no
more than four consecutive empty bars are added.
- For each step within a measure, process the following:
- **Chords**: If there is a chord change, add a corresponding chord event.
- **Tempo**: If the tempo changes, add a tempo event.
- **Notes**: Iterate over each instrument, adding notes and checking for half-step
intervals, deduplicating notes, and choosing the longest duration for each pitch.
- Append a "Beat" event for each step with musical events.
4. **End Sequence**:
- Conclude the sequence by appending a final "Bar" token followed by an end token (EOS).
'''
# --- global tag --- #
first_note_tick = song_data['metadata']['first_note'] # Starting tick of the first note
global_end = song_data['metadata']['last_note'] # Ending tick of the last note
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat resolution
# Resolution for dividing beats within measures, expressed as a fraction
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Example: 1024/12 -> (256, 3)
instrument_list = sorted(list(song_data['notes'].keys())) # Get a sorted list of instruments in the song
# --- process time signature --- #
# Normalize and process the time signatures in the song
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
if time_signature_changes == None:
return None # Exit if time signature is invalid
# --- create sequence --- #
prev_instr_idx = None # Track the previously processed instrument
final_sequence = []
final_sequence.append(self._create_event('SOS', None)) # Add Start of Sequence (SOS) token
prev_chord = None # Track the previous chord
prev_tempo = None # Track the previous tempo
chord_value = None
tempo_value = None
# Process each time signature change
for idx in range(len(time_signature_changes)):
time_sig_change_flag = True # Flag to indicate a time signature change
# Calculate bar resolution based on the current time signature
numerator = time_signature_changes[idx].numerator
denominator = time_signature_changes[idx].denominator
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format time signature name
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate bar resolution in ticks
bar_start_tick = time_signature_changes[idx].time # Start tick of the current bar
# Determine the next time signature change point or the end of the song
if idx == len(time_signature_changes) - 1:
next_change_point = global_end
else:
next_change_point = time_signature_changes[idx+1].time
# Process each measure within the current time signature
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
empty_bar_token = self._create_event('Bar', None) # Token for empty bars
# Ensure no more than 4 consecutive empty bars are added
if len(final_sequence) >= 4:
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and
final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
if time_sig_change_flag:
final_sequence.append(self._create_event('Bar', time_sig_name)) # Mark new bar with time signature
else:
final_sequence.append(self._create_event('Bar', None))
else:
if time_sig_change_flag:
final_sequence.append(self._create_event('Bar', time_sig_name))
else:
if time_sig_change_flag:
final_sequence.append(self._create_event('Bar', time_sig_name))
else:
final_sequence.append(self._create_event('Bar', None))
time_sig_change_flag = False # Reset time signature change flag
# Process events within each beat
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
events_list = []
# Retrieve chords and tempos at the current beat step
t_chords = song_data['chords'].get(beat_step)
t_tempos = song_data['tempos'].get(beat_step)
# Process chord and tempo if the number of features allows for it
if self.num_features in {8, 7}:
if t_chords is not None:
root, quality, _ = t_chords[-1].text.split('_') # Extract chord info
chord_value = root + '_' + quality
if t_tempos is not None:
tempo_value = t_tempos[-1].tempo # Extract tempo value
# Dictionary to track notes for each instrument to avoid duplicates
instrument_note_dict = defaultdict(dict)
# Process notes for each instrument at the current beat step
for instrument_idx in instrument_list:
t_notes = song_data['notes'][instrument_idx].get(beat_step)
# If there are notes at this beat step, process them.
if t_notes is not None:
# Track notes to avoid duplicates and check for half-step intervals
for note in t_notes:
if note.pitch not in instrument_note_dict[instrument_idx]:
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
else:
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
if len(instrument_note_dict) == 0:
continue
# Check for half-step interval gaps and handle them across instruments
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
# add chord and tempo
if self.num_features in {7, 8}:
if prev_chord != chord_value:
events_list.append(self._create_event('Chord', chord_value))
prev_chord = chord_value
if prev_tempo != tempo_value:
events_list.append(self._create_event('Tempo', tempo_value))
prev_tempo = tempo_value
# add instrument and note
for instrument in pruned_instrument_note_dict:
if self.num_features in {5, 8}:
events_list.append(self._create_event('Instrument', instrument))
for pitch in pruned_instrument_note_dict[instrument]:
max_duration = max(pruned_instrument_note_dict[instrument][pitch], key=lambda x: x[0])
note_event = [
self._create_event('Note_Pitch', pitch),
self._create_event('Note_Duration', max_duration[0])
]
if self.num_features in {7, 8}:
note_event.append(self._create_event('Note_Velocity', max_duration[1]))
events_list.extend(note_event)
# If there are events in this step, add a "Beat" event and the collected events
if len(events_list):
final_sequence.append(self._create_event('Beat', in_beat_off_idx))
final_sequence.extend(events_list)
# --- end with BAR & EOS --- #
final_sequence.append(self._create_event('Bar', None)) # Add final bar token
final_sequence.append(self._create_event('EOS', None)) # Add End of Sequence (EOS) token
return final_sequence
################################# for CP style encoding #################################
class Corpus2event_cp(Corpus2event_remi):
def __init__(self, num_features):
super().__init__(num_features)
self.num_features = num_features
self._init_event_template()
def _init_event_template(self):
'''
The order of musical features is Type, Beat, Chord, Tempo, Instrument, Pitch, Duration, Velocity
'''
self.event_template = {}
if self.num_features == 8:
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
elif self.num_features == 7:
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
elif self.num_features == 5:
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
elif self.num_features == 4:
feature_names = ['type', 'beat', 'pitch', 'duration']
for feature_name in feature_names:
self.event_template[feature_name] = 0
def create_cp_sos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'SOS'
return total_event
def create_cp_eos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'EOS'
return total_event
def create_cp_metrical_event(self, pos, chord, tempo):
'''
when the compound token is related to metrical information
'''
meter_event = self.event_template.copy()
meter_event['type'] = 'Metrical'
meter_event['beat'] = pos
if self.num_features == 7 or self.num_features == 8:
meter_event['chord'] = chord
meter_event['tempo'] = tempo
return meter_event
def create_cp_note_event(self, instrument_name, pitch, duration, velocity):
'''
when the compound token is related to note information
'''
note_event = self.event_template.copy()
note_event['type'] = 'Note'
note_event['pitch'] = pitch
note_event['duration'] = duration
if self.num_features == 5 or self.num_features == 8:
note_event['instrument'] = instrument_name
if self.num_features == 7 or self.num_features == 8:
note_event['velocity'] = velocity
return note_event
def create_cp_bar_event(self, time_sig_change_flag=False, time_sig_name=None):
meter_event = self.event_template.copy()
if time_sig_change_flag:
meter_event['type'] = 'Metrical'
meter_event['beat'] = f'Bar_{time_sig_name}'
else:
meter_event['type'] = 'Metrical'
meter_event['beat'] = 'Bar'
return meter_event
def __call__(self, song_data, in_beat_resolution):
# --- global tag --- #
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
# --- process time signature --- #
# Process time signature changes and adjust them for the given song structure
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
if time_signature_changes == None:
return None # Exit if no valid time signature changes found
# --- create sequence --- #
final_sequence = [] # Initialize the final sequence to store the events
final_sequence.append(self.create_cp_sos_event()) # Add the Start-of-Sequence (SOS) event
chord_text = None # Placeholder for the current chord
tempo_text = None # Placeholder for the current tempo
# Loop through each time signature change and process the corresponding measures
for idx in range(len(time_signature_changes)):
time_sig_change_flag = True # Flag to track when time signature changes
# Calculate bar resolution (number of ticks per bar based on the time signature)
numerator = time_signature_changes[idx].numerator
denominator = time_signature_changes[idx].denominator
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
# Determine the point for the next time signature change or the end of the song
if idx == len(time_signature_changes) - 1:
next_change_point = global_end
else:
next_change_point = time_signature_changes[idx + 1].time
# Iterate over each measure (bar) between the current and next time signature change
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
empty_bar_token = self.create_cp_bar_event() # Create an empty bar event
# Check if the last four events in the sequence are consecutive empty bars
if len(final_sequence) >= 4:
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
else:
if time_sig_change_flag:
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
else:
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
# Reset the time signature change flag after handling the bar event
time_sig_change_flag = False
# Loop through beats in each measure based on the in-beat resolution
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
chord_tempo_flag = False # Flag to track if chord and tempo events are added
events_list = [] # List to hold events for the current beat
pos_text = 'Beat_' + str(in_beat_off_idx) # Create a beat event label
# --- chord & tempo processing --- #
# Unpack chords and tempos for the current beat step
t_chords = song_data['chords'].get(beat_step)
t_tempos = song_data['tempos'].get(beat_step)
# If a chord is present, extract its root, quality, and bass
if self.num_features in {7, 8}:
if t_chords is not None:
root, quality, _ = t_chords[-1].text.split('_')
chord_text = 'Chord_' + root + '_' + quality
# If a tempo is present, format it as a string
if t_tempos is not None:
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
# Dictionary to track notes for each instrument to avoid duplicates
instrument_note_dict = defaultdict(dict)
# --- instrument & note processing --- #
# Loop through each instrument and process its notes at the current beat step
for instrument_idx in instrument_list:
t_notes = song_data['notes'][instrument_idx].get(beat_step)
# If notes are present, process them
if t_notes != None:
# Track notes and their properties (duration and velocity) for the current instrument
for note in t_notes:
if note.pitch not in instrument_note_dict[instrument_idx]:
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
else:
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
if len(instrument_note_dict) == 0:
continue
# Check for half-step interval gaps and handle them across instruments
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
# add chord and tempo
if self.num_features in {7, 8}:
if not chord_tempo_flag:
if chord_text == None:
chord_text = 'Chord_N_N'
if tempo_text == None:
tempo_text = 'Tempo_N_N'
chord_tempo_flag = True
events_list.append(self.create_cp_metrical_event(pos_text, chord_text, tempo_text))
# add instrument and note
for instrument_idx in pruned_instrument_note_dict:
instrument_name = 'Instrument_' + str(instrument_idx)
for pitch in pruned_instrument_note_dict[instrument_idx]:
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
note_pitch_text = 'Note_Pitch_' + str(pitch)
note_duration_text = 'Note_Duration_' + str(max_duration[0])
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
events_list.append(self.create_cp_note_event(instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
# If there are any events for this beat, add them to the final sequence
if len(events_list) > 0:
final_sequence.extend(events_list)
# --- end with BAR & EOS --- #
final_sequence.append(self.create_cp_bar_event()) # Add the final bar event
final_sequence.append(self.create_cp_eos_event()) # Add the End-of-Sequence (EOS) event
return final_sequence # Return the final sequence of events
################################# for NB style encoding #################################
class Corpus2event_nb(Corpus2event_cp):
def __init__(self, num_features):
'''
For convenience in logging, we use "type" word for "metric" sub-token in the code to compare easily with other encoding schemes
'''
super().__init__(num_features)
self.num_features = num_features
self._init_event_template()
def _init_event_template(self):
self.event_template = {}
if self.num_features == 8:
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
elif self.num_features == 7:
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
elif self.num_features == 5:
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
elif self.num_features == 4:
feature_names = ['type', 'beat', 'pitch', 'duration']
for feature_name in feature_names:
self.event_template[feature_name] = 0
def create_nb_sos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'SOS'
return total_event
def create_nb_eos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'EOS'
return total_event
def create_nb_event(self, bar_beat_type, pos, chord, tempo, instrument_name, pitch, duration, velocity):
total_event = self.event_template.copy()
total_event['type'] = bar_beat_type
total_event['beat'] = pos
total_event['pitch'] = pitch
total_event['duration'] = duration
if self.num_features in {5, 8}:
total_event['instrument'] = instrument_name
if self.num_features in {7, 8}:
total_event['chord'] = chord
total_event['tempo'] = tempo
total_event['velocity'] = velocity
return total_event
def create_nb_empty_bar_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'Empty_Bar'
return total_event
def get_bar_beat_idx(self, bar_flag, beat_flag, time_sig_name, time_sig_change_flag):
'''
This function is to get the metric information for the current bar and beat
There are four types of metric information: NNN, SNN, SSN, SSS
Each letter represents the change of time signature, bar, and beat (new or same)
'''
if time_sig_change_flag: # new time signature
return "NNN_" + time_sig_name
else:
if bar_flag and beat_flag: # same time sig & new bar & new beat
return "SNN"
elif not bar_flag and beat_flag: # same time sig & same bar & new beat
return "SSN"
elif not bar_flag and not beat_flag: # same time sig & same bar & same beat
return "SSS"
def __call__(self, song_data, in_beat_resolution:int):
# --- global tag --- #
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
# --- process time signature --- #
# Process time signature changes and adjust them for the given song structure
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
if time_signature_changes == None:
return None # Exit if no valid time signature changes found
# --- create sequence --- #
final_sequence = [] # Initialize the final sequence to store the events
final_sequence.append(self.create_nb_sos_event()) # Add the Start-of-Sequence (SOS) event
chord_text = None # Placeholder for the current chord
tempo_text = None # Placeholder for the current tempo
# Loop through each time signature change and process the corresponding measures
for idx in range(len(time_signature_changes)):
time_sig_change_flag = True # Flag to track when time signature changes
# Calculate bar resolution (number of ticks per bar based on the time signature)
numerator = time_signature_changes[idx].numerator
denominator = time_signature_changes[idx].denominator
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
# Determine the point for the next time signature change or the end of the song
if idx == len(time_signature_changes) - 1:
next_change_point = global_end
else:
next_change_point = time_signature_changes[idx + 1].time
# Iterate over each measure (bar) between the current and next time signature change
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
bar_flag = True
note_flag = False
# Loop through beats in each measure based on the in-beat resolution
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
beat_flag = True
events_list = []
pos_text = 'Beat_' + str(in_beat_off_idx)
# --- chord & tempo processing --- #
# Unpack chords and tempos for the current beat step
t_chords = song_data['chords'].get(beat_step)
t_tempos = song_data['tempos'].get(beat_step)
# If a chord is present, extract its root, quality, and bass
if self.num_features == 8 or self.num_features == 7:
if t_chords is not None:
root, quality, _ = t_chords[-1].text.split('_')
chord_text = 'Chord_' + root + '_' + quality
# If a tempo is present, format it as a string
if t_tempos is not None:
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
# Dictionary to track notes for each instrument to avoid duplicates
instrument_note_dict = defaultdict(dict)
# --- instrument & note processing --- #
# Loop through each instrument and process its notes at the current beat step
for instrument_idx in instrument_list:
t_notes = song_data['notes'][instrument_idx].get(beat_step)
# If notes are present, process them
if t_notes != None:
note_flag = True
# Track notes and their properties (duration and velocity) for the current instrument
for note in t_notes:
if note.pitch not in instrument_note_dict[instrument_idx]:
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
else:
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
# # Check for half-step interval gaps and handle them accordingly
# self._half_step_interval_gap_check(instrument_note_dict, instrument_idx)
if len(instrument_note_dict) == 0:
continue
# Check for half-step interval gaps and handle them across instruments
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
# add chord and tempo
if self.num_features in {7, 8}:
if chord_text == None:
chord_text = 'Chord_N_N'
if tempo_text == None:
tempo_text = 'Tempo_N_N'
# add instrument and note
for instrument_idx in pruned_instrument_note_dict:
instrument_name = 'Instrument_' + str(instrument_idx)
for pitch in pruned_instrument_note_dict[instrument_idx]:
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
note_pitch_text = 'Note_Pitch_' + str(pitch)
note_duration_text = 'Note_Duration_' + str(max_duration[0])
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
bar_beat_type = self.get_bar_beat_idx(bar_flag, beat_flag, time_sig_name, time_sig_change_flag)
events_list.append(self.create_nb_event(bar_beat_type, pos_text, chord_text, tempo_text, instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
bar_flag = False
beat_flag = False
time_sig_change_flag = False
# If there are any events for this beat, add them to the final sequence
if events_list != None and len(events_list):
final_sequence.extend(events_list)
# when there is no note in this bar
if not note_flag:
# avoid consecutive empty bars (more than 4 is not allowed)
empty_bar_token = self.create_nb_empty_bar_event()
if len(final_sequence) >= 4:
if final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token:
continue
final_sequence.append(empty_bar_token)
# --- end with BAR & EOS --- #
final_sequence.append(self.create_nb_eos_event())
return final_sequence

View File

@ -0,0 +1,650 @@
import argparse
import time
import itertools
import copy
from copy import deepcopy
from pathlib import Path
from multiprocessing import Pool, cpu_count
from collections import defaultdict
from fractions import Fraction
from typing import List
import os
from muspy import sort
import numpy as np
import pickle
from tqdm import tqdm
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument
from chorder import Dechorder
from constants import NUM2PITCH, PROGRAM_INSTRUMENT_MAP, INSTRUMENT_PROGRAM_MAP
'''
This script is designed to preprocess MIDI files and convert them into a structured corpus suitable for symbolic music analysis or model training.
It handles various tasks, including setting beat resolution, calculating duration, velocity, and tempo bins, and processing MIDI data into quantized musical events.
'''
def get_tempo_bin(max_tempo:int, ratio:float=1.1):
bpm = 30
regular_tempo_bins = [bpm]
while bpm < max_tempo:
bpm *= ratio
bpm = round(bpm)
if bpm > max_tempo:
break
regular_tempo_bins.append(bpm)
return np.array(regular_tempo_bins)
def split_markers(markers:List[miditoolkit.midi.containers.Marker]):
'''
split markers into chord, tempo, label
'''
chords = []
for marker in markers:
splitted_text = marker.text.split('_')
if splitted_text[0] != 'global' and 'Boundary' not in splitted_text[0]:
chords.append(marker)
return chords
class CorpusMaker():
def __init__(
self,
dataset_name:str,
num_features:int,
in_dir:Path,
out_dir:Path,
debug:bool
):
'''
Initialize the CorpusMaker with dataset information and directory paths.
It sets up MIDI paths, output directories, and debug mode, then
retrieves the beat resolution, duration bins, velocity/tempo bins, and prepares the MIDI file list.
'''
self.dataset_name = dataset_name
self.num_features = num_features
self.midi_path = in_dir / f"{dataset_name}"
self.out_dir = out_dir
self.debug = debug
self._get_in_beat_resolution()
self._get_duration_bins()
self._get_velocity_tempo_bins()
self._get_min_max_last_time()
self._prepare_midi_list()
def _get_in_beat_resolution(self):
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
try:
self.in_beat_resolution = in_beat_resolution_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
def _get_duration_bins(self):
# Set up regular duration bins for quantizing note lengths, based on the beat resolution.
base_duration = {4:[1,2,3,4,5,6,8,10,12,16,20,24,28,32],
8:[1,2,3,4,6,8,10,12,14,16,20,24,28,32,36,40,48,56,64],
12:[1,2,3,4,6,9,12,15,18,24,30,36,42,48,54,60,72,84,96]}
base_duration_list = base_duration[self.in_beat_resolution]
self.regular_duration_bins = np.array(base_duration_list)
def _get_velocity_tempo_bins(self):
# Define velocity and tempo bins based on whether the dataset is a performance or score type.
midi_type_dict = {'BachChorale': 'score', 'Pop1k7': 'perform', 'Pop909': 'score', 'SOD': 'score', 'LakhClean': 'score', 'Symphony': 'score'}
try:
midi_type = midi_type_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
midi_type = midi_type_dict['LakhClean']
# For performance-type datasets, set finer granularity of velocity and tempo bins.
if midi_type == 'perform':
self.regular_velocity_bins = np.array(list(range(40, 128, 8)) + [127])
self.regular_tempo_bins = get_tempo_bin(max_tempo=240, ratio=1.04)
# For score-type datasets, use coarser velocity and tempo bins.
elif midi_type == 'score':
self.regular_velocity_bins = np.array([40, 60, 80, 100, 120])
self.regular_tempo_bins = get_tempo_bin(max_tempo=390, ratio=1.04)
def _get_min_max_last_time(self):
'''
Set the minimum and maximum allowed length of a MIDI track, depending on the dataset.
0 to 2000 means no limitation
'''
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)}
try:
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
self.min_last_time, self.max_last_time = last_time_dict['LakhClean']
def _prepare_midi_list(self):
midi_path = Path(self.midi_path)
# detect subdirectories and get all midi files
if not midi_path.exists():
raise ValueError(f"midi_path {midi_path} does not exist")
# go though all subdirectories and get all midi files
midi_files = []
for root, _, files in os.walk(midi_path):
for file in files:
if file.endswith('.mid'):
# print(Path(root) / file)
midi_files.append(Path(root) / file)
self.midi_list = midi_files
print(f"Found {len(self.midi_list)} MIDI files in {midi_path}")
def make_corpus(self) -> None:
'''
Main method to process the MIDI files and create the corpus data.
It supports both single-processing (debug mode) and multi-processing for large datasets.
'''
print("preprocessing midi data to corpus data")
# check the corpus folder is already exist and make it if not
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
Path(self.out_dir / f"corpus_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
Path(self.out_dir / f"midi_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
start_time = time.time()
if self.debug:
# single processing for debugging
broken_counter = 0
success_counter = 0
for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
message = self._mp_midi2corpus(file_path)
if message == "error":
broken_counter += 1
elif message == "success":
success_counter += 1
else:
# Multi-threaded processing for faster corpus generation.
broken_counter = 0
success_counter = 0
# filter out processed files
print(self.out_dir)
processed_files = list(Path(self.out_dir).glob(f"midi_{self.dataset_name}/*.mid"))
processed_files = [x.name for x in processed_files]
print(f"processed files: {len(processed_files)}")
print("length of midi list: ", len(self.midi_list))
# Use set for faster lookup (O(1) per check)
processed_files_set = set(processed_files)
self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
# reverse the list to process the latest files first
self.midi_list.reverse()
print(f"length of midi list after filtering: ", len(self.midi_list))
with Pool(16) as p:
for message in tqdm(p.imap(self._mp_midi2corpus, self.midi_list, 1000), total=len(self.midi_list)):
if message == "error":
broken_counter += 1
elif message == "success":
success_counter += 1
# for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
# message = self._mp_midi2corpus(file_path)
# if message == "error":
# broken_counter += 1
# elif message == "success":
# success_counter += 1
print(f"Making corpus takes: {time.time() - start_time}s, success: {success_counter}, broken: {broken_counter}")
def _mp_midi2corpus(self, file_path: Path):
"""Convert MIDI to corpus format and save both corpus (.pkl) and MIDI (.mid)."""
try:
midi_obj = self._analyze(file_path)
corpus, midi_obj = self._midi2corpus(midi_obj)
# --- 1. Save corpus (.pkl) ---
relative_path = file_path.relative_to(self.midi_path) # Get relative path from input dir
safe_name = str(relative_path).replace("/", "_").replace("\\", "_").replace(".mid", ".pkl")
save_path = Path(self.out_dir) / f"corpus_{self.dataset_name}" / safe_name
save_path.parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
with save_path.open("wb") as f:
pickle.dump(corpus, f)
# --- 2. Save MIDI (.mid) ---
midi_save_dir = Path("../dataset/represented_data/corpus") / f"midi_{self.dataset_name}"
midi_save_dir.mkdir(parents=True, exist_ok=True)
midi_save_path = midi_save_dir / file_path.name # Keep original MIDI filename
midi_obj.dump(midi_save_path)
del midi_obj, corpus
return "success"
except (OSError, EOFError, ValueError, KeyError, AssertionError) as e:
print(f"Error processing {file_path.name}: {e}")
return "error"
except Exception as e:
print(f"Unexpected error in {file_path.name}: {e}")
return "error"
def _check_length(self, last_time:float):
if last_time < self.min_last_time:
raise ValueError(f"last time {last_time} is out of range")
def _analyze(self, midi_path:Path):
# Loads and analyzes a MIDI file, performing various checks and extracting chords.
midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
# check length
mapping = midi_obj.get_tick_to_time_mapping()
last_time = mapping[midi_obj.max_tick]
self._check_length(last_time)
for ins in midi_obj.instruments:
# delete instrument with no notes
if len(ins.notes) == 0:
midi_obj.instruments.remove(ins)
continue
notes = ins.notes
notes = sorted(notes, key=lambda x: (x.start, x.pitch))
# three steps to merge instruments
self._merge_percussion(midi_obj)
self._pruning_instrument(midi_obj)
self._limit_max_track(midi_obj)
if self.num_features == 7 or self.num_features == 8:
# in case of 7 or 8 features, we need to extract chords
new_midi_obj = self._pruning_notes_for_chord_extraction(midi_obj)
chords = Dechorder.dechord(new_midi_obj)
markers = []
for cidx, chord in enumerate(chords):
if chord.is_complete():
chord_text = NUM2PITCH[chord.root_pc] + '_' + chord.quality + '_' + NUM2PITCH[chord.bass_pc]
else:
chord_text = 'N_N_N'
markers.append(Marker(time=int(cidx*new_midi_obj.ticks_per_beat), text=chord_text))
# de-duplication
prev_chord = None
dedup_chords = []
for m in markers:
if m.text != prev_chord:
prev_chord = m.text
dedup_chords.append(m)
# return midi
midi_obj.markers = dedup_chords
return midi_obj
def _pruning_grouped_notes_from_quantization(self, instr_grid:dict):
'''
In case where notes are grouped in the same quant_time but with different start time, unintentional chords are created
rule1: if notes have half step interval, delete the shorter one
rule2: if notes do not share 50% of duration of the shorter note, delete the shorter one
'''
for instr in instr_grid.keys():
time_list = sorted(list(instr_grid[instr].keys()))
for time in time_list:
notes = instr_grid[instr][time]
if len(notes) == 1:
continue
else:
new_notes = []
# sort in pitch with ascending order
notes.sort(key=lambda x: x.pitch)
for i in range(len(notes)-1):
# if start time is same add to new_notes
if notes[i].start == notes[i+1].start:
new_notes.append(notes[i])
new_notes.append(notes[i+1])
continue
if notes[i].pitch == notes[i+1].pitch or notes[i].pitch + 1 == notes[i+1].pitch:
# select longer note
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
new_notes.append(notes[i])
else:
new_notes.append(notes[i+1])
else:
# check how much duration they share
shared_duration = min(notes[i].end, notes[i+1].end) - max(notes[i].start, notes[i+1].start)
shorter_duration = min(notes[i].end - notes[i].start, notes[i+1].end - notes[i+1].start)
# unless they share more than 80% of duration, select longer note (pruning shorter note)
if shared_duration / shorter_duration < 0.8:
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
new_notes.append(notes[i])
else:
new_notes.append(notes[i+1])
else:
if len(new_notes) == 0:
new_notes.append(notes[i])
new_notes.append(notes[i+1])
else:
new_notes.append(notes[i+1])
instr_grid[instr][time] = new_notes
def _midi2corpus(self, midi_obj:miditoolkit.midi.parser.MidiFile):
# Checks if the ticks per beat in the MIDI file is lower than the expected resolution.
# If it is, raise an error.
if midi_obj.ticks_per_beat < self.in_beat_resolution:
raise ValueError(f'[x] Irregular ticks_per_beat. {midi_obj.ticks_per_beat}')
# Ensure there is at least one time signature change in the MIDI file.
# if len(midi_obj.time_signature_changes) == 0:
# raise ValueError('[x] No time_signature_changes')
# Ensure there are no duplicated time signature changes.
# time_list = [ts.time for ts in midi_obj.time_signature_changes]
# if len(time_list) != len(set(time_list)):
# raise ValueError('[x] Duplicated time_signature_changes')
# If the dataset is 'LakhClean' or 'SymphonyMIDI', verify there are at least 4 tracks.
# if self.dataset_name == 'LakhClean' or self.dataset_name == 'SymphonyMIDI':
# if len(midi_obj.instruments) < 4:
# raise ValueError('[x] We will use more than 4 tracks in Lakh Clean dataset.')
# Calculate the resolution of ticks per beat as a fraction.
in_beat_tick_resol = Fraction(midi_obj.ticks_per_beat, self.in_beat_resolution)
# Extract the initial time signature (numerator and denominator) and calculate the number of ticks for the first bar.
if len(midi_obj.time_signature_changes) != 0:
initial_numerator = midi_obj.time_signature_changes[0].numerator
initial_denominator = midi_obj.time_signature_changes[0].denominator
else:
# If no time signature changes, set default values
initial_numerator = 4
initial_denominator = 4
first_bar_resol = int(midi_obj.ticks_per_beat * initial_numerator * (4 / initial_denominator))
# --- load notes --- #
instr_notes = self._make_instr_notes(midi_obj)
# --- load information --- #
# load chords, labels
chords = split_markers(midi_obj.markers)
chords.sort(key=lambda x: x.time)
# load tempos
tempos = midi_obj.tempo_changes if len(midi_obj.tempo_changes) > 0 else []
if len(tempos) == 0:
# if no tempo changes, set the default tempo to 120 BPM
tempos = [miditoolkit.midi.containers.TempoChange(time=0, tempo=120)]
tempos.sort(key=lambda x: x.time)
# --- process items to grid --- #
# compute empty bar offset at head
first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
quant_time_first = int(round(first_note_time / in_beat_tick_resol)) * in_beat_tick_resol
offset = quant_time_first // first_bar_resol # empty bar
offset_by_resol = offset * first_bar_resol
# --- process notes --- #
instr_grid = dict()
for key in instr_notes.keys():
notes = instr_notes[key]
note_grid = defaultdict(list)
for note in notes:
# skip notes out of range, below C-1 and above C8
if note.pitch < 12 or note.pitch >= 120:
continue
# in case when the first note starts at slightly before the first bar
note.start = note.start - offset_by_resol if note.start - offset_by_resol > 0 else 0
note.end = note.end - offset_by_resol if note.end - offset_by_resol > 0 else 0
# relative duration
# skip note with 0 duration
note_duration = note.end - note.start
relative_duration = round(note_duration / in_beat_tick_resol)
if relative_duration == 0:
continue
if relative_duration > self.in_beat_resolution * 8: # 8 beats
relative_duration = self.in_beat_resolution * 8
# use regular duration bins
note.quantized_duration = self.regular_duration_bins[np.argmin(abs(self.regular_duration_bins-relative_duration))]
# quantize start time
quant_time = int(round(note.start / in_beat_tick_resol)) * in_beat_tick_resol
# velocity
note.velocity = self.regular_velocity_bins[
np.argmin(abs(self.regular_velocity_bins-note.velocity))]
# append
note_grid[quant_time].append(note)
# set to track
instr_grid[key] = note_grid
# --- pruning grouped notes --- #
self._pruning_grouped_notes_from_quantization(instr_grid)
# --- process chords --- #
chord_grid = defaultdict(list)
for chord in chords:
# quantize
chord.time = chord.time - offset_by_resol
chord.time = 0 if chord.time < 0 else chord.time
quant_time = int(round(chord.time / in_beat_tick_resol)) * in_beat_tick_resol
chord_grid[quant_time].append(chord)
# --- process tempos --- #
first_notes_list = []
for instr in instr_grid.keys():
time_list = sorted(list(instr_grid[instr].keys()))
if len(time_list) == 0: # 跳过空轨道
continue
first_quant_time = time_list[0]
first_notes_list.append(first_quant_time)
# 处理全空情况
if not first_notes_list:
raise ValueError("[x] No valid notes found in any instrument track.")
quant_first_note_time = min(first_notes_list)
tempo_grid = defaultdict(list)
for tempo in tempos:
# quantize
tempo.time = tempo.time - offset_by_resol if tempo.time - offset_by_resol > 0 else 0
quant_time = int(round(tempo.time / in_beat_tick_resol)) * in_beat_tick_resol
tempo.tempo = self.regular_tempo_bins[
np.argmin(abs(self.regular_tempo_bins-tempo.tempo))]
if quant_time < quant_first_note_time:
tempo_grid[quant_first_note_time].append(tempo)
else:
tempo_grid[quant_time].append(tempo)
if len(tempo_grid[quant_first_note_time]) > 1:
tempo_grid[quant_first_note_time] = [tempo_grid[quant_first_note_time][-1]]
# --- process time signature --- #
quant_time_signature = deepcopy(midi_obj.time_signature_changes)
quant_time_signature.sort(key=lambda x: x.time)
for ts in quant_time_signature:
ts.time = ts.time - offset_by_resol if ts.time - offset_by_resol > 0 else 0
ts.time = int(round(ts.time / in_beat_tick_resol)) * in_beat_tick_resol
# --- make new midi object to check processed values --- #
new_midi_obj = miditoolkit.midi.parser.MidiFile()
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
new_midi_obj.max_tick = midi_obj.max_tick
for instr_idx in instr_grid.keys():
new_instrument = Instrument(program=instr_idx)
new_instrument.notes = [y for x in instr_grid[instr_idx].values() for y in x]
new_midi_obj.instruments.append(new_instrument)
new_midi_obj.markers = [y for x in chord_grid.values() for y in x]
new_midi_obj.tempo_changes = [y for x in tempo_grid.values() for y in x]
new_midi_obj.time_signature_changes = midi_obj.time_signature_changes
# make corpus
song_data = {
'notes': instr_grid,
'chords': chord_grid,
'tempos': tempo_grid,
'metadata': {
'first_note': first_note_time,
'last_note': last_note_time,
'time_signature': quant_time_signature,
'ticks_per_beat': midi_obj.ticks_per_beat,
}
}
return song_data, new_midi_obj
def _make_instr_notes(self, midi_obj):
'''
This part is important, we can use three different ways to merge instruments
1st option: compare the number of notes and choose tracks with more notes
2nd option: merge all instruments with the same tracks
3rd option: leave all instruments as they are. differentiate tracks with different track number
In this version we choose to use the 2nd option as it helps to reduce the number of tracks and sequence length
'''
instr_notes = defaultdict(list)
for instr in midi_obj.instruments:
instr_idx = instr.program
# change instrument idx
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
if instr_name is None:
continue
new_instr_idx = INSTRUMENT_PROGRAM_MAP[instr_name]
instr_notes[new_instr_idx].extend(instr.notes)
instr_notes[new_instr_idx].sort(key=lambda x: (x.start, -x.pitch))
return instr_notes
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
def _merge_percussion(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
merge drum track to one track
'''
drum_0_lst = []
new_instruments = []
for instrument in midi_obj.instruments:
if len(instrument.notes) == 0:
continue
if instrument.is_drum:
drum_0_lst.extend(instrument.notes)
else:
new_instruments.append(instrument)
if len(drum_0_lst) > 0:
drum_0_lst.sort(key=lambda x: x.start)
# remove duplicate
drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst))
drum_0_instrument = Instrument(program=114, is_drum=True, name="percussion")
drum_0_instrument.notes = drum_0_lst
new_instruments.append(drum_0_instrument)
midi_obj.instruments = new_instruments
# referred to mmt "https://github.com/salu133445/mmt"
def _pruning_instrument(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
merge instrument number with similar intrument category
ex. 0: Acoustic Grand Piano, 1: Bright Acoustic Piano, 2: Electric Grand Piano into 0: Acoustic Grand Piano
'''
new_instruments = []
for instr in midi_obj.instruments:
instr_idx = instr.program
# change instrument idx
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
if instr_name != None:
new_instruments.append(instr)
midi_obj.instruments = new_instruments
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
def _limit_max_track(self, midi_obj:miditoolkit.midi.parser.MidiFile, MAX_TRACK:int=16):
'''
merge track with least notes to other track with same program
and limit the maximum amount of track to 16
'''
if len(midi_obj.instruments) == 1:
if midi_obj.instruments[0].is_drum:
midi_obj.instruments[0].program = 114
midi_obj.instruments[0].is_drum = False
return midi_obj
good_instruments = midi_obj.instruments
good_instruments.sort(
key=lambda x: (not x.is_drum, -len(x.notes))) # place drum track or the most note track at first
assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len(
good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3])
# assert good_instruments[0].is_drum == False, (, len(good_instruments[2]))
track_idx_lst = list(range(len(good_instruments)))
if len(good_instruments) > MAX_TRACK:
new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK])
# print(midi_file_path)
for id in track_idx_lst[MAX_TRACK:]:
cur_ins = good_instruments[id]
merged = False
new_good_instruments.sort(key=lambda x: len(x.notes))
for nid, ins in enumerate(new_good_instruments):
if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum:
new_good_instruments[nid].notes.extend(cur_ins.notes)
merged = True
break
if not merged:
pass
good_instruments = new_good_instruments
assert len(good_instruments) <= MAX_TRACK, len(good_instruments)
for idx, good_instrument in enumerate(good_instruments):
if good_instrument.is_drum:
good_instruments[idx].program = 114
good_instruments[idx].is_drum = False
midi_obj.instruments = good_instruments
def _pruning_notes_for_chord_extraction(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
extract notes for chord extraction
'''
new_midi_obj = miditoolkit.midi.parser.MidiFile()
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
new_midi_obj.max_tick = midi_obj.max_tick
new_instrument = Instrument(program=0, is_drum=False, name="for_chord")
new_instruments = []
new_notes = []
for instrument in midi_obj.instruments:
if instrument.program == 114 or instrument.is_drum: # pass drum track
continue
valid_notes = [note for note in instrument.notes if note.pitch >= 21 and note.pitch <= 108]
new_notes.extend(valid_notes)
new_notes.sort(key=lambda x: x.start)
new_instrument.notes = new_notes
new_instruments.append(new_instrument)
new_midi_obj.instruments = new_instruments
return new_midi_obj
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dataset",
required=True,
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
type=str,
help="dataset names",
)
parser.add_argument(
"-f",
"--num_features",
required=True,
choices=(4, 5, 7, 8),
type=int,
help="number of features",
)
parser.add_argument(
"-i",
"--in_dir",
default="../dataset/",
type=Path,
help="input data directory",
)
parser.add_argument(
"-o",
"--out_dir",
default="../dataset/represented_data/corpus/",
type=Path,
help="output data directory",
)
parser.add_argument(
"--debug",
action="store_true",
help="enable debug mode",
)
return parser
def main():
parser = get_argument_parser()
args = parser.parse_args()
corpus_maker = CorpusMaker(args.dataset, args.num_features, args.in_dir, args.out_dir, args.debug)
corpus_maker.make_corpus()
if __name__ == "__main__":
main()
# python3 step1_midi2corpus.py --dataset SOD --num_features 5
# python3 step2_corpus2event.py --dataset LakhClean --num_features 5 --encoding nb
# python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb
# python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb

View File

@ -0,0 +1,654 @@
import argparse
import time
import itertools
import copy
from copy import deepcopy
from pathlib import Path
from multiprocessing import Pool, cpu_count
from collections import defaultdict
from fractions import Fraction
from typing import List
import os
from muspy import sort
import numpy as np
import pickle
from tqdm import tqdm
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument
from chorder import Dechorder
from constants import NUM2PITCH,FINED_PROGRAM_INSTRUMENT_MAP, INSTRUMENT_PROGRAM_MAP
'''
This script is designed to preprocess MIDI files and convert them into a structured corpus suitable for symbolic music analysis or model training.
It handles various tasks, including setting beat resolution, calculating duration, velocity, and tempo bins, and processing MIDI data into quantized musical events.
We dont do instrument merging here.
'''
def get_tempo_bin(max_tempo:int, ratio:float=1.1):
bpm = 30
regular_tempo_bins = [bpm]
while bpm < max_tempo:
bpm *= ratio
bpm = round(bpm)
if bpm > max_tempo:
break
regular_tempo_bins.append(bpm)
return np.array(regular_tempo_bins)
def split_markers(markers:List[miditoolkit.midi.containers.Marker]):
'''
split markers into chord, tempo, label
'''
chords = []
for marker in markers:
splitted_text = marker.text.split('_')
if splitted_text[0] != 'global' and 'Boundary' not in splitted_text[0]:
chords.append(marker)
return chords
class CorpusMaker():
def __init__(
self,
dataset_name:str,
num_features:int,
in_dir:Path,
out_dir:Path,
debug:bool
):
'''
Initialize the CorpusMaker with dataset information and directory paths.
It sets up MIDI paths, output directories, and debug mode, then
retrieves the beat resolution, duration bins, velocity/tempo bins, and prepares the MIDI file list.
'''
self.dataset_name = dataset_name
self.num_features = num_features
self.midi_path = in_dir / f"{dataset_name}"
self.out_dir = out_dir
self.debug = debug
self._get_in_beat_resolution()
self._get_duration_bins()
self._get_velocity_tempo_bins()
self._get_min_max_last_time()
self._prepare_midi_list()
def _get_in_beat_resolution(self):
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
try:
self.in_beat_resolution = in_beat_resolution_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
def _get_duration_bins(self):
# Set up regular duration bins for quantizing note lengths, based on the beat resolution.
base_duration = {4:[1,2,3,4,5,6,8,10,12,16,20,24,28,32],
8:[1,2,3,4,6,8,10,12,14,16,20,24,28,32,36,40,48,56,64],
12:[1,2,3,4,6,9,12,15,18,24,30,36,42,48,54,60,72,84,96]}
base_duration_list = base_duration[self.in_beat_resolution]
self.regular_duration_bins = np.array(base_duration_list)
def _get_velocity_tempo_bins(self):
# Define velocity and tempo bins based on whether the dataset is a performance or score type.
midi_type_dict = {'BachChorale': 'score', 'Pop1k7': 'perform', 'Pop909': 'score', 'SOD': 'score', 'LakhClean': 'score', 'Symphony': 'score'}
try:
midi_type = midi_type_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
midi_type = midi_type_dict['LakhClean']
# For performance-type datasets, set finer granularity of velocity and tempo bins.
if midi_type == 'perform':
self.regular_velocity_bins = np.array(list(range(40, 128, 8)) + [127])
self.regular_tempo_bins = get_tempo_bin(max_tempo=240, ratio=1.04)
# For score-type datasets, use coarser velocity and tempo bins.
elif midi_type == 'score':
self.regular_velocity_bins = np.array([40, 60, 80, 100, 120])
self.regular_tempo_bins = get_tempo_bin(max_tempo=390, ratio=1.04)
def _get_min_max_last_time(self):
'''
Set the minimum and maximum allowed length of a MIDI track, depending on the dataset.
0 to 2000 means no limitation
'''
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)}
try:
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
except KeyError:
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
self.min_last_time, self.max_last_time = last_time_dict['LakhClean']
def _prepare_midi_list(self):
midi_path = Path(self.midi_path)
# detect subdirectories and get all midi files
if not midi_path.exists():
raise ValueError(f"midi_path {midi_path} does not exist")
# go though all subdirectories and get all midi files
midi_files = []
for root, _, files in os.walk(midi_path):
for file in files:
if file.endswith('.mid') or file.endswith('.midi') or file.endswith('.MID'):
# print(Path(root) / file)
midi_files.append(Path(root) / file)
self.midi_list = midi_files
print(f"Found {len(self.midi_list)} MIDI files in {midi_path}")
def make_corpus(self) -> None:
'''
Main method to process the MIDI files and create the corpus data.
It supports both single-processing (debug mode) and multi-processing for large datasets.
'''
print("preprocessing midi data to corpus data")
# check the corpus folder is already exist and make it if not
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
Path(self.out_dir / f"corpus_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
Path(self.out_dir / f"midi_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
start_time = time.time()
if self.debug:
# single processing for debugging
broken_counter = 0
success_counter = 0
for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
message = self._mp_midi2corpus(file_path)
if message == "error":
broken_counter += 1
elif message == "success":
success_counter += 1
else:
# Multi-threaded processing for faster corpus generation.
broken_counter = 0
success_counter = 0
# filter out processed files
print(self.out_dir)
processed_files = list(Path(self.out_dir).glob(f"midi_{self.dataset_name}/*.mid"))
processed_files = [x.name for x in processed_files]
print(f"processed files: {len(processed_files)}")
print("length of midi list: ", len(self.midi_list))
# Use set for faster lookup (O(1) per check)
processed_files_set = set(processed_files)
# self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
# reverse the list to process the latest files first
self.midi_list.reverse()
print(f"length of midi list after filtering: ", len(self.midi_list))
with Pool(16) as p:
for message in tqdm(p.imap(self._mp_midi2corpus, self.midi_list, 500), total=len(self.midi_list)):
if message == "error":
broken_counter += 1
elif message == "success":
success_counter += 1
# for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
# message = self._mp_midi2corpus(file_path)
# if message == "error":
# broken_counter += 1
# elif message == "success":
# success_counter += 1
print(f"Making corpus takes: {time.time() - start_time}s, success: {success_counter}, broken: {broken_counter}")
def _mp_midi2corpus(self, file_path: Path):
"""Convert MIDI to corpus format and save both corpus (.pkl) and MIDI (.mid)."""
try:
midi_obj = self._analyze(file_path)
corpus, midi_obj = self._midi2corpus(midi_obj)
# --- 1. Save corpus (.pkl) ---
relative_path = file_path.relative_to(self.midi_path) # Get relative path from input dir
safe_name = str(relative_path).replace("/", "_").replace("\\", "_").replace(".mid", ".pkl")
save_path = Path(self.out_dir) / f"corpus_{self.dataset_name}" / safe_name
save_path.parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
with save_path.open("wb") as f:
pickle.dump(corpus, f)
# --- 2. Save MIDI (.mid) ---
midi_save_dir = Path("../dataset/represented_data/corpus") / f"midi_{self.dataset_name}"
midi_save_dir.mkdir(parents=True, exist_ok=True)
midi_save_path = midi_save_dir / file_path.name # Keep original MIDI filename
midi_obj.dump(midi_save_path)
del midi_obj, corpus
return "success"
except (OSError, EOFError, ValueError, KeyError, AssertionError) as e:
print(f"Error processing {file_path.name}: {e}")
return "error"
except Exception as e:
print(f"Unexpected error in {file_path.name}: {e}")
return "error"
def _check_length(self, last_time:float):
if last_time < self.min_last_time:
raise ValueError(f"last time {last_time} is out of range")
def _analyze(self, midi_path:Path):
# Loads and analyzes a MIDI file, performing various checks and extracting chords.
midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
# check length
mapping = midi_obj.get_tick_to_time_mapping()
last_time = mapping[midi_obj.max_tick]
self._check_length(last_time)
for ins in midi_obj.instruments:
# delete instrument with no notes
if len(ins.notes) == 0:
midi_obj.instruments.remove(ins)
continue
notes = ins.notes
notes = sorted(notes, key=lambda x: (x.start, x.pitch))
# three steps to merge instruments
self._merge_percussion(midi_obj)
# self._pruning_instrument(midi_obj)
self._limit_max_track(midi_obj)
if self.num_features == 7 or self.num_features == 8:
# in case of 7 or 8 features, we need to extract chords
new_midi_obj = self._pruning_notes_for_chord_extraction(midi_obj)
chords = Dechorder.dechord(new_midi_obj)
markers = []
for cidx, chord in enumerate(chords):
if chord.is_complete():
chord_text = NUM2PITCH[chord.root_pc] + '_' + chord.quality + '_' + NUM2PITCH[chord.bass_pc]
else:
chord_text = 'N_N_N'
markers.append(Marker(time=int(cidx*new_midi_obj.ticks_per_beat), text=chord_text))
# de-duplication
prev_chord = None
dedup_chords = []
for m in markers:
if m.text != prev_chord:
prev_chord = m.text
dedup_chords.append(m)
# return midi
midi_obj.markers = dedup_chords
return midi_obj
def _pruning_grouped_notes_from_quantization(self, instr_grid:dict):
'''
In case where notes are grouped in the same quant_time but with different start time, unintentional chords are created
rule1: if notes have half step interval, delete the shorter one
rule2: if notes do not share 50% of duration of the shorter note, delete the shorter one
'''
for instr in instr_grid.keys():
time_list = sorted(list(instr_grid[instr].keys()))
for time in time_list:
notes = instr_grid[instr][time]
if len(notes) == 1:
continue
else:
new_notes = []
# sort in pitch with ascending order
notes.sort(key=lambda x: x.pitch)
for i in range(len(notes)-1):
# if start time is same add to new_notes
if notes[i].start == notes[i+1].start:
new_notes.append(notes[i])
new_notes.append(notes[i+1])
continue
if notes[i].pitch == notes[i+1].pitch or notes[i].pitch + 1 == notes[i+1].pitch:
# select longer note
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
new_notes.append(notes[i])
else:
new_notes.append(notes[i+1])
else:
# check how much duration they share
shared_duration = min(notes[i].end, notes[i+1].end) - max(notes[i].start, notes[i+1].start)
shorter_duration = min(notes[i].end - notes[i].start, notes[i+1].end - notes[i+1].start)
# unless they share more than 80% of duration, select longer note (pruning shorter note)
if shared_duration / shorter_duration < 0.8:
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
new_notes.append(notes[i])
else:
new_notes.append(notes[i+1])
else:
if len(new_notes) == 0:
new_notes.append(notes[i])
new_notes.append(notes[i+1])
else:
new_notes.append(notes[i+1])
instr_grid[instr][time] = new_notes
def _midi2corpus(self, midi_obj:miditoolkit.midi.parser.MidiFile):
# Checks if the ticks per beat in the MIDI file is lower than the expected resolution.
# If it is, raise an error.
if midi_obj.ticks_per_beat < self.in_beat_resolution:
raise ValueError(f'[x] Irregular ticks_per_beat. {midi_obj.ticks_per_beat}')
# Ensure there is at least one time signature change in the MIDI file.
# if len(midi_obj.time_signature_changes) == 0:
# raise ValueError('[x] No time_signature_changes')
# Ensure there are no duplicated time signature changes.
# time_list = [ts.time for ts in midi_obj.time_signature_changes]
# if len(time_list) != len(set(time_list)):
# raise ValueError('[x] Duplicated time_signature_changes')
# If the dataset is 'LakhClean' or 'SymphonyMIDI', verify there are at least 4 tracks.
# if self.dataset_name == 'LakhClean' or self.dataset_name == 'SymphonyMIDI':
# if len(midi_obj.instruments) < 4:
# raise ValueError('[x] We will use more than 4 tracks in Lakh Clean dataset.')
# Calculate the resolution of ticks per beat as a fraction.
in_beat_tick_resol = Fraction(midi_obj.ticks_per_beat, self.in_beat_resolution)
# Extract the initial time signature (numerator and denominator) and calculate the number of ticks for the first bar.
if len(midi_obj.time_signature_changes) != 0:
initial_numerator = midi_obj.time_signature_changes[0].numerator
initial_denominator = midi_obj.time_signature_changes[0].denominator
else:
# If no time signature changes, set default values
initial_numerator = 4
initial_denominator = 4
first_bar_resol = int(midi_obj.ticks_per_beat * initial_numerator * (4 / initial_denominator))
# --- load notes --- #
instr_notes = self._make_instr_notes(midi_obj)
# --- load information --- #
# load chords, labels
chords = split_markers(midi_obj.markers)
chords.sort(key=lambda x: x.time)
# load tempos
tempos = midi_obj.tempo_changes if len(midi_obj.tempo_changes) > 0 else []
if len(tempos) == 0:
# if no tempo changes, set the default tempo to 120 BPM
tempos = [miditoolkit.midi.containers.TempoChange(time=0, tempo=120)]
tempos.sort(key=lambda x: x.time)
# --- process items to grid --- #
# compute empty bar offset at head
first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
quant_time_first = int(round(first_note_time / in_beat_tick_resol)) * in_beat_tick_resol
offset = quant_time_first // first_bar_resol # empty bar
offset_by_resol = offset * first_bar_resol
# --- process notes --- #
instr_grid = dict()
for key in instr_notes.keys():
notes = instr_notes[key]
note_grid = defaultdict(list)
for note in notes:
# skip notes out of range, below C-1 and above C8
if note.pitch < 12 or note.pitch >= 120:
continue
# in case when the first note starts at slightly before the first bar
note.start = note.start - offset_by_resol if note.start - offset_by_resol > 0 else 0
note.end = note.end - offset_by_resol if note.end - offset_by_resol > 0 else 0
# relative duration
# skip note with 0 duration
note_duration = note.end - note.start
relative_duration = round(note_duration / in_beat_tick_resol)
if relative_duration == 0:
continue
if relative_duration > self.in_beat_resolution * 8: # 8 beats
relative_duration = self.in_beat_resolution * 8
# use regular duration bins
note.quantized_duration = self.regular_duration_bins[np.argmin(abs(self.regular_duration_bins-relative_duration))]
# quantize start time
quant_time = int(round(note.start / in_beat_tick_resol)) * in_beat_tick_resol
# velocity
note.velocity = self.regular_velocity_bins[
np.argmin(abs(self.regular_velocity_bins-note.velocity))]
# append
note_grid[quant_time].append(note)
# set to track
instr_grid[key] = note_grid
# --- pruning grouped notes --- #
self._pruning_grouped_notes_from_quantization(instr_grid)
# --- process chords --- #
chord_grid = defaultdict(list)
for chord in chords:
# quantize
chord.time = chord.time - offset_by_resol
chord.time = 0 if chord.time < 0 else chord.time
quant_time = int(round(chord.time / in_beat_tick_resol)) * in_beat_tick_resol
chord_grid[quant_time].append(chord)
# --- process tempos --- #
first_notes_list = []
for instr in instr_grid.keys():
time_list = sorted(list(instr_grid[instr].keys()))
if len(time_list) == 0: # 跳过空轨道
continue
first_quant_time = time_list[0]
first_notes_list.append(first_quant_time)
# 处理全空情况
if not first_notes_list:
raise ValueError("[x] No valid notes found in any instrument track.")
quant_first_note_time = min(first_notes_list)
tempo_grid = defaultdict(list)
for tempo in tempos:
# quantize
tempo.time = tempo.time - offset_by_resol if tempo.time - offset_by_resol > 0 else 0
quant_time = int(round(tempo.time / in_beat_tick_resol)) * in_beat_tick_resol
tempo.tempo = self.regular_tempo_bins[
np.argmin(abs(self.regular_tempo_bins-tempo.tempo))]
if quant_time < quant_first_note_time:
tempo_grid[quant_first_note_time].append(tempo)
else:
tempo_grid[quant_time].append(tempo)
if len(tempo_grid[quant_first_note_time]) > 1:
tempo_grid[quant_first_note_time] = [tempo_grid[quant_first_note_time][-1]]
# --- process time signature --- #
quant_time_signature = deepcopy(midi_obj.time_signature_changes)
quant_time_signature.sort(key=lambda x: x.time)
for ts in quant_time_signature:
ts.time = ts.time - offset_by_resol if ts.time - offset_by_resol > 0 else 0
ts.time = int(round(ts.time / in_beat_tick_resol)) * in_beat_tick_resol
# --- make new midi object to check processed values --- #
new_midi_obj = miditoolkit.midi.parser.MidiFile()
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
new_midi_obj.max_tick = midi_obj.max_tick
for instr_idx in instr_grid.keys():
new_instrument = Instrument(program=instr_idx)
new_instrument.notes = [y for x in instr_grid[instr_idx].values() for y in x]
new_midi_obj.instruments.append(new_instrument)
new_midi_obj.markers = [y for x in chord_grid.values() for y in x]
new_midi_obj.tempo_changes = [y for x in tempo_grid.values() for y in x]
new_midi_obj.time_signature_changes = midi_obj.time_signature_changes
# make corpus
song_data = {
'notes': instr_grid,
'chords': chord_grid,
'tempos': tempo_grid,
'metadata': {
'first_note': first_note_time,
'last_note': last_note_time,
'time_signature': quant_time_signature,
'ticks_per_beat': midi_obj.ticks_per_beat,
}
}
return song_data, new_midi_obj
def _make_instr_notes(self, midi_obj):
'''
This part is important, we can use three different ways to merge instruments
1st option: compare the number of notes and choose tracks with more notes
2nd option: merge all instruments with the same tracks
3rd option: leave all instruments as they are. differentiate tracks with different track number
In this version we choose to use the 2nd option as it helps to reduce the number of tracks and sequence length
'''
instr_notes = defaultdict(list)
for instr in midi_obj.instruments:
instr_idx = instr.program
# change instrument idx
instr_name = FINED_PROGRAM_INSTRUMENT_MAP.get(instr_idx)
if instr_name is None:
continue
# new_instr_idx = INSTRUMENT_PROGRAM_MAP[instr_name]
new_instr_idx = instr_idx
if new_instr_idx not in instr_notes:
instr_notes[new_instr_idx] = []
instr_notes[new_instr_idx].extend(instr.notes)
instr_notes[new_instr_idx].sort(key=lambda x: (x.start, -x.pitch))
return instr_notes
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
def _merge_percussion(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
merge drum track to one track
'''
drum_0_lst = []
new_instruments = []
for instrument in midi_obj.instruments:
if len(instrument.notes) == 0:
continue
if instrument.is_drum:
drum_0_lst.extend(instrument.notes)
else:
new_instruments.append(instrument)
if len(drum_0_lst) > 0:
drum_0_lst.sort(key=lambda x: x.start)
# remove duplicate
drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst))
drum_0_instrument = Instrument(program=114, is_drum=True, name="percussion")
drum_0_instrument.notes = drum_0_lst
new_instruments.append(drum_0_instrument)
midi_obj.instruments = new_instruments
# referred to mmt "https://github.com/salu133445/mmt"
def _pruning_instrument(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
merge instrument number with similar intrument category
ex. 0: Acoustic Grand Piano, 1: Bright Acoustic Piano, 2: Electric Grand Piano into 0: Acoustic Grand Piano
'''
new_instruments = []
for instr in midi_obj.instruments:
instr_idx = instr.program
# change instrument idx
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
if instr_name != None:
new_instruments.append(instr)
midi_obj.instruments = new_instruments
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
def _limit_max_track(self, midi_obj:miditoolkit.midi.parser.MidiFile, MAX_TRACK:int=16):
'''
merge track with least notes to other track with same program
and limit the maximum amount of track to 16
'''
if len(midi_obj.instruments) == 1:
if midi_obj.instruments[0].is_drum:
midi_obj.instruments[0].program = 114
midi_obj.instruments[0].is_drum = False
return midi_obj
good_instruments = midi_obj.instruments
good_instruments.sort(
key=lambda x: (not x.is_drum, -len(x.notes))) # place drum track or the most note track at first
assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len(
good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3])
# assert good_instruments[0].is_drum == False, (, len(good_instruments[2]))
track_idx_lst = list(range(len(good_instruments)))
if len(good_instruments) > MAX_TRACK:
new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK])
# print(midi_file_path)
for id in track_idx_lst[MAX_TRACK:]:
cur_ins = good_instruments[id]
merged = False
new_good_instruments.sort(key=lambda x: len(x.notes))
for nid, ins in enumerate(new_good_instruments):
if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum:
new_good_instruments[nid].notes.extend(cur_ins.notes)
merged = True
break
if not merged:
pass
good_instruments = new_good_instruments
assert len(good_instruments) <= MAX_TRACK, len(good_instruments)
for idx, good_instrument in enumerate(good_instruments):
if good_instrument.is_drum:
good_instruments[idx].program = 114
good_instruments[idx].is_drum = False
midi_obj.instruments = good_instruments
def _pruning_notes_for_chord_extraction(self, midi_obj:miditoolkit.midi.parser.MidiFile):
'''
extract notes for chord extraction
'''
new_midi_obj = miditoolkit.midi.parser.MidiFile()
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
new_midi_obj.max_tick = midi_obj.max_tick
new_instrument = Instrument(program=0, is_drum=False, name="for_chord")
new_instruments = []
new_notes = []
for instrument in midi_obj.instruments:
if instrument.program == 114 or instrument.is_drum: # pass drum track
continue
valid_notes = [note for note in instrument.notes if note.pitch >= 21 and note.pitch <= 108]
new_notes.extend(valid_notes)
new_notes.sort(key=lambda x: x.start)
new_instrument.notes = new_notes
new_instruments.append(new_instrument)
new_midi_obj.instruments = new_instruments
return new_midi_obj
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dataset",
required=True,
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
type=str,
help="dataset names",
)
parser.add_argument(
"-f",
"--num_features",
required=True,
choices=(4, 5, 7, 8),
type=int,
help="number of features",
)
parser.add_argument(
"-i",
"--in_dir",
default="../dataset/",
type=Path,
help="input data directory",
)
parser.add_argument(
"-o",
"--out_dir",
default="../dataset/represented_data/corpus/",
type=Path,
help="output data directory",
)
parser.add_argument(
"--debug",
action="store_true",
help="enable debug mode",
)
return parser
def main():
parser = get_argument_parser()
args = parser.parse_args()
corpus_maker = CorpusMaker(args.dataset, args.num_features, args.in_dir, args.out_dir, args.debug)
corpus_maker.make_corpus()
if __name__ == "__main__":
main()
# python3 step1_midi2corpus.py --dataset SOD --num_features 5
# python3 step2_corpus2event.py --dataset LakhClean --num_features 5 --encoding nb
# python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb
# python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb

View File

@ -0,0 +1,147 @@
import argparse
import time
from pathlib import Path
import pickle
from tqdm import tqdm
from multiprocessing import Pool
import encoding_utils
'''
This script is for converting corpus data to event data.
'''
class Corpus2Event():
def __init__(
self,
dataset: str,
encoding_scheme: str,
num_features: int,
in_dir: Path,
out_dir: Path,
debug: bool,
cache: bool,
):
self.dataset = dataset
self.encoding_name = encoding_scheme + str(num_features)
self.in_dir = in_dir / f"corpus_{self.dataset}"
self.out_dir = out_dir / f"events_{self.dataset}" / self.encoding_name
self.debug = debug
self.cache = cache
self.encoding_function = getattr(encoding_utils, f'Corpus2event_{encoding_scheme}')(num_features)
self._get_in_beat_resolution()
def _get_in_beat_resolution(self):
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
try:
self.in_beat_resolution = in_beat_resolution_dict[self.dataset]
except KeyError:
print(f"Dataset {self.dataset} is not supported. use the setting of LakhClean")
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
def make_events(self):
'''
Preprocess corpus data to events data.
The process in each encoding scheme is different.
Please refer to encoding_utils.py for more details.
'''
print("preprocessing corpus data to events data")
# check output directory exists
self.out_dir.mkdir(parents=True, exist_ok=True)
start_time = time.time()
# single-processing
broken_count = 0
success_count = 0
corpus_list = sorted(list(self.in_dir.rglob("*.pkl")))
if corpus_list == []:
print(f"No corpus files found in {self.in_dir}. Please check the directory.")
corpus_list = sorted(list(self.in_dir.glob("*.pkli")))
# remove the corpus files that are already in the out_dir
# Use set for faster existence checks
existing_files = set(f.name for f in self.out_dir.glob("*.pkl"))
# corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)):
if event is None:
broken_count += 1
continue
# if using cache, check if the event file already exists
if self.cache and (self.out_dir / filepath_name).exists():
# print(f"event file {filepath_name} already exists, skipping")
continue
with open(self.out_dir / filepath_name, 'wb') as f:
pickle.dump(event, f)
success_count += 1
del event
print(f"taken time for making events is {time.time()-start_time}s, success: {success_count}, broken: {broken_count}")
def _load_single_corpus_and_make_event(self, file_path):
try:
with open(file_path, 'rb') as f:
corpus = pickle.load(f)
event = self.encoding_function(corpus, self.in_beat_resolution)
except Exception as e:
print(f"error in encoding {file_path}: {e}")
event = None
return file_path.name, event
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dataset",
required=True,
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
type=str,
help="dataset names",
)
parser.add_argument(
"-e",
"--encoding",
required=True,
choices=("remi", "cp", "nb", "remi_pos"),
type=str,
help="encoding scheme",
)
parser.add_argument(
"-f",
"--num_features",
required=True,
choices=(4, 5, 7, 8),
type=int,
help="number of features",
)
parser.add_argument(
"-i",
"--in_dir",
default="../dataset/represented_data/corpus/",
type=Path,
help="input data directory",
)
parser.add_argument(
"-o",
"--out_dir",
default="../dataset/represented_data/events/",
type=Path,
help="output data directory",
)
parser.add_argument(
"--debug",
action="store_true",
help="enable debug mode",
)
parser.add_argument(
"--cache",
action="store_true",
help="enable cache mode",
)
return parser
def main():
args = get_argument_parser().parse_args()
corpus2event = Corpus2Event(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug, args.cache)
corpus2event.make_events()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,84 @@
import argparse
from pathlib import Path
import vocab_utils
'''
This script is for creating vocab file.
'''
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dataset",
required=True,
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
type=str,
help="dataset names",
)
parser.add_argument(
"-e",
"--encoding",
required=True,
choices=("remi", "cp", "nb"),
type=str,
help="encoding scheme",
)
parser.add_argument(
"-f",
"--num_features",
required=True,
choices=(4, 5, 7, 8),
type=int,
help="number of features",
)
parser.add_argument(
"-i",
"--in_dir",
default="../dataset/represented_data/events/",
type=Path,
help="input data directory",
)
parser.add_argument(
"-o",
"--out_dir",
default="../vocab/",
type=Path,
help="output data directory",
)
parser.add_argument(
"--debug",
action="store_true",
help="enable debug mode",
)
return parser
def main():
args = get_argument_parser().parse_args()
encoding_scheme = args.encoding
num_features = args.num_features
dataset = args.dataset
out_vocab_path = args.out_dir / f"vocab_{dataset}"
out_vocab_path.mkdir(parents=True, exist_ok=True)
out_vocab_file_path = out_vocab_path / f"vocab_{dataset}_{encoding_scheme}{num_features}.json"
events_path = Path(args.in_dir / f"events_{dataset}" / f"{encoding_scheme}{num_features}")
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
selected_vocab_name = vocab_name[encoding_scheme]
event_data = sorted(list(events_path.rglob("*.pkl")))
if event_data == []:
print(f"No event files found in {events_path}. Please check the directory.")
event_data = sorted(list(events_path.glob("*.pkli")))
vocab = getattr(vocab_utils, selected_vocab_name)(
in_vocab_file_path=None,
event_data=event_data,
encoding_scheme=encoding_scheme,
num_features=num_features
)
vocab.save_vocab(out_vocab_file_path)
print(f"Vocab file saved at {out_vocab_file_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,127 @@
import argparse
import time
from pathlib import Path
import numpy as np
import pickle
from tqdm import tqdm
import vocab_utils
class Event2tuneidx():
def __init__(
self,
dataset: str,
encoding_scheme: str,
num_features: int,
in_dir: Path,
out_dir: Path,
debug: bool
):
self.dataset = dataset
self.encoding_scheme = encoding_scheme
self.encoding_name = encoding_scheme + str(num_features)
self.in_dir = in_dir / f"events_{self.dataset}" / self.encoding_name
self.out_dir = out_dir / f"tuneidx_{self.dataset}" / self.encoding_name
self.debug = debug
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
selected_vocab_name = vocab_name[encoding_scheme]
in_vocab_file_path = Path(f"../vocab/vocab_{dataset}/vocab_{dataset}_{encoding_scheme}{num_features}.json")
self.vocab = getattr(vocab_utils, selected_vocab_name)(in_vocab_file_path=in_vocab_file_path, event_data=None,
encoding_scheme=encoding_scheme, num_features=num_features)
def _convert_event_to_tune_in_idx(self, tune_in_event):
tune_in_idx = []
for event in tune_in_event:
event_in_idx = self.vocab(event)
if event_in_idx != None:
tune_in_idx.append(event_in_idx)
return tune_in_idx
def _load_single_event_and_make_tune_in_idx(self, file_path):
with open(file_path, 'rb') as f:
tune_in_event = pickle.load(f)
tune_in_idx = self._convert_event_to_tune_in_idx(tune_in_event)
return file_path.name, tune_in_idx
def make_tune_in_idx(self):
print("preprocessing events data to tune_in_idx data")
# check output directory exists
self.out_dir.mkdir(parents=True, exist_ok=True)
start_time = time.time()
event_list = sorted(list(self.in_dir.rglob("*.pkl")))
if event_list == []:
event_list = sorted(list(self.in_dir.glob("*.pkli")))
for filepath_name, tune_in_idx in tqdm(map(self._load_single_event_and_make_tune_in_idx, event_list), total=len(event_list)):
# save tune_in_idx as npz file with uint16 dtype for remi because it has more than 256 tokens
if self.encoding_scheme == 'remi':
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
else:
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
if np.max(tune_in_idx) < 256:
tune_in_idx = np.array(tune_in_idx, dtype=np.uint8)
if filepath_name.endswith('.pkli'):
file_name = filepath_name.replace('.pkli', '.npz')
else:
file_name = filepath_name.replace('.pkl', '.npz')
np.savez_compressed(self.out_dir / file_name, tune_in_idx)
del tune_in_idx
print(f"taken time for making tune_in_idx is {time.time()-start_time}")
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dataset",
required=True,
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
type=str,
help="dataset names",
)
parser.add_argument(
"-e",
"--encoding",
required=True,
choices=("remi", "cp", "nb"),
type=str,
help="encoding scheme",
)
parser.add_argument(
"-f",
"--num_features",
required=True,
choices=(4, 5, 7, 8),
type=int,
help="number of features",
)
parser.add_argument(
"-i",
"--in_dir",
default="../dataset/represented_data/events/",
type=Path,
help="input data directory",
)
parser.add_argument(
"-o",
"--out_dir",
default="../dataset/represented_data/tuneidx/",
type=Path,
help="output data directory",
)
parser.add_argument(
"--debug",
action="store_true",
help="enable debug mode",
)
return parser
def main():
parser = get_argument_parser()
args = parser.parse_args()
event2tuneidx = Event2tuneidx(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug)
event2tuneidx.make_tune_in_idx()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,122 @@
import argparse
import time
from pathlib import Path
import numpy as np
import pickle
from tqdm import tqdm
import vocab_utils
class Event2tuneidx():
def __init__(
self,
dataset: str,
encoding_scheme: str,
num_features: int,
in_dir: Path,
out_dir: Path,
debug: bool
):
self.dataset = dataset
self.encoding_scheme = encoding_scheme
self.encoding_name = encoding_scheme + str(num_features)
self.in_dir = in_dir / f"events_{self.dataset}" / self.encoding_name
self.out_dir = out_dir / f"tuneidx_{self.dataset}" / self.encoding_name
self.debug = debug
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
selected_vocab_name = vocab_name[encoding_scheme]
in_vocab_file_path = Path(f"../vocab/vocab_{dataset}/vocab_{dataset}_{encoding_scheme}{num_features}.json")
self.vocab = getattr(vocab_utils, selected_vocab_name)(in_vocab_file_path=in_vocab_file_path, event_data=None,
encoding_scheme=encoding_scheme, num_features=num_features)
def _convert_event_to_tune_in_idx(self, tune_in_event):
tune_in_idx = []
for event in tune_in_event:
event_in_idx = self.vocab(event)
if event_in_idx != None:
tune_in_idx.append(event_in_idx)
return tune_in_idx
def _load_single_event_and_make_tune_in_idx(self, file_path):
with open(file_path, 'rb') as f:
tune_in_event = pickle.load(f)
tune_in_idx = self._convert_event_to_tune_in_idx(tune_in_event)
return file_path.name, tune_in_idx
def make_tune_in_idx(self):
print("preprocessing events data to tune_in_idx data")
# check output directory exists
self.out_dir.mkdir(parents=True, exist_ok=True)
start_time = time.time()
event_list = sorted(list(self.in_dir.rglob("*.pkl")))
for filepath_name, tune_in_idx in tqdm(map(self._load_single_event_and_make_tune_in_idx, event_list), total=len(event_list)):
# save tune_in_idx as npz file with uint16 dtype for remi because it has more than 256 tokens
if self.encoding_scheme == 'remi':
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
else:
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
if np.max(tune_in_idx) < 256:
tune_in_idx = np.array(tune_in_idx, dtype=np.uint8)
file_name = filepath_name.replace('.pkl', '.npz')
np.savez_compressed(self.out_dir / file_name, tune_in_idx)
del tune_in_idx
print(f"taken time for making tune_in_idx is {time.time()-start_time}")
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--dataset",
required=True,
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
type=str,
help="dataset names",
)
parser.add_argument(
"-e",
"--encoding",
required=True,
choices=("remi", "cp", "nb"),
type=str,
help="encoding scheme",
)
parser.add_argument(
"-f",
"--num_features",
required=True,
choices=(4, 5, 7, 8),
type=int,
help="number of features",
)
parser.add_argument(
"-i",
"--in_dir",
default="../dataset/represented_data/events/",
type=Path,
help="input data directory",
)
parser.add_argument(
"-o",
"--out_dir",
default="../dataset/represented_data/tuneidx_withcaption/",
type=Path,
help="output data directory",
)
parser.add_argument(
"--debug",
action="store_true",
help="enable debug mode",
)
return parser
def main():
parser = get_argument_parser()
args = parser.parse_args()
event2tuneidx = Event2tuneidx(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug)
event2tuneidx.make_tune_in_idx()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,395 @@
import pickle
from pathlib import Path
from typing import Union
from multiprocessing import Pool, cpu_count
from collections import defaultdict
from fractions import Fraction
import torch
import json
from tqdm import tqdm
def sort_key(s):
fraction_part = s.split('_')[-1]
numerator, denominator = map(int, fraction_part.split('/'))
# Return a tuple with denominator first, then numerator, both in negative for descending order
return (-denominator, -numerator)
class LangTokenVocab:
def __init__(
self,
in_vocab_file_path:Union[Path, None],
event_data: list,
encoding_scheme: str,
num_features: int
):
'''
Initializes the LangTokenVocab class.
Args:
in_vocab_file_path (Union[Path, None]): Path to the pre-made vocabulary file (optional).
event_data (list): List of event data to create a vocabulary if no pre-made vocab is provided.
encoding_scheme (str): Encoding scheme to be used (e.g., 'remi', 'cp', 'nb').
num_features (int): Number of features to be used (e.g., 4, 5, 7, 8).
Summary:
This class is responsible for handling vocabularies used in language models, especially for REMI encoding.
It supports multiple encoding schemes, creates vocabularies based on event data, handles special tokens (e.g.,
start/end of sequence), and manages feature-specific masks. It provides methods for saving, loading, and decoding
vocabularies. It also supports vocabulary augmentation for pitch, instrument, beat, and chord features, ensuring
that these are arranged and ordered appropriately.
For all encoding schemes, the metric or special tokens are named as 'type',
so that we can easily handle and compare among different encoding schemes.
'''
self.encoding_scheme = encoding_scheme
self.num_features = num_features
self._prepare_in_vocab(in_vocab_file_path, event_data) # Prepares initial vocab based on the input file or event data
self._get_features() # Extracts relevant features based on the num_features
self.idx2event, self.event2idx = self._get_vocab(event_data, unique_vocabs=self.idx2event) # Creates vocab or loads premade vocab
if self.encoding_scheme == 'remi':
self._make_mask() # Generates masks for 'remi' encoding scheme
self._get_sos_eos_token() # Retrieves special tokens (Start of Sequence, End of Sequence)
# Prepares vocabulary if a pre-made vocab file exists or handles cases with no input file.
def _prepare_in_vocab(self, in_vocab_file_path, event_data):
if in_vocab_file_path is not None:
with open(in_vocab_file_path, 'r') as f:
idx2event_temp = json.load(f)
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
for key in idx2event_temp.keys():
idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()}
elif self.encoding_scheme == 'remi':
idx2event_temp = {int(idx):tok for idx, tok in idx2event_temp.items()}
self.idx2event = idx2event_temp
elif in_vocab_file_path is None and event_data is None:
raise NotImplementedError('either premade vocab or event_data should be given')
else:
self.idx2event = None
# Extracts features depending on the number of features chosen (4, 5, 7, 8).
def _get_features(self):
feature_args = {
4: ["type", "beat", "pitch", "duration"],
5: ["type", "beat", "instrument", "pitch", "duration"],
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
self.feature_list = feature_args[self.num_features]
# Saves the current vocabulary to a specified JSON path.
def save_vocab(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.idx2event, f, indent=2, ensure_ascii=False)
# Returns the size of the current vocabulary.
def get_vocab_size(self):
return len(self.idx2event)
# Handles Start of Sequence (SOS) and End of Sequence (EOS) tokens based on the encoding scheme.
def _get_sos_eos_token(self):
if self.encoding_scheme == 'remi':
self.sos_token = [self.event2idx['SOS_None']]
self.eos_token = [[self.event2idx['EOS_None']]]
else:
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
# Generates vocabularies by either loading from a file or creating them based on the event data.
def _get_vocab(self, event_data, unique_vocabs=None):
# make new vocab from given event_data
if event_data is not None:
unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))]))
unique_vocabs = sorted(unique_char_list)
unique_vocabs.remove('SOS_None')
unique_vocabs.remove('EOS_None')
unique_vocabs.remove('Bar_None')
new_unique_vocab = self._augment_pitch_vocab(unique_vocabs)
if self.num_features == 5 or self.num_features == 8:
new_unique_vocab = self._arange_instrument_vocab(new_unique_vocab)
if self.num_features == 7 or self.num_features == 8:
new_unique_vocab = self._arange_chord_vocab(new_unique_vocab)
new_unique_vocab = self._arange_beat_vocab(new_unique_vocab)
new_unique_vocab.insert(0, 'SOS_None')
new_unique_vocab.insert(1, 'EOS_None')
new_unique_vocab.insert(2, 'Bar_None')
idx2event = {int(idx) : tok for idx, tok in enumerate(new_unique_vocab)}
event2idx = {tok : int(idx) for idx, tok in idx2event.items()}
# load premade vocab
else:
idx2event = unique_vocabs
event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()}
return idx2event, event2idx
# Augments the pitch vocabulary by expanding the range of pitch values.
def _augment_pitch_vocab(self, unique_vocabs):
pitch_vocab = [x for x in unique_vocabs if 'Note_Pitch_' in x]
pitch_int = [int(x.replace('Note_Pitch_', '')) for x in pitch_vocab if x.replace('Note_Pitch_', '').isdigit()]
min_pitch = min(pitch_int)
max_pitch = max(pitch_int)
min_pitch_margin = max(min_pitch-6, 0)
max_pitch_margin = min(max_pitch+7, 127)
new_pitch_vocab = sorted([f'Note_Pitch_{x}' for x in range(min_pitch_margin, max_pitch_margin+1)], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
new_unique_vocab = [x for x in unique_vocabs if x not in new_pitch_vocab] + new_pitch_vocab
return new_unique_vocab
# Orders and arranges the instrument vocabulary.
def _arange_instrument_vocab(self, unique_vocabs):
instrument_vocab = [x for x in unique_vocabs if 'Instrument_' in x]
new_instrument_vocab = sorted(instrument_vocab, key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
new_unique_vocab = [x for x in unique_vocabs if x not in new_instrument_vocab] + new_instrument_vocab
return new_unique_vocab
# Orders and arranges the chord vocabulary, ensuring 'Chord_N_N' is the last token.
def _arange_chord_vocab(self, unique_vocabs):
'''
for chord augmentation
Chord_N_N should be the last token in the list for an easy implementation of chord augmentation
'''
chord_vocab = [x for x in unique_vocabs if 'Chord_' in x]
chord_vocab.remove('Chord_N_N')
new_chord_vocab = sorted(chord_vocab, key=lambda x: (not isinstance(x, int), x.split('_')[-1] if isinstance(x, str) else x, x.split('_')[1] if isinstance(x, str) else x))
new_chord_vocab.append('Chord_N_N')
new_unique_vocab = [x for x in unique_vocabs if x not in new_chord_vocab] + new_chord_vocab
return new_unique_vocab
# Orders and arranges the beat vocabulary.
def _arange_beat_vocab(self, unique_vocabs):
beat_vocab = [x for x in unique_vocabs if 'Beat_' in x]
new_beat_vocab = sorted(beat_vocab, key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
count = 0
for idx, token in enumerate(unique_vocabs):
if 'Beat_' in token:
unique_vocabs[idx] = new_beat_vocab[count]
count += 1
return unique_vocabs
# Generates masks for the 'remi' encoding scheme.
def _make_mask(self):
'''
This function is used to extract the target musical features for validation.
'''
idx2feature = {}
for idx, feature in self.idx2event.items():
if feature.startswith('SOS') or feature.startswith('EOS') or feature.startswith('Bar'):
idx2feature[idx] = 'type'
elif feature.startswith('Beat'):
idx2feature[idx] = 'beat'
elif feature.startswith('Chord'):
idx2feature[idx] = 'chord'
elif feature.startswith('Tempo'):
idx2feature[idx] = 'tempo'
elif feature.startswith('Note_Pitch'):
idx2feature[idx] = 'pitch'
elif feature.startswith('Note_Duration'):
idx2feature[idx] = 'duration'
elif feature.startswith('Note_Velocity'):
idx2feature[idx] = 'velocity'
elif feature.startswith('Instrument'):
idx2feature[idx] = 'instrument'
self.total_mask = {}
self.remi_vocab_boundaries_by_key = {}
for target in self.feature_list:
mask = [0] * len(idx2feature) # Initialize all-zero list of length equal to dictionary
for key, value in idx2feature.items():
if value == target:
mask[int(key)] = 1 # If value equals target, set corresponding position in mask to 1
mask = torch.LongTensor(mask)
self.total_mask[target] = mask
start_idx, end_idx = torch.argwhere(mask == 1).flatten().tolist()[0], torch.argwhere(mask == 1).flatten().tolist()[-1]
self.remi_vocab_boundaries_by_key[target] = (start_idx, end_idx+1)
def decode(self, events:torch.Tensor):
'''
Used for checking events in the evaluation
events: 1d tensor
'''
decoded_list = []
for event in events:
decoded_list.append(self.idx2event[event.item()])
return decoded_list
def __call__(self, word):
'''
for remi style encoding
'''
return self.event2idx[f"{word['name']}_{word['value']}"]
class MusicTokenVocabCP(LangTokenVocab):
def __init__(
self,
in_vocab_file_path:Union[Path, None],
event_data: list,
encoding_scheme: str,
num_features: int
):
# Initialize the vocabulary class with vocab file path, event data, encoding scheme, and feature count
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
def _augment_pitch_vocab(self, unique_vocabs):
# Extract pitch-related vocabularies and adjust pitch range
pitch_total_vocab = unique_vocabs['pitch']
pitch_vocab = [x for x in pitch_total_vocab if 'Note_Pitch_' in str(x)]
pitch_int = [int(x.replace('Note_Pitch_', '')) for x in pitch_vocab if x.replace('Note_Pitch_', '').isdigit()]
# Determine the min and max pitch values and extend the pitch range slightly
min_pitch = min(pitch_int)
max_pitch = max(pitch_int)
min_pitch_margin = max(min_pitch - 6, 0)
max_pitch_margin = min(max_pitch + 7, 127)
# Create new pitch vocab and ensure new entries do not overlap with existing ones
new_pitch_vocab = [f'Note_Pitch_{x}' for x in range(min_pitch_margin, max_pitch_margin + 1)]
new_pitch_vocab = [x for x in pitch_total_vocab if str(x) not in new_pitch_vocab] + new_pitch_vocab
unique_vocabs['pitch'] = new_pitch_vocab
return unique_vocabs
def _mp_get_unique_vocab(self, tune, features):
# Read event data from a file and collect unique vocabularies for specified features
with open(tune, 'rb') as f:
events_list = pickle.load(f)
unique_vocabs = defaultdict(set)
for event in events_list:
for key in features:
unique_vocabs[key].add(event[key])
return unique_vocabs
def _get_chord_vocab(self):
'''
Manually define the chord vocabulary by combining roots and qualities
from a predefined list. This is used for chord augmentation.
'''
root_list = ['A', 'A#', 'B', 'C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#']
quality_list = ['+', '/o7', '7', 'M', 'M7', 'm', 'm7', 'o', 'o7', 'sus2', 'sus4']
chord_vocab = [f'Chord_{root}_{quality}' for root in root_list for quality in quality_list]
# Sort the chord vocabulary based on the root and quality
chord_vocab = sorted(chord_vocab, key=lambda x: (not isinstance(x, int), x.split('_')[-1] if isinstance(x, str) else x, x.split('_')[0] if isinstance(x, str) else x))
return chord_vocab
def _cp_sort_type(self, unique_vocabs):
# Similar to _nb_sort_type but used for the 'cp' encoding scheme, sorting vocabularies in a different order
unique_vocabs.remove('SOS')
unique_vocabs.remove('EOS')
unique_vocabs.remove('Metrical')
unique_vocabs.remove('Note')
vocab_list = list(unique_vocabs)
unique_vocabs = sorted(vocab_list, key=sort_key)
unique_vocabs.insert(0, 'SOS')
unique_vocabs.insert(1, 'EOS')
unique_vocabs.insert(2, 'Metrical')
unique_vocabs.insert(3, 'Note')
return unique_vocabs
# Define custom sorting function
def sort_type_cp(self, item):
if item == 0:
return (0, 0) # Move 0 to the beginning
elif isinstance(item, str):
if item.startswith("Bar"):
return (1, item) # "Bar" items come next, sorted lexicographically
elif item.startswith("Beat"):
# Extract numeric part of "Beat_x" to sort numerically
beat_number = int(item.split('_')[1])
return (2, beat_number) # "Beat" items come last, sorted by number
return (3, item) # Catch-all for anything unexpected (shouldn't be necessary here)
def _get_vocab(self, event_data, unique_vocabs=None):
if event_data is not None:
# Create vocab mappings (event2idx, idx2event) from the provided event data
print('start to get unique vocab')
event2idx = {}
idx2event = {}
unique_vocabs = defaultdict(set)
# Use multiprocessing to extract unique vocabularies for each event
with Pool(16) as p:
results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data]))
# Combine results from different processes
for result in results:
for key in self.feature_list:
if key == 'chord': # Chords are handled separately
continue
unique_vocabs[key].update(result[key])
# Augment pitch vocab and add manually defined chord vocab
unique_vocabs = self._augment_pitch_vocab(unique_vocabs)
unique_vocabs['chord'] = self._get_chord_vocab()
# Process each feature type, handling special cases like 'tempo' and 'chord'
for key in self.feature_list:
if key == 'tempo':
remove_nn_flag = False
if 'Tempo_N_N' in unique_vocabs[key]:
unique_vocabs[key].remove('Tempo_N_N')
remove_nn_flag = True
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
if remove_nn_flag:
unique_vocabs[key].insert(1, 'Tempo_N_N')
elif key == 'chord':
unique_vocabs[key].insert(0, 0)
unique_vocabs[key].insert(1, 'Chord_N_N')
elif key == 'type': # Sort 'type' vocab depending on the encoding scheme
if self.encoding_scheme == 'cp':
unique_vocabs[key] = self._cp_sort_type(unique_vocabs[key])
else: # NB encoding scheme
unique_vocabs[key] = self._nb_sort_type(unique_vocabs[key])
elif key == 'beat' and self.encoding_scheme == 'cp': # Handle 'beat' vocab with 'cp' scheme
# unique_vocabs[key].remove('Bar')
# unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), Fraction(x.split('_')[-1] if isinstance(x, str) else x)))
# unique_vocabs[key].insert(1, 'Bar')
unique_vocabs[key] = sorted(unique_vocabs[key], key = self.sort_type_cp)
elif key == 'beat' and self.encoding_scheme == 'nb': # Handle 'beat' vocab with 'nb' scheme
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
elif key == 'instrument': # Sort 'instrument' vocab by integer values
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
else: # Default case: sort by integer values for other keys
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
# Create event2idx and idx2event mappings for each feature
event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])}
idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])}
return idx2event, event2idx
else:
# If no event data, simply map unique vocab to indexes
event2idx = {}
for key in self.feature_list:
event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()}
return unique_vocabs, event2idx
def get_vocab_size(self):
# Return the size of the vocabulary for each feature
return {key: len(self.idx2event[key]) for key in self.feature_list}
def __call__(self, event):
# Convert an event to its corresponding indices
return [self.event2idx[key][event[key]] for key in self.feature_list]
def decode(self, events:torch.Tensor):
decoded_list = []
for event in events:
decoded_list.append([self.idx2event[key][event[idx].item()] for idx, key in enumerate(self.feature_list)])
return decoded_list
class MusicTokenVocabNB(MusicTokenVocabCP):
def __init__(
self,
in_vocab_file_path:Union[Path, None],
event_data: list,
encoding_scheme: str,
num_features: int
):
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
def _nb_sort_type(self, unique_vocabs):
# Remove special tokens and sort the remaining vocab list, then re-insert the special tokens in order
unique_vocabs.remove('SOS')
unique_vocabs.remove('EOS')
unique_vocabs.remove('Empty_Bar')
unique_vocabs.remove('SSS')
unique_vocabs.remove('SSN')
unique_vocabs.remove('SNN')
vocab_list = list(unique_vocabs)
unique_vocabs = sorted(vocab_list, key=sort_key)
unique_vocabs.insert(0, 'SOS')
unique_vocabs.insert(1, 'EOS')
unique_vocabs.insert(2, 'Empty_Bar')
unique_vocabs.insert(3, 'SSS')
unique_vocabs.insert(4, 'SSN')
unique_vocabs.insert(5, 'SNN')
return unique_vocabs