first commit
This commit is contained in:
81
data_representation/README.md
Normal file
81
data_representation/README.md
Normal file
@ -0,0 +1,81 @@
|
||||
# Dataset Download
|
||||
|
||||
Our model supports four different datasets:
|
||||
|
||||
- **Symbolic Orchestral Database (SOD)**: [Link](https://qsdfo.github.io/LOP/database.html)
|
||||
- **Lakh MIDI Dataset (Clean version)**: [Link](https://colinraffel.com/projects/lmd/)
|
||||
- **Pop1k7**: [Link](https://github.com/YatingMusic/compound-word-transformer)
|
||||
- **Pop909**: [Link](https://github.com/music-x-lab/POP909-Dataset)
|
||||
|
||||
### Download Instructions
|
||||
|
||||
You can download the datasets via the command line:
|
||||
|
||||
```sh
|
||||
# SOD
|
||||
wget https://qsdfo.github.io/LOP/database/SOD.zip
|
||||
|
||||
# LakhClean
|
||||
wget http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz
|
||||
```
|
||||
|
||||
For Pop1k7, the official repository link is currently unavailable. However, you can download it from this Google Drive link:
|
||||
[Download Pop1k7](https://drive.google.com/file/d/1GnbELjE-kQ4WOkBmZ3XapFKIaltySRyV/view?usp=drive_link)
|
||||
|
||||
For Pop909, the dataset is uploaded in the official Github repository: [Repository link](https://github.com/music-x-lab/POP909-Dataset)
|
||||
|
||||
### Using Your Own Dataset
|
||||
If you plan to use your own dataset, you can modify the dataset class in the data_utils.py script under symbolic_encoding folder inside the nested_music_transformer folder. Alternatively, for a simpler approach, rename your dataset to match one of the following options:
|
||||
|
||||
- SOD: Use this for score-based MIDI datasets that require finer-grained quantization (supports up to 16th note triplet level quantization; 24 samples per quarter note).
|
||||
- LakhClean: Suitable for score-based MIDI datasets requiring coarse-grained quantization (supports up to 16th note level quantization; 4 samples per quarter note).
|
||||
- Pop1k7, Pop909: Ideal for expressive-based MIDI datasets requiring coarse-grained quantization (supports up to 16th note level quantization; 4 samples per quarter note).
|
||||
|
||||
# Data Representation
|
||||
|
||||
<p align="center">
|
||||
<img src="figure/Data_Representation_Pipeline.png" width="1000">
|
||||
</p>
|
||||
|
||||
|
||||
This document outlines our standard data processing pipeline. By following the instructions and running the corresponding Python scripts, you can generate a data representation suited to your specific needs.
|
||||
|
||||
We focus on symbolic music and limit the use of musical features to a select few. Each feature set size corresponds to specific musical attributes. Through various experiments, we decided to use **7 features** for the *Pop1k7* and *Pop909* datasets, which consist of pop piano music requiring velocity for expression, and **5 features** for the *Symbolic Orchestral Database (SOD)*, *Lakh MIDI*, and *SymphonyMIDI* datasets.
|
||||
|
||||
- **4 features**: `["type", "beat", "pitch", "duration"]`
|
||||
- **5 features**: `["type", "beat", "instrument", "pitch", "duration"]`
|
||||
- **7 features**: `["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"]`
|
||||
- **8 features**: `["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]`
|
||||
|
||||
## Parse Argument
|
||||
- `-d`, `--dataset`: This required argument specifies the dataset to be used. It takes one of the following values: `"BachChorale"`, `"Pop1k7"`, `"Pop909"`, `"SOD"`, `"LakhClean"`, or `"SymphonyMIDI"`.
|
||||
|
||||
- `-e`, `--encoding`: This required argument specifies the encoding scheme to use. It accepts one of the following: `"remi"`, `"cp"`, `"nb"`, or `"remi_pos"`.
|
||||
|
||||
- `-f`, `--num_features`: This required argument specifies the number of features. It can take one of the following values: `4`, `5`, `7`, or `8`.
|
||||
|
||||
- `-i`, `--in_dir`: This optional argument specifies the input data directory. It defaults to `../dataset/represented_data/corpus/` if not provided.
|
||||
|
||||
- `-o`, `--out_dir`: This optional argument specifies the output data directory. It defaults to `../dataset/represented_data/events/`.
|
||||
|
||||
- `--debug`: This flag enables debug mode when included. No additional value is needed.
|
||||
|
||||
## 1. MIDI to Corpus
|
||||
In this step, we convert MIDI files into a set of events containing various musical information. The MIDI files should be aligned with the beat and contain accurate time signature information. Place the MIDI files in `<nmt/dataset/MIDI_dataset>` and refer to the example files provided. Navigate to the `<nmt/data_representation>` folder and run the script. The converted data will be stored in `<nmt/dataset/represented_data/corpus>`.
|
||||
|
||||
- Example usage: `python3 step1_midi2corpus.py --dataset SOD --num_features 5`
|
||||
|
||||
## 2. Corpus to Event
|
||||
We provide three types of representations: **REMI**, **Compound Word (CP)**, and **Note-based Encoding (NB)**. The converted data will be stored in `<nmt/dataset/represented_data/events>`.
|
||||
|
||||
- Example usage: `python3 step2_corpus2event.py --dataset SOD --num_features 5 --encoding nb`
|
||||
|
||||
## 3. Creating Vocabulary
|
||||
This script creates a vocabulary in the `<nmt/vocab>` folder. The vocabulary includes event-to-index pair information.
|
||||
|
||||
- Example usage: `python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb`
|
||||
|
||||
## 4. Event to Index
|
||||
In this step, we convert events into indices for efficient model training. The converted data will be stored in `<nmt/dataset/represented_data/tuneidx>`.
|
||||
|
||||
- Example usage: `python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb`
|
||||
0
data_representation/__init__.py
Normal file
0
data_representation/__init__.py
Normal file
BIN
data_representation/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
data_representation/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
data_representation/__pycache__/constants.cpython-310.pyc
Normal file
BIN
data_representation/__pycache__/constants.cpython-310.pyc
Normal file
Binary file not shown.
BIN
data_representation/__pycache__/vocab_utils.cpython-310.pyc
Normal file
BIN
data_representation/__pycache__/vocab_utils.cpython-310.pyc
Normal file
Binary file not shown.
422
data_representation/constants.py
Normal file
422
data_representation/constants.py
Normal file
@ -0,0 +1,422 @@
|
||||
import numpy as np
|
||||
|
||||
# for chord analysis
|
||||
NUM2PITCH = {
|
||||
0: 'C',
|
||||
1: 'C#',
|
||||
2: 'D',
|
||||
3: 'D#',
|
||||
4: 'E',
|
||||
5: 'F',
|
||||
6: 'F#',
|
||||
7: 'G',
|
||||
8: 'G#',
|
||||
9: 'A',
|
||||
10: 'A#',
|
||||
11: 'B',
|
||||
}
|
||||
|
||||
# referred to mmt "https://github.com/salu133445/mmt"
|
||||
PROGRAM_INSTRUMENT_MAP = {
|
||||
# Pianos
|
||||
0: "piano",
|
||||
1: "piano",
|
||||
2: "piano",
|
||||
3: "piano",
|
||||
4: "electric-piano",
|
||||
5: "electric-piano",
|
||||
6: "harpsichord",
|
||||
7: "clavinet",
|
||||
# Chromatic Percussion
|
||||
8: "celesta",
|
||||
9: "glockenspiel",
|
||||
10: "music-box",
|
||||
11: "vibraphone",
|
||||
12: "marimba",
|
||||
13: "xylophone",
|
||||
14: "tubular-bells",
|
||||
15: "dulcimer",
|
||||
# Organs
|
||||
16: "organ",
|
||||
17: "organ",
|
||||
18: "organ",
|
||||
19: "church-organ",
|
||||
20: "organ",
|
||||
21: "accordion",
|
||||
22: "harmonica",
|
||||
23: "bandoneon",
|
||||
# Guitars
|
||||
24: "nylon-string-guitar",
|
||||
25: "steel-string-guitar",
|
||||
26: "electric-guitar",
|
||||
27: "electric-guitar",
|
||||
28: "electric-guitar",
|
||||
29: "electric-guitar",
|
||||
30: "electric-guitar",
|
||||
31: "electric-guitar",
|
||||
# Basses
|
||||
32: "bass",
|
||||
33: "electric-bass",
|
||||
34: "electric-bass",
|
||||
35: "electric-bass",
|
||||
36: "slap-bass",
|
||||
37: "slap-bass",
|
||||
38: "synth-bass",
|
||||
39: "synth-bass",
|
||||
# Strings
|
||||
40: "violin",
|
||||
41: "viola",
|
||||
42: "cello",
|
||||
43: "contrabass",
|
||||
44: "strings",
|
||||
45: "strings",
|
||||
46: "harp",
|
||||
47: "timpani",
|
||||
# Ensemble
|
||||
48: "strings",
|
||||
49: "strings",
|
||||
50: "synth-strings",
|
||||
51: "synth-strings",
|
||||
52: "voices",
|
||||
53: "voices",
|
||||
54: "voices",
|
||||
55: "orchestra-hit",
|
||||
# Brass
|
||||
56: "trumpet",
|
||||
57: "trombone",
|
||||
58: "tuba",
|
||||
59: "trumpet",
|
||||
60: "horn",
|
||||
61: "brasses",
|
||||
62: "synth-brasses",
|
||||
63: "synth-brasses",
|
||||
# Reed
|
||||
64: "soprano-saxophone",
|
||||
65: "alto-saxophone",
|
||||
66: "tenor-saxophone",
|
||||
67: "baritone-saxophone",
|
||||
68: "oboe",
|
||||
69: "english-horn",
|
||||
70: "bassoon",
|
||||
71: "clarinet",
|
||||
# Pipe
|
||||
72: "piccolo",
|
||||
73: "flute",
|
||||
74: "recorder",
|
||||
75: "pan-flute",
|
||||
76: None,
|
||||
77: None,
|
||||
78: None,
|
||||
79: "ocarina",
|
||||
# Synth Lead
|
||||
80: "lead",
|
||||
81: "lead",
|
||||
82: "lead",
|
||||
83: "lead",
|
||||
84: "lead",
|
||||
85: "lead",
|
||||
86: "lead",
|
||||
87: "lead",
|
||||
# Synth Pad
|
||||
88: "pad",
|
||||
89: "pad",
|
||||
90: "pad",
|
||||
91: "pad",
|
||||
92: "pad",
|
||||
93: "pad",
|
||||
94: "pad",
|
||||
95: "pad",
|
||||
# Synth Effects
|
||||
96: None,
|
||||
97: None,
|
||||
98: None,
|
||||
99: None,
|
||||
100: None,
|
||||
101: None,
|
||||
102: None,
|
||||
103: None,
|
||||
# Ethnic
|
||||
104: "sitar",
|
||||
105: "banjo",
|
||||
106: "shamisen",
|
||||
107: "koto",
|
||||
108: "kalimba",
|
||||
109: "bag-pipe",
|
||||
110: "violin",
|
||||
111: "shehnai",
|
||||
# Percussive
|
||||
112: None,
|
||||
113: None,
|
||||
114: "steel-drums",
|
||||
115: None,
|
||||
116: None,
|
||||
117: "melodic-tom",
|
||||
118: "synth-drums",
|
||||
119: "synth-drums",
|
||||
# Sound effects
|
||||
120: None,
|
||||
121: None,
|
||||
122: None,
|
||||
123: None,
|
||||
124: None,
|
||||
125: None,
|
||||
126: None,
|
||||
127: None,
|
||||
}
|
||||
|
||||
# referred to mmt "https://github.com/salu133445/mmt"
|
||||
INSTRUMENT_PROGRAM_MAP = {
|
||||
# Pianos
|
||||
"piano": 0,
|
||||
"electric-piano": 4,
|
||||
"harpsichord": 6,
|
||||
"clavinet": 7,
|
||||
# Chromatic Percussion
|
||||
"celesta": 8,
|
||||
"glockenspiel": 9,
|
||||
"music-box": 10,
|
||||
"vibraphone": 11,
|
||||
"marimba": 12,
|
||||
"xylophone": 13,
|
||||
"tubular-bells": 14,
|
||||
"dulcimer": 15,
|
||||
# Organs
|
||||
"organ": 16,
|
||||
"church-organ": 19,
|
||||
"accordion": 21,
|
||||
"harmonica": 22,
|
||||
"bandoneon": 23,
|
||||
# Guitars
|
||||
"nylon-string-guitar": 24,
|
||||
"steel-string-guitar": 25,
|
||||
"electric-guitar": 26,
|
||||
# Basses
|
||||
"bass": 32,
|
||||
"electric-bass": 33,
|
||||
"slap-bass": 36,
|
||||
"synth-bass": 38,
|
||||
# Strings
|
||||
"violin": 40,
|
||||
"viola": 41,
|
||||
"cello": 42,
|
||||
"contrabass": 43,
|
||||
"harp": 46,
|
||||
"timpani": 47,
|
||||
# Ensemble
|
||||
"strings": 49,
|
||||
"synth-strings": 50,
|
||||
"voices": 52,
|
||||
"orchestra-hit": 55,
|
||||
# Brass
|
||||
"trumpet": 56,
|
||||
"trombone": 57,
|
||||
"tuba": 58,
|
||||
"horn": 60,
|
||||
"brasses": 61,
|
||||
"synth-brasses": 62,
|
||||
# Reed
|
||||
"soprano-saxophone": 64,
|
||||
"alto-saxophone": 65,
|
||||
"tenor-saxophone": 66,
|
||||
"baritone-saxophone": 67,
|
||||
"oboe": 68,
|
||||
"english-horn": 69,
|
||||
"bassoon": 70,
|
||||
"clarinet": 71,
|
||||
# Pipe
|
||||
"piccolo": 72,
|
||||
"flute": 73,
|
||||
"recorder": 74,
|
||||
"pan-flute": 75,
|
||||
"ocarina": 79,
|
||||
# Synth Lead
|
||||
"lead": 80,
|
||||
# Synth Pad
|
||||
"pad": 88,
|
||||
# Ethnic
|
||||
"sitar": 104,
|
||||
"banjo": 105,
|
||||
"shamisen": 106,
|
||||
"koto": 107,
|
||||
"kalimba": 108,
|
||||
"bag-pipe": 109,
|
||||
"shehnai": 111,
|
||||
# Percussive
|
||||
"steel-drums": 114,
|
||||
"melodic-tom": 117,
|
||||
"synth-drums": 118,
|
||||
}
|
||||
|
||||
FINED_PROGRAM_INSTRUMENT_MAP ={
|
||||
# Pianos
|
||||
0: "Acoustic-Grand-Piano",
|
||||
1: "Bright-Acoustic-Piano",
|
||||
2: "Electric-Grand-Piano",
|
||||
3: "Honky-Tonk-Piano",
|
||||
4: "Electric-Piano-1",
|
||||
5: "Electric-Piano-2",
|
||||
6: "Harpsichord",
|
||||
7: "Clavinet",
|
||||
|
||||
# Chromatic Percussion
|
||||
8: "Celesta",
|
||||
9: "Glockenspiel",
|
||||
10: "Music-Box",
|
||||
11: "Vibraphone",
|
||||
12: "Marimba",
|
||||
13: "Xylophone",
|
||||
14: "Tubular-Bells",
|
||||
15: "Dulcimer",
|
||||
|
||||
# Organs
|
||||
16: "Drawbar-Organ",
|
||||
17: "Percussive-Organ",
|
||||
18: "Rock-Organ",
|
||||
19: "Church-Organ",
|
||||
20: "Reed-Organ",
|
||||
21: "Accordion",
|
||||
22: "Harmonica",
|
||||
23: "Tango-Accordion",
|
||||
|
||||
# Guitars
|
||||
24: "Acoustic-Guitar-nylon",
|
||||
25: "Acoustic-Guitar-steel",
|
||||
26: "Electric-Guitar-jazz",
|
||||
27: "Electric-Guitar-clean",
|
||||
28: "Electric-Guitar-muted",
|
||||
29: "Overdriven-Guitar",
|
||||
30: "Distortion-Guitar",
|
||||
31: "Guitar-harmonics",
|
||||
|
||||
# Basses
|
||||
32: "Acoustic-Bass",
|
||||
33: "Electric-Bass-finger",
|
||||
34: "Electric-Bass-pick",
|
||||
35: "Fretless-Bass",
|
||||
36: "Slap-Bass-1",
|
||||
37: "Slap-Bass-2",
|
||||
38: "Synth-Bass-1",
|
||||
39: "Synth-Bass-2",
|
||||
|
||||
# Strings & Orchestral
|
||||
40: "Violin",
|
||||
41: "Viola",
|
||||
42: "Cello",
|
||||
43: "Contrabass",
|
||||
44: "Tremolo-Strings",
|
||||
45: "Pizzicato-Strings",
|
||||
46: "Orchestral-Harp",
|
||||
47: "Timpani",
|
||||
|
||||
# Ensemble
|
||||
48: "String-Ensemble-1",
|
||||
49: "String-Ensemble-2",
|
||||
50: "Synth-Strings-1",
|
||||
51: "Synth-Strings-2",
|
||||
52: "Choir-Aahs",
|
||||
53: "Voice-Oohs",
|
||||
54: "Synth-Voice",
|
||||
55: "Orchestra-Hit",
|
||||
|
||||
# Brass
|
||||
56: "Trumpet",
|
||||
57: "Trombone",
|
||||
58: "Tuba",
|
||||
59: "Muted-Trumpet",
|
||||
60: "French-Horn",
|
||||
61: "Brass-Section",
|
||||
62: "Synth-Brass-1",
|
||||
63: "Synth-Brass-2",
|
||||
|
||||
# Reeds
|
||||
64: "Soprano-Sax",
|
||||
65: "Alto-Sax",
|
||||
66: "Tenor-Sax",
|
||||
67: "Baritone-Sax",
|
||||
68: "Oboe",
|
||||
69: "English-Horn",
|
||||
70: "Bassoon",
|
||||
71: "Clarinet",
|
||||
|
||||
# Pipes
|
||||
72: "Piccolo",
|
||||
73: "Flute",
|
||||
74: "Recorder",
|
||||
75: "Pan-Flute",
|
||||
76: "Blown-Bottle",
|
||||
77: "Shakuhachi",
|
||||
78: "Whistle",
|
||||
79: "Ocarina",
|
||||
|
||||
# Synth Lead
|
||||
80: "Lead-1-square",
|
||||
81: "Lead-2-sawtooth",
|
||||
82: "Lead-3-calliope",
|
||||
83: "Lead-4-chiff",
|
||||
84: "Lead-5-charang",
|
||||
85: "Lead-6-voice",
|
||||
86: "Lead-7-fifths",
|
||||
87: "Lead-8-bass+lead",
|
||||
|
||||
# Synth Pad
|
||||
88: "Pad-1-new-age",
|
||||
89: "Pad-2-warm",
|
||||
90: "Pad-3-polysynth",
|
||||
91: "Pad-4-choir",
|
||||
92: "Pad-5-bowed",
|
||||
93: "Pad-6-metallic",
|
||||
94: "Pad-7-halo",
|
||||
95: "Pad-8-sweep",
|
||||
|
||||
# Effects
|
||||
96: "FX-1-rain",
|
||||
97: "FX-2-soundtrack",
|
||||
98: "FX-3-crystal",
|
||||
99: "FX-4-atmosphere",
|
||||
100: "FX-5-brightness",
|
||||
101: "FX-6-goblins",
|
||||
102: "FX-7-echoes",
|
||||
103: "FX-8-sci-fi",
|
||||
|
||||
# Ethnic & Percussion
|
||||
104: "Sitar",
|
||||
105: "Banjo",
|
||||
106: "Shamisen",
|
||||
107: "Koto",
|
||||
108: "Kalimba",
|
||||
109: "Bag-pipe",
|
||||
110: "Fiddle",
|
||||
111: "Shanai",
|
||||
|
||||
# Percussive
|
||||
112: "Tinkle-Bell",
|
||||
113: "Agogo",
|
||||
114: "Steel-Drums",
|
||||
115: "Woodblock",
|
||||
116: "Taiko-Drum",
|
||||
117: "Melodic-Tom",
|
||||
118: "Synth-Drum",
|
||||
119: "Reverse-Cymbal",
|
||||
|
||||
# Sound Effects
|
||||
120: "Guitar-Fret-Noise",
|
||||
121: "Breath-Noise",
|
||||
122: "Seashore",
|
||||
123: "Bird-Tweet",
|
||||
124: "Telephone-Ring",
|
||||
125: "Helicopter",
|
||||
126: "Applause",
|
||||
127: "Gunshot"
|
||||
}
|
||||
|
||||
|
||||
REGULAR_NUM_DENOM = [(1, 1), (1, 2), (2, 2), (3, 2), (4, 2),
|
||||
(1, 4), (2, 4), (3, 4), (4, 4), (5, 4), (6, 4), (7, 4), (8, 4),
|
||||
(1, 8), (2, 8), (3, 8), (4, 8), (5, 8), (6, 8), (7, 8), (8, 8), (9, 8), (11, 8), (12, 8)]
|
||||
CORE_NUM_DENOM = [(1, 1), (1, 2), (2, 2), (4, 2),
|
||||
(1, 4), (2, 4), (3, 4), (4, 4), (5, 4),
|
||||
(1, 8), (2, 8), (3, 8), (6, 8), (9, 8), (12, 8)]
|
||||
VALID_TIME_SIGNATURES = ['time_signature_' + str(x[0]) + '/' + str(x[1]) for x in REGULAR_NUM_DENOM]
|
||||
|
||||
# cover possible time signatures
|
||||
REGULAR_TICKS_PER_BEAT = [48, 96, 192, 384, 120, 240, 480, 960, 256, 512, 1024]
|
||||
879
data_representation/encoding_utils.py
Normal file
879
data_representation/encoding_utils.py
Normal file
@ -0,0 +1,879 @@
|
||||
from typing import Any
|
||||
from fractions import Fraction
|
||||
from collections import defaultdict
|
||||
|
||||
from miditoolkit import TimeSignature
|
||||
|
||||
from constants import *
|
||||
|
||||
'''
|
||||
This script contains specific encoding functions for different encoding schemes.
|
||||
'''
|
||||
|
||||
def frange(start, stop, step):
|
||||
while start < stop:
|
||||
yield start
|
||||
start += step
|
||||
|
||||
################################# for REMI style encoding #################################
|
||||
|
||||
class Corpus2event_remi():
|
||||
def __init__(self, num_features:int):
|
||||
self.num_features = num_features
|
||||
|
||||
def _create_event(self, name, value):
|
||||
event = dict()
|
||||
event['name'] = name
|
||||
event['value'] = value
|
||||
return event
|
||||
def _break_down_numerator(self, numerator, possible_time_signatures):
|
||||
"""Break down a numerator into smaller time signatures.
|
||||
|
||||
Args:
|
||||
numerator: Target numerator to decompose (must be > 0).
|
||||
possible_time_signatures: List of (numerator, denominator) tuples,
|
||||
sorted in descending order (e.g., [(4,4), (3,4)]).
|
||||
|
||||
Returns:
|
||||
List of decomposed time signatures (e.g., [(4,4), (3,4)]).
|
||||
|
||||
Raises:
|
||||
ValueError: If decomposition is impossible.
|
||||
"""
|
||||
if numerator <= 0:
|
||||
raise ValueError("Numerator must be positive.")
|
||||
if not possible_time_signatures:
|
||||
raise ValueError("No possible time signatures provided.")
|
||||
|
||||
result = []
|
||||
original_numerator = numerator # For error message
|
||||
|
||||
# Sort signatures in descending order to prioritize larger chunks
|
||||
possible_time_signatures = sorted(possible_time_signatures, key=lambda x: -x[0])
|
||||
|
||||
while numerator > 0:
|
||||
subtracted = False # Track if any subtraction occurred in this iteration
|
||||
|
||||
for sig in possible_time_signatures:
|
||||
sig_numerator, _ = sig
|
||||
if sig_numerator <= 0:
|
||||
continue # Skip invalid signatures
|
||||
|
||||
while numerator >= sig_numerator:
|
||||
result.append(sig)
|
||||
numerator -= sig_numerator
|
||||
subtracted = True
|
||||
|
||||
# If no progress was made, decomposition failed
|
||||
if not subtracted:
|
||||
raise ValueError(
|
||||
f"Cannot decompose numerator {original_numerator} "
|
||||
f"with given time signatures {possible_time_signatures}. "
|
||||
f"Remaining: {numerator}"
|
||||
)
|
||||
|
||||
return result
|
||||
def _normalize_time_signature(self, time_signature, ticks_per_beat, next_change_point):
|
||||
"""
|
||||
Normalize irregular time signatures to standard ones by breaking them down
|
||||
into common time signatures, and adjusting their durations to fit the given
|
||||
musical structure.
|
||||
|
||||
Parameters:
|
||||
- time_signature: TimeSignature object with numerator, denominator, and start time.
|
||||
- ticks_per_beat: Number of ticks per beat, representing the resolution of the timing.
|
||||
- next_change_point: Tick position where the next time signature change occurs.
|
||||
|
||||
Returns:
|
||||
- A list of TimeSignature objects, normalized to fit within regular time signatures.
|
||||
|
||||
Procedure:
|
||||
1. If the time signature is already a standard one (in REGULAR_NUM_DENOM), return it.
|
||||
2. For non-standard signatures, break them down into simpler, well-known signatures.
|
||||
- For unusual denominations (e.g., 16th, 32nd, or 64th notes), normalize to 4/4.
|
||||
- For 6/4 signatures, break it into two 3/4 measures.
|
||||
3. If the time signature has a non-standard numerator and denominator, break it down
|
||||
into the largest possible numerators that still fit within the denominator.
|
||||
This ensures that the final measure fits within the regular time signature format.
|
||||
4. Calculate the resolution (duration in ticks) for each bar and ensure the bars
|
||||
fit within the time until the next change point.
|
||||
- Adjust the number of bars if they exceed the available space.
|
||||
- If the total length is too short, repeat the first (largest) bar to fill the gap.
|
||||
5. Convert the breakdown into TimeSignature objects and return the normalized result.
|
||||
"""
|
||||
|
||||
# Check if the time signature is a regular one, return it if so
|
||||
if (time_signature.numerator, time_signature.denominator) in REGULAR_NUM_DENOM:
|
||||
return [time_signature]
|
||||
|
||||
# Extract time signature components
|
||||
numerator, denominator, bar_start_tick = time_signature.numerator, time_signature.denominator, time_signature.time
|
||||
|
||||
# Normalize time signatures with 16th, 32nd, or 64th note denominators to 4/4
|
||||
if denominator in [16, 32, 64]:
|
||||
return [TimeSignature(4, 4, time_signature.time)]
|
||||
|
||||
# Special case for 6/4, break it into two 3/4 bars
|
||||
elif denominator == 6 and numerator == 4:
|
||||
return [TimeSignature(3, 4, time_signature.time), TimeSignature(3, 4, time_signature.time)]
|
||||
|
||||
# Determine possible regular signatures for the given denominator
|
||||
possible_time_signatures = [sig for sig in CORE_NUM_DENOM if sig[1] == denominator]
|
||||
|
||||
# Sort by numerator in descending order to prioritize larger numerators
|
||||
possible_time_signatures.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
result = []
|
||||
|
||||
# Break down the numerator into smaller regular numerators
|
||||
max_iterations = 100 # Prevent infinite loops
|
||||
original_numerator = numerator # Store original for error message
|
||||
|
||||
# Break down the numerator into smaller regular numerators
|
||||
iteration_count = 0
|
||||
while numerator > 0:
|
||||
iteration_count += 1
|
||||
if iteration_count > max_iterations:
|
||||
raise ValueError(
|
||||
f"Failed to normalize time signature {original_numerator}/{denominator}. "
|
||||
f"Could not break down numerator {original_numerator} with available signatures: "
|
||||
f"{possible_time_signatures}"
|
||||
)
|
||||
|
||||
for sig in possible_time_signatures:
|
||||
# Subtract numerators and add to the result
|
||||
while numerator >= sig[0]:
|
||||
result.append(sig)
|
||||
numerator -= sig[0]
|
||||
|
||||
|
||||
|
||||
# Calculate the resolution (length in ticks) of each bar
|
||||
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
|
||||
|
||||
# Adjust bars to fit within the remaining ticks before the next change point
|
||||
total_length = 0
|
||||
for idx, bar_resol in enumerate(bar_resol_list):
|
||||
total_length += bar_resol
|
||||
if total_length > next_change_point - bar_start_tick:
|
||||
result = result[:idx+1]
|
||||
break
|
||||
|
||||
# If the total length is too short, repeat the first (largest) bar until the gap is filled
|
||||
while total_length < next_change_point - bar_start_tick:
|
||||
result.append(result[0])
|
||||
total_length += int(ticks_per_beat * result[0][0] * (4 / result[0][1]))
|
||||
|
||||
# Recalculate bar resolutions for the final result
|
||||
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
|
||||
|
||||
# Insert a starting resolution of 0 and calculate absolute tick positions for each TimeSignature
|
||||
bar_resol_list.insert(0, 0)
|
||||
total_length = bar_start_tick
|
||||
normalized_result = []
|
||||
for sig, length in zip(result, bar_resol_list):
|
||||
total_length += length
|
||||
normalized_result.append(TimeSignature(sig[0], sig[1], total_length))
|
||||
|
||||
return normalized_result
|
||||
|
||||
def _process_time_signature(self, time_signature_changes, ticks_per_beat, first_note_tick, global_end):
|
||||
"""
|
||||
Process and normalize time signature changes for a given musical piece.
|
||||
|
||||
Parameters:
|
||||
- time_signature_changes: A list of TimeSignature objects representing time signature changes in the music.
|
||||
- ticks_per_beat: The resolution of timing in ticks per beat.
|
||||
- first_note_tick: The tick position of the first note in the piece.
|
||||
- global_end: The tick position where the piece ends.
|
||||
|
||||
Returns:
|
||||
- A list of processed and normalized time signature changes. If no valid time signature
|
||||
changes are found, returns None.
|
||||
|
||||
Procedure:
|
||||
1. Check the validity of the time signature changes:
|
||||
- Ensure there is at least one time signature change.
|
||||
- Ensure the first time signature change occurs at the beginning (before the first note).
|
||||
2. Remove duplicate consecutive time signatures:
|
||||
- Only add time signatures that differ from the previous one (de-duplication).
|
||||
3. Normalize the time signatures:
|
||||
- For each time signature, determine its duration by calculating the time until the
|
||||
next change point or the end of the piece.
|
||||
- Use the _normalize_time_signature method to break down non-standard signatures into
|
||||
simpler, well-known signatures that fit within the musical structure.
|
||||
4. Return the processed and normalized time signature changes.
|
||||
|
||||
"""
|
||||
|
||||
# Check if there are any time signature changes
|
||||
if len(time_signature_changes) == 0:
|
||||
print("No time signature change in this tune, default to 4/4 time signature")
|
||||
# default to 4/4 time signature if none are found
|
||||
return [TimeSignature(4, 4, 0)]
|
||||
|
||||
# Ensure the first time signature change is at the start of the piece (before the first note)
|
||||
if time_signature_changes[0].time != 0 and time_signature_changes[0].time > first_note_tick:
|
||||
print("The first time signature change is not at the beginning of the tune")
|
||||
return None
|
||||
|
||||
# Remove consecutive duplicate time signatures (de-duplication)
|
||||
processed_time_signature_changes = []
|
||||
for idx, time_sig in enumerate(time_signature_changes):
|
||||
if idx == 0:
|
||||
processed_time_signature_changes.append(time_sig)
|
||||
else:
|
||||
prev_time_sig = time_signature_changes[idx-1]
|
||||
# Only add time signature if it's different from the previous one
|
||||
if not (prev_time_sig.numerator == time_sig.numerator and prev_time_sig.denominator == time_sig.denominator):
|
||||
processed_time_signature_changes.append(time_sig)
|
||||
|
||||
# Normalize the time signatures to standard formats
|
||||
normalized_time_signature_changes = []
|
||||
for idx, time_signature in enumerate(processed_time_signature_changes):
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
# If it's the last time signature change, set the next change point as the end of the piece
|
||||
next_change_point = global_end
|
||||
else:
|
||||
# Otherwise, set the next change point as the next time signature's start time
|
||||
next_change_point = time_signature_changes[idx+1].time
|
||||
|
||||
# Normalize the current time signature and extend the result
|
||||
normalized_time_signature_changes.extend(self._normalize_time_signature(time_signature, ticks_per_beat, next_change_point))
|
||||
|
||||
# Return the list of processed and normalized time signatures
|
||||
time_signature_changes = normalized_time_signature_changes
|
||||
return time_signature_changes
|
||||
|
||||
def _half_step_interval_gap_check_across_instruments(self, instrument_note_dict):
|
||||
'''
|
||||
This function checks for half-step interval gaps between notes across different instruments.
|
||||
It will avoid half-step intervals by keeping one note from any pair of notes that are a half-step apart,
|
||||
regardless of which instrument they belong to.
|
||||
'''
|
||||
# order instrument_note_dict by pitch in descending order
|
||||
instrument_note_dict = dict(sorted(instrument_note_dict.items()))
|
||||
|
||||
# Create a dictionary to store all pitches across instruments
|
||||
all_pitches = {}
|
||||
|
||||
# Collect all pitches from each instrument and sort them in descending order
|
||||
for instrument, notes in instrument_note_dict.items():
|
||||
for pitch, durations in notes.items():
|
||||
all_pitches[pitch] = all_pitches.get(pitch, []) + [(instrument, durations)]
|
||||
|
||||
# Sort the pitches in descending order
|
||||
sorted_pitches = sorted(all_pitches.keys(), reverse=True)
|
||||
|
||||
# Create a new list to store the final pitches after comparison
|
||||
final_pitch_list = []
|
||||
|
||||
# Use an index pointer to control the sliding window
|
||||
idx = 0
|
||||
while idx < len(sorted_pitches) - 1:
|
||||
current_pitch = sorted_pitches[idx]
|
||||
next_pitch = sorted_pitches[idx + 1]
|
||||
|
||||
if current_pitch - next_pitch == 1: # Check for a half-step interval gap
|
||||
current_max_duration = max(duration for _, durations in all_pitches[current_pitch] for duration, _ in durations)
|
||||
next_max_duration = max(duration for _, durations in all_pitches[next_pitch] for duration, _ in durations)
|
||||
|
||||
if current_max_duration < next_max_duration:
|
||||
# Keep the higher pitch (next_pitch) and skip the current_pitch
|
||||
final_pitch_list.append(next_pitch)
|
||||
else:
|
||||
# Keep the lower pitch (current_pitch) and skip the next_pitch
|
||||
final_pitch_list.append(current_pitch)
|
||||
|
||||
# Skip the next pitch because we already handled it
|
||||
idx += 2
|
||||
else:
|
||||
# No half-step gap, keep the current pitch and move to the next one
|
||||
final_pitch_list.append(current_pitch)
|
||||
idx += 1
|
||||
|
||||
# Ensure the last pitch is added if it's not part of a half-step interval
|
||||
if idx == len(sorted_pitches) - 1:
|
||||
final_pitch_list.append(sorted_pitches[-1])
|
||||
|
||||
# Filter out notes not in the final pitch list and update the instrument_note_dict
|
||||
for instrument in instrument_note_dict.keys():
|
||||
instrument_note_dict[instrument] = {
|
||||
pitch: instrument_note_dict[instrument][pitch]
|
||||
for pitch in sorted(instrument_note_dict[instrument].keys(), reverse=True) if pitch in final_pitch_list
|
||||
}
|
||||
|
||||
return instrument_note_dict
|
||||
|
||||
def __call__(self, song_data, in_beat_resolution):
|
||||
'''
|
||||
Process a song's data to generate a sequence of musical events, including bars, chords, tempo,
|
||||
and notes, similar to the approach used in the CP paper (corpus2event_remi_v2).
|
||||
|
||||
Parameters:
|
||||
- song_data: A dictionary containing metadata, notes, chords, and tempos of the song.
|
||||
- in_beat_resolution: The resolution of timing in beats (how many divisions per beat).
|
||||
|
||||
Returns:
|
||||
- A sequence of musical events including start (SOS), bars, chords, tempo, instruments, notes,
|
||||
and an end (EOS) event. If the time signature is invalid, returns None.
|
||||
|
||||
Procedure:
|
||||
1. **Global Setup**:
|
||||
- Extract global metadata like first and last note ticks, time signature changes, and ticks
|
||||
per beat.
|
||||
- Compute `in_beat_tick_resol`, the ratio of ticks per beat to the input beat resolution,
|
||||
to assist in dividing bars later.
|
||||
- Get a sorted list of instruments in the song.
|
||||
|
||||
2. **Time Signature Processing**:
|
||||
- Call `_process_time_signature` to clean up and normalize the time signatures in the song.
|
||||
- If the time signatures are invalid (e.g., no time signature changes or missing at the
|
||||
start), the function exits early with None.
|
||||
|
||||
3. **Sequence Generation**:
|
||||
- Initialize the sequence with a start token (SOS) and prepare variables for tracking
|
||||
previous chord, tempo, and instrument states.
|
||||
- Loop through each time signature change, dividing the song into measures based on the
|
||||
current time signature's numerator and denominator.
|
||||
- For each measure, append "Bar" tokens to mark measure boundaries, while ensuring that no
|
||||
more than four consecutive empty bars are added.
|
||||
- For each step within a measure, process the following:
|
||||
- **Chords**: If there is a chord change, add a corresponding chord event.
|
||||
- **Tempo**: If the tempo changes, add a tempo event.
|
||||
- **Notes**: Iterate over each instrument, adding notes and checking for half-step
|
||||
intervals, deduplicating notes, and choosing the longest duration for each pitch.
|
||||
- Append a "Beat" event for each step with musical events.
|
||||
|
||||
4. **End Sequence**:
|
||||
- Conclude the sequence by appending a final "Bar" token followed by an end token (EOS).
|
||||
'''
|
||||
|
||||
# --- global tag --- #
|
||||
first_note_tick = song_data['metadata']['first_note'] # Starting tick of the first note
|
||||
global_end = song_data['metadata']['last_note'] # Ending tick of the last note
|
||||
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes
|
||||
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat resolution
|
||||
# Resolution for dividing beats within measures, expressed as a fraction
|
||||
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Example: 1024/12 -> (256, 3)
|
||||
instrument_list = sorted(list(song_data['notes'].keys())) # Get a sorted list of instruments in the song
|
||||
|
||||
# --- process time signature --- #
|
||||
# Normalize and process the time signatures in the song
|
||||
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
|
||||
if time_signature_changes == None:
|
||||
return None # Exit if time signature is invalid
|
||||
|
||||
# --- create sequence --- #
|
||||
prev_instr_idx = None # Track the previously processed instrument
|
||||
final_sequence = []
|
||||
final_sequence.append(self._create_event('SOS', None)) # Add Start of Sequence (SOS) token
|
||||
prev_chord = None # Track the previous chord
|
||||
prev_tempo = None # Track the previous tempo
|
||||
chord_value = None
|
||||
tempo_value = None
|
||||
|
||||
# Process each time signature change
|
||||
for idx in range(len(time_signature_changes)):
|
||||
time_sig_change_flag = True # Flag to indicate a time signature change
|
||||
# Calculate bar resolution based on the current time signature
|
||||
numerator = time_signature_changes[idx].numerator
|
||||
denominator = time_signature_changes[idx].denominator
|
||||
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format time signature name
|
||||
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate bar resolution in ticks
|
||||
bar_start_tick = time_signature_changes[idx].time # Start tick of the current bar
|
||||
# Determine the next time signature change point or the end of the song
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
next_change_point = global_end
|
||||
else:
|
||||
next_change_point = time_signature_changes[idx+1].time
|
||||
|
||||
# Process each measure within the current time signature
|
||||
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
|
||||
empty_bar_token = self._create_event('Bar', None) # Token for empty bars
|
||||
|
||||
# Ensure no more than 4 consecutive empty bars are added
|
||||
if len(final_sequence) >= 4:
|
||||
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and
|
||||
final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self._create_event('Bar', time_sig_name)) # Mark new bar with time signature
|
||||
else:
|
||||
final_sequence.append(self._create_event('Bar', None))
|
||||
else:
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self._create_event('Bar', time_sig_name))
|
||||
else:
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self._create_event('Bar', time_sig_name))
|
||||
else:
|
||||
final_sequence.append(self._create_event('Bar', None))
|
||||
|
||||
time_sig_change_flag = False # Reset time signature change flag
|
||||
|
||||
# Process events within each beat
|
||||
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
|
||||
events_list = []
|
||||
# Retrieve chords and tempos at the current beat step
|
||||
t_chords = song_data['chords'].get(beat_step)
|
||||
t_tempos = song_data['tempos'].get(beat_step)
|
||||
|
||||
# Process chord and tempo if the number of features allows for it
|
||||
if self.num_features in {8, 7}:
|
||||
if t_chords is not None:
|
||||
root, quality, _ = t_chords[-1].text.split('_') # Extract chord info
|
||||
chord_value = root + '_' + quality
|
||||
if t_tempos is not None:
|
||||
tempo_value = t_tempos[-1].tempo # Extract tempo value
|
||||
|
||||
# Dictionary to track notes for each instrument to avoid duplicates
|
||||
instrument_note_dict = defaultdict(dict)
|
||||
|
||||
# Process notes for each instrument at the current beat step
|
||||
for instrument_idx in instrument_list:
|
||||
t_notes = song_data['notes'][instrument_idx].get(beat_step)
|
||||
|
||||
# If there are notes at this beat step, process them.
|
||||
if t_notes is not None:
|
||||
# Track notes to avoid duplicates and check for half-step intervals
|
||||
for note in t_notes:
|
||||
if note.pitch not in instrument_note_dict[instrument_idx]:
|
||||
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
|
||||
else:
|
||||
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
|
||||
|
||||
if len(instrument_note_dict) == 0:
|
||||
continue
|
||||
|
||||
# Check for half-step interval gaps and handle them across instruments
|
||||
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
|
||||
|
||||
# add chord and tempo
|
||||
if self.num_features in {7, 8}:
|
||||
if prev_chord != chord_value:
|
||||
events_list.append(self._create_event('Chord', chord_value))
|
||||
prev_chord = chord_value
|
||||
if prev_tempo != tempo_value:
|
||||
events_list.append(self._create_event('Tempo', tempo_value))
|
||||
prev_tempo = tempo_value
|
||||
|
||||
# add instrument and note
|
||||
for instrument in pruned_instrument_note_dict:
|
||||
if self.num_features in {5, 8}:
|
||||
events_list.append(self._create_event('Instrument', instrument))
|
||||
|
||||
for pitch in pruned_instrument_note_dict[instrument]:
|
||||
max_duration = max(pruned_instrument_note_dict[instrument][pitch], key=lambda x: x[0])
|
||||
note_event = [
|
||||
self._create_event('Note_Pitch', pitch),
|
||||
self._create_event('Note_Duration', max_duration[0])
|
||||
]
|
||||
if self.num_features in {7, 8}:
|
||||
note_event.append(self._create_event('Note_Velocity', max_duration[1]))
|
||||
events_list.extend(note_event)
|
||||
|
||||
# If there are events in this step, add a "Beat" event and the collected events
|
||||
if len(events_list):
|
||||
final_sequence.append(self._create_event('Beat', in_beat_off_idx))
|
||||
final_sequence.extend(events_list)
|
||||
|
||||
# --- end with BAR & EOS --- #
|
||||
final_sequence.append(self._create_event('Bar', None)) # Add final bar token
|
||||
final_sequence.append(self._create_event('EOS', None)) # Add End of Sequence (EOS) token
|
||||
return final_sequence
|
||||
|
||||
################################# for CP style encoding #################################
|
||||
|
||||
class Corpus2event_cp(Corpus2event_remi):
|
||||
def __init__(self, num_features):
|
||||
super().__init__(num_features)
|
||||
self.num_features = num_features
|
||||
self._init_event_template()
|
||||
|
||||
def _init_event_template(self):
|
||||
'''
|
||||
The order of musical features is Type, Beat, Chord, Tempo, Instrument, Pitch, Duration, Velocity
|
||||
'''
|
||||
self.event_template = {}
|
||||
if self.num_features == 8:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 7:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 5:
|
||||
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
|
||||
elif self.num_features == 4:
|
||||
feature_names = ['type', 'beat', 'pitch', 'duration']
|
||||
for feature_name in feature_names:
|
||||
self.event_template[feature_name] = 0
|
||||
|
||||
def create_cp_sos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'SOS'
|
||||
return total_event
|
||||
|
||||
def create_cp_eos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'EOS'
|
||||
return total_event
|
||||
|
||||
def create_cp_metrical_event(self, pos, chord, tempo):
|
||||
'''
|
||||
when the compound token is related to metrical information
|
||||
'''
|
||||
meter_event = self.event_template.copy()
|
||||
meter_event['type'] = 'Metrical'
|
||||
meter_event['beat'] = pos
|
||||
if self.num_features == 7 or self.num_features == 8:
|
||||
meter_event['chord'] = chord
|
||||
meter_event['tempo'] = tempo
|
||||
return meter_event
|
||||
|
||||
def create_cp_note_event(self, instrument_name, pitch, duration, velocity):
|
||||
'''
|
||||
when the compound token is related to note information
|
||||
'''
|
||||
note_event = self.event_template.copy()
|
||||
note_event['type'] = 'Note'
|
||||
note_event['pitch'] = pitch
|
||||
note_event['duration'] = duration
|
||||
if self.num_features == 5 or self.num_features == 8:
|
||||
note_event['instrument'] = instrument_name
|
||||
if self.num_features == 7 or self.num_features == 8:
|
||||
note_event['velocity'] = velocity
|
||||
return note_event
|
||||
|
||||
def create_cp_bar_event(self, time_sig_change_flag=False, time_sig_name=None):
|
||||
meter_event = self.event_template.copy()
|
||||
if time_sig_change_flag:
|
||||
meter_event['type'] = 'Metrical'
|
||||
meter_event['beat'] = f'Bar_{time_sig_name}'
|
||||
else:
|
||||
meter_event['type'] = 'Metrical'
|
||||
meter_event['beat'] = 'Bar'
|
||||
return meter_event
|
||||
|
||||
def __call__(self, song_data, in_beat_resolution):
|
||||
# --- global tag --- #
|
||||
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
|
||||
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
|
||||
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
|
||||
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
|
||||
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
|
||||
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
|
||||
|
||||
# --- process time signature --- #
|
||||
# Process time signature changes and adjust them for the given song structure
|
||||
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
|
||||
if time_signature_changes == None:
|
||||
return None # Exit if no valid time signature changes found
|
||||
|
||||
# --- create sequence --- #
|
||||
final_sequence = [] # Initialize the final sequence to store the events
|
||||
final_sequence.append(self.create_cp_sos_event()) # Add the Start-of-Sequence (SOS) event
|
||||
chord_text = None # Placeholder for the current chord
|
||||
tempo_text = None # Placeholder for the current tempo
|
||||
|
||||
# Loop through each time signature change and process the corresponding measures
|
||||
for idx in range(len(time_signature_changes)):
|
||||
time_sig_change_flag = True # Flag to track when time signature changes
|
||||
# Calculate bar resolution (number of ticks per bar based on the time signature)
|
||||
numerator = time_signature_changes[idx].numerator
|
||||
denominator = time_signature_changes[idx].denominator
|
||||
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
|
||||
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
|
||||
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
|
||||
|
||||
# Determine the point for the next time signature change or the end of the song
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
next_change_point = global_end
|
||||
else:
|
||||
next_change_point = time_signature_changes[idx + 1].time
|
||||
|
||||
# Iterate over each measure (bar) between the current and next time signature change
|
||||
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
|
||||
empty_bar_token = self.create_cp_bar_event() # Create an empty bar event
|
||||
|
||||
# Check if the last four events in the sequence are consecutive empty bars
|
||||
if len(final_sequence) >= 4:
|
||||
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
|
||||
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
|
||||
else:
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
|
||||
else:
|
||||
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
|
||||
|
||||
# Reset the time signature change flag after handling the bar event
|
||||
time_sig_change_flag = False
|
||||
|
||||
# Loop through beats in each measure based on the in-beat resolution
|
||||
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
|
||||
chord_tempo_flag = False # Flag to track if chord and tempo events are added
|
||||
events_list = [] # List to hold events for the current beat
|
||||
pos_text = 'Beat_' + str(in_beat_off_idx) # Create a beat event label
|
||||
|
||||
# --- chord & tempo processing --- #
|
||||
# Unpack chords and tempos for the current beat step
|
||||
t_chords = song_data['chords'].get(beat_step)
|
||||
t_tempos = song_data['tempos'].get(beat_step)
|
||||
|
||||
# If a chord is present, extract its root, quality, and bass
|
||||
if self.num_features in {7, 8}:
|
||||
if t_chords is not None:
|
||||
root, quality, _ = t_chords[-1].text.split('_')
|
||||
chord_text = 'Chord_' + root + '_' + quality
|
||||
|
||||
# If a tempo is present, format it as a string
|
||||
if t_tempos is not None:
|
||||
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
|
||||
|
||||
# Dictionary to track notes for each instrument to avoid duplicates
|
||||
instrument_note_dict = defaultdict(dict)
|
||||
|
||||
# --- instrument & note processing --- #
|
||||
# Loop through each instrument and process its notes at the current beat step
|
||||
for instrument_idx in instrument_list:
|
||||
t_notes = song_data['notes'][instrument_idx].get(beat_step)
|
||||
|
||||
# If notes are present, process them
|
||||
if t_notes != None:
|
||||
# Track notes and their properties (duration and velocity) for the current instrument
|
||||
for note in t_notes:
|
||||
if note.pitch not in instrument_note_dict[instrument_idx]:
|
||||
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
|
||||
else:
|
||||
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
|
||||
|
||||
if len(instrument_note_dict) == 0:
|
||||
continue
|
||||
|
||||
# Check for half-step interval gaps and handle them across instruments
|
||||
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
|
||||
|
||||
# add chord and tempo
|
||||
if self.num_features in {7, 8}:
|
||||
if not chord_tempo_flag:
|
||||
if chord_text == None:
|
||||
chord_text = 'Chord_N_N'
|
||||
if tempo_text == None:
|
||||
tempo_text = 'Tempo_N_N'
|
||||
chord_tempo_flag = True
|
||||
|
||||
events_list.append(self.create_cp_metrical_event(pos_text, chord_text, tempo_text))
|
||||
|
||||
# add instrument and note
|
||||
for instrument_idx in pruned_instrument_note_dict:
|
||||
instrument_name = 'Instrument_' + str(instrument_idx)
|
||||
for pitch in pruned_instrument_note_dict[instrument_idx]:
|
||||
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
|
||||
note_pitch_text = 'Note_Pitch_' + str(pitch)
|
||||
note_duration_text = 'Note_Duration_' + str(max_duration[0])
|
||||
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
|
||||
events_list.append(self.create_cp_note_event(instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
|
||||
|
||||
# If there are any events for this beat, add them to the final sequence
|
||||
if len(events_list) > 0:
|
||||
final_sequence.extend(events_list)
|
||||
|
||||
# --- end with BAR & EOS --- #
|
||||
final_sequence.append(self.create_cp_bar_event()) # Add the final bar event
|
||||
final_sequence.append(self.create_cp_eos_event()) # Add the End-of-Sequence (EOS) event
|
||||
return final_sequence # Return the final sequence of events
|
||||
|
||||
################################# for NB style encoding #################################
|
||||
|
||||
class Corpus2event_nb(Corpus2event_cp):
|
||||
def __init__(self, num_features):
|
||||
'''
|
||||
For convenience in logging, we use "type" word for "metric" sub-token in the code to compare easily with other encoding schemes
|
||||
'''
|
||||
super().__init__(num_features)
|
||||
self.num_features = num_features
|
||||
self._init_event_template()
|
||||
|
||||
def _init_event_template(self):
|
||||
self.event_template = {}
|
||||
if self.num_features == 8:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 7:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 5:
|
||||
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
|
||||
elif self.num_features == 4:
|
||||
feature_names = ['type', 'beat', 'pitch', 'duration']
|
||||
for feature_name in feature_names:
|
||||
self.event_template[feature_name] = 0
|
||||
|
||||
def create_nb_sos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'SOS'
|
||||
return total_event
|
||||
|
||||
def create_nb_eos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'EOS'
|
||||
return total_event
|
||||
|
||||
def create_nb_event(self, bar_beat_type, pos, chord, tempo, instrument_name, pitch, duration, velocity):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = bar_beat_type
|
||||
total_event['beat'] = pos
|
||||
total_event['pitch'] = pitch
|
||||
total_event['duration'] = duration
|
||||
if self.num_features in {5, 8}:
|
||||
total_event['instrument'] = instrument_name
|
||||
if self.num_features in {7, 8}:
|
||||
total_event['chord'] = chord
|
||||
total_event['tempo'] = tempo
|
||||
total_event['velocity'] = velocity
|
||||
return total_event
|
||||
|
||||
def create_nb_empty_bar_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'Empty_Bar'
|
||||
return total_event
|
||||
|
||||
def get_bar_beat_idx(self, bar_flag, beat_flag, time_sig_name, time_sig_change_flag):
|
||||
'''
|
||||
This function is to get the metric information for the current bar and beat
|
||||
There are four types of metric information: NNN, SNN, SSN, SSS
|
||||
Each letter represents the change of time signature, bar, and beat (new or same)
|
||||
'''
|
||||
if time_sig_change_flag: # new time signature
|
||||
return "NNN_" + time_sig_name
|
||||
else:
|
||||
if bar_flag and beat_flag: # same time sig & new bar & new beat
|
||||
return "SNN"
|
||||
elif not bar_flag and beat_flag: # same time sig & same bar & new beat
|
||||
return "SSN"
|
||||
elif not bar_flag and not beat_flag: # same time sig & same bar & same beat
|
||||
return "SSS"
|
||||
|
||||
def __call__(self, song_data, in_beat_resolution:int):
|
||||
# --- global tag --- #
|
||||
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
|
||||
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
|
||||
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
|
||||
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
|
||||
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
|
||||
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
|
||||
|
||||
# --- process time signature --- #
|
||||
# Process time signature changes and adjust them for the given song structure
|
||||
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
|
||||
if time_signature_changes == None:
|
||||
return None # Exit if no valid time signature changes found
|
||||
|
||||
# --- create sequence --- #
|
||||
final_sequence = [] # Initialize the final sequence to store the events
|
||||
final_sequence.append(self.create_nb_sos_event()) # Add the Start-of-Sequence (SOS) event
|
||||
chord_text = None # Placeholder for the current chord
|
||||
tempo_text = None # Placeholder for the current tempo
|
||||
|
||||
# Loop through each time signature change and process the corresponding measures
|
||||
for idx in range(len(time_signature_changes)):
|
||||
time_sig_change_flag = True # Flag to track when time signature changes
|
||||
# Calculate bar resolution (number of ticks per bar based on the time signature)
|
||||
numerator = time_signature_changes[idx].numerator
|
||||
denominator = time_signature_changes[idx].denominator
|
||||
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
|
||||
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
|
||||
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
|
||||
|
||||
# Determine the point for the next time signature change or the end of the song
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
next_change_point = global_end
|
||||
else:
|
||||
next_change_point = time_signature_changes[idx + 1].time
|
||||
|
||||
# Iterate over each measure (bar) between the current and next time signature change
|
||||
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
|
||||
bar_flag = True
|
||||
note_flag = False
|
||||
|
||||
# Loop through beats in each measure based on the in-beat resolution
|
||||
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
|
||||
beat_flag = True
|
||||
events_list = []
|
||||
pos_text = 'Beat_' + str(in_beat_off_idx)
|
||||
|
||||
# --- chord & tempo processing --- #
|
||||
# Unpack chords and tempos for the current beat step
|
||||
t_chords = song_data['chords'].get(beat_step)
|
||||
t_tempos = song_data['tempos'].get(beat_step)
|
||||
|
||||
# If a chord is present, extract its root, quality, and bass
|
||||
if self.num_features == 8 or self.num_features == 7:
|
||||
if t_chords is not None:
|
||||
root, quality, _ = t_chords[-1].text.split('_')
|
||||
chord_text = 'Chord_' + root + '_' + quality
|
||||
|
||||
# If a tempo is present, format it as a string
|
||||
if t_tempos is not None:
|
||||
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
|
||||
|
||||
# Dictionary to track notes for each instrument to avoid duplicates
|
||||
instrument_note_dict = defaultdict(dict)
|
||||
|
||||
# --- instrument & note processing --- #
|
||||
# Loop through each instrument and process its notes at the current beat step
|
||||
for instrument_idx in instrument_list:
|
||||
t_notes = song_data['notes'][instrument_idx].get(beat_step)
|
||||
|
||||
# If notes are present, process them
|
||||
if t_notes != None:
|
||||
note_flag = True
|
||||
|
||||
# Track notes and their properties (duration and velocity) for the current instrument
|
||||
for note in t_notes:
|
||||
if note.pitch not in instrument_note_dict[instrument_idx]:
|
||||
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
|
||||
else:
|
||||
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
|
||||
|
||||
# # Check for half-step interval gaps and handle them accordingly
|
||||
# self._half_step_interval_gap_check(instrument_note_dict, instrument_idx)
|
||||
|
||||
if len(instrument_note_dict) == 0:
|
||||
continue
|
||||
|
||||
# Check for half-step interval gaps and handle them across instruments
|
||||
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
|
||||
|
||||
# add chord and tempo
|
||||
if self.num_features in {7, 8}:
|
||||
if chord_text == None:
|
||||
chord_text = 'Chord_N_N'
|
||||
if tempo_text == None:
|
||||
tempo_text = 'Tempo_N_N'
|
||||
|
||||
# add instrument and note
|
||||
for instrument_idx in pruned_instrument_note_dict:
|
||||
instrument_name = 'Instrument_' + str(instrument_idx)
|
||||
for pitch in pruned_instrument_note_dict[instrument_idx]:
|
||||
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
|
||||
note_pitch_text = 'Note_Pitch_' + str(pitch)
|
||||
note_duration_text = 'Note_Duration_' + str(max_duration[0])
|
||||
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
|
||||
bar_beat_type = self.get_bar_beat_idx(bar_flag, beat_flag, time_sig_name, time_sig_change_flag)
|
||||
events_list.append(self.create_nb_event(bar_beat_type, pos_text, chord_text, tempo_text, instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
|
||||
bar_flag = False
|
||||
beat_flag = False
|
||||
time_sig_change_flag = False
|
||||
|
||||
# If there are any events for this beat, add them to the final sequence
|
||||
if events_list != None and len(events_list):
|
||||
final_sequence.extend(events_list)
|
||||
|
||||
# when there is no note in this bar
|
||||
if not note_flag:
|
||||
# avoid consecutive empty bars (more than 4 is not allowed)
|
||||
empty_bar_token = self.create_nb_empty_bar_event()
|
||||
if len(final_sequence) >= 4:
|
||||
if final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token:
|
||||
continue
|
||||
final_sequence.append(empty_bar_token)
|
||||
|
||||
# --- end with BAR & EOS --- #
|
||||
final_sequence.append(self.create_nb_eos_event())
|
||||
return final_sequence
|
||||
650
data_representation/step1_midi2corpus.py
Normal file
650
data_representation/step1_midi2corpus.py
Normal file
@ -0,0 +1,650 @@
|
||||
import argparse
|
||||
import time
|
||||
import itertools
|
||||
import copy
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from collections import defaultdict
|
||||
from fractions import Fraction
|
||||
from typing import List
|
||||
import os
|
||||
from muspy import sort
|
||||
import numpy as np
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
import miditoolkit
|
||||
from miditoolkit.midi.containers import Marker, Instrument
|
||||
from chorder import Dechorder
|
||||
|
||||
from constants import NUM2PITCH, PROGRAM_INSTRUMENT_MAP, INSTRUMENT_PROGRAM_MAP
|
||||
|
||||
'''
|
||||
This script is designed to preprocess MIDI files and convert them into a structured corpus suitable for symbolic music analysis or model training.
|
||||
It handles various tasks, including setting beat resolution, calculating duration, velocity, and tempo bins, and processing MIDI data into quantized musical events.
|
||||
'''
|
||||
|
||||
def get_tempo_bin(max_tempo:int, ratio:float=1.1):
|
||||
bpm = 30
|
||||
regular_tempo_bins = [bpm]
|
||||
while bpm < max_tempo:
|
||||
bpm *= ratio
|
||||
bpm = round(bpm)
|
||||
if bpm > max_tempo:
|
||||
break
|
||||
regular_tempo_bins.append(bpm)
|
||||
return np.array(regular_tempo_bins)
|
||||
|
||||
def split_markers(markers:List[miditoolkit.midi.containers.Marker]):
|
||||
'''
|
||||
split markers into chord, tempo, label
|
||||
'''
|
||||
chords = []
|
||||
for marker in markers:
|
||||
splitted_text = marker.text.split('_')
|
||||
if splitted_text[0] != 'global' and 'Boundary' not in splitted_text[0]:
|
||||
chords.append(marker)
|
||||
return chords
|
||||
|
||||
class CorpusMaker():
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name:str,
|
||||
num_features:int,
|
||||
in_dir:Path,
|
||||
out_dir:Path,
|
||||
debug:bool
|
||||
):
|
||||
'''
|
||||
Initialize the CorpusMaker with dataset information and directory paths.
|
||||
It sets up MIDI paths, output directories, and debug mode, then
|
||||
retrieves the beat resolution, duration bins, velocity/tempo bins, and prepares the MIDI file list.
|
||||
'''
|
||||
self.dataset_name = dataset_name
|
||||
self.num_features = num_features
|
||||
self.midi_path = in_dir / f"{dataset_name}"
|
||||
self.out_dir = out_dir
|
||||
self.debug = debug
|
||||
self._get_in_beat_resolution()
|
||||
self._get_duration_bins()
|
||||
self._get_velocity_tempo_bins()
|
||||
self._get_min_max_last_time()
|
||||
self._prepare_midi_list()
|
||||
|
||||
def _get_in_beat_resolution(self):
|
||||
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
|
||||
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
|
||||
try:
|
||||
self.in_beat_resolution = in_beat_resolution_dict[self.dataset_name]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
||||
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
|
||||
|
||||
def _get_duration_bins(self):
|
||||
# Set up regular duration bins for quantizing note lengths, based on the beat resolution.
|
||||
base_duration = {4:[1,2,3,4,5,6,8,10,12,16,20,24,28,32],
|
||||
8:[1,2,3,4,6,8,10,12,14,16,20,24,28,32,36,40,48,56,64],
|
||||
12:[1,2,3,4,6,9,12,15,18,24,30,36,42,48,54,60,72,84,96]}
|
||||
base_duration_list = base_duration[self.in_beat_resolution]
|
||||
self.regular_duration_bins = np.array(base_duration_list)
|
||||
|
||||
def _get_velocity_tempo_bins(self):
|
||||
# Define velocity and tempo bins based on whether the dataset is a performance or score type.
|
||||
midi_type_dict = {'BachChorale': 'score', 'Pop1k7': 'perform', 'Pop909': 'score', 'SOD': 'score', 'LakhClean': 'score', 'Symphony': 'score'}
|
||||
try:
|
||||
midi_type = midi_type_dict[self.dataset_name]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
||||
midi_type = midi_type_dict['LakhClean']
|
||||
# For performance-type datasets, set finer granularity of velocity and tempo bins.
|
||||
if midi_type == 'perform':
|
||||
self.regular_velocity_bins = np.array(list(range(40, 128, 8)) + [127])
|
||||
self.regular_tempo_bins = get_tempo_bin(max_tempo=240, ratio=1.04)
|
||||
# For score-type datasets, use coarser velocity and tempo bins.
|
||||
elif midi_type == 'score':
|
||||
self.regular_velocity_bins = np.array([40, 60, 80, 100, 120])
|
||||
self.regular_tempo_bins = get_tempo_bin(max_tempo=390, ratio=1.04)
|
||||
|
||||
def _get_min_max_last_time(self):
|
||||
'''
|
||||
Set the minimum and maximum allowed length of a MIDI track, depending on the dataset.
|
||||
0 to 2000 means no limitation
|
||||
'''
|
||||
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
|
||||
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)}
|
||||
try:
|
||||
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
||||
self.min_last_time, self.max_last_time = last_time_dict['LakhClean']
|
||||
|
||||
def _prepare_midi_list(self):
|
||||
midi_path = Path(self.midi_path)
|
||||
# detect subdirectories and get all midi files
|
||||
if not midi_path.exists():
|
||||
raise ValueError(f"midi_path {midi_path} does not exist")
|
||||
# go though all subdirectories and get all midi files
|
||||
midi_files = []
|
||||
for root, _, files in os.walk(midi_path):
|
||||
for file in files:
|
||||
if file.endswith('.mid'):
|
||||
# print(Path(root) / file)
|
||||
midi_files.append(Path(root) / file)
|
||||
self.midi_list = midi_files
|
||||
print(f"Found {len(self.midi_list)} MIDI files in {midi_path}")
|
||||
|
||||
def make_corpus(self) -> None:
|
||||
'''
|
||||
Main method to process the MIDI files and create the corpus data.
|
||||
It supports both single-processing (debug mode) and multi-processing for large datasets.
|
||||
'''
|
||||
print("preprocessing midi data to corpus data")
|
||||
# check the corpus folder is already exist and make it if not
|
||||
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(self.out_dir / f"corpus_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
|
||||
Path(self.out_dir / f"midi_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
|
||||
start_time = time.time()
|
||||
if self.debug:
|
||||
# single processing for debugging
|
||||
broken_counter = 0
|
||||
success_counter = 0
|
||||
for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
|
||||
message = self._mp_midi2corpus(file_path)
|
||||
if message == "error":
|
||||
broken_counter += 1
|
||||
elif message == "success":
|
||||
success_counter += 1
|
||||
else:
|
||||
# Multi-threaded processing for faster corpus generation.
|
||||
broken_counter = 0
|
||||
success_counter = 0
|
||||
# filter out processed files
|
||||
print(self.out_dir)
|
||||
processed_files = list(Path(self.out_dir).glob(f"midi_{self.dataset_name}/*.mid"))
|
||||
processed_files = [x.name for x in processed_files]
|
||||
print(f"processed files: {len(processed_files)}")
|
||||
print("length of midi list: ", len(self.midi_list))
|
||||
# Use set for faster lookup (O(1) per check)
|
||||
processed_files_set = set(processed_files)
|
||||
self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
|
||||
# reverse the list to process the latest files first
|
||||
self.midi_list.reverse()
|
||||
print(f"length of midi list after filtering: ", len(self.midi_list))
|
||||
with Pool(16) as p:
|
||||
for message in tqdm(p.imap(self._mp_midi2corpus, self.midi_list, 1000), total=len(self.midi_list)):
|
||||
if message == "error":
|
||||
broken_counter += 1
|
||||
elif message == "success":
|
||||
success_counter += 1
|
||||
# for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
|
||||
# message = self._mp_midi2corpus(file_path)
|
||||
# if message == "error":
|
||||
# broken_counter += 1
|
||||
# elif message == "success":
|
||||
# success_counter += 1
|
||||
print(f"Making corpus takes: {time.time() - start_time}s, success: {success_counter}, broken: {broken_counter}")
|
||||
|
||||
def _mp_midi2corpus(self, file_path: Path):
|
||||
"""Convert MIDI to corpus format and save both corpus (.pkl) and MIDI (.mid)."""
|
||||
try:
|
||||
midi_obj = self._analyze(file_path)
|
||||
corpus, midi_obj = self._midi2corpus(midi_obj)
|
||||
# --- 1. Save corpus (.pkl) ---
|
||||
relative_path = file_path.relative_to(self.midi_path) # Get relative path from input dir
|
||||
safe_name = str(relative_path).replace("/", "_").replace("\\", "_").replace(".mid", ".pkl")
|
||||
save_path = Path(self.out_dir) / f"corpus_{self.dataset_name}" / safe_name
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
|
||||
with save_path.open("wb") as f:
|
||||
pickle.dump(corpus, f)
|
||||
|
||||
# --- 2. Save MIDI (.mid) ---
|
||||
midi_save_dir = Path("../dataset/represented_data/corpus") / f"midi_{self.dataset_name}"
|
||||
midi_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
midi_save_path = midi_save_dir / file_path.name # Keep original MIDI filename
|
||||
midi_obj.dump(midi_save_path)
|
||||
|
||||
del midi_obj, corpus
|
||||
return "success"
|
||||
|
||||
except (OSError, EOFError, ValueError, KeyError, AssertionError) as e:
|
||||
print(f"Error processing {file_path.name}: {e}")
|
||||
return "error"
|
||||
except Exception as e:
|
||||
print(f"Unexpected error in {file_path.name}: {e}")
|
||||
return "error"
|
||||
def _check_length(self, last_time:float):
|
||||
if last_time < self.min_last_time:
|
||||
raise ValueError(f"last time {last_time} is out of range")
|
||||
|
||||
def _analyze(self, midi_path:Path):
|
||||
# Loads and analyzes a MIDI file, performing various checks and extracting chords.
|
||||
midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
|
||||
|
||||
# check length
|
||||
mapping = midi_obj.get_tick_to_time_mapping()
|
||||
last_time = mapping[midi_obj.max_tick]
|
||||
self._check_length(last_time)
|
||||
|
||||
for ins in midi_obj.instruments:
|
||||
# delete instrument with no notes
|
||||
if len(ins.notes) == 0:
|
||||
midi_obj.instruments.remove(ins)
|
||||
continue
|
||||
notes = ins.notes
|
||||
notes = sorted(notes, key=lambda x: (x.start, x.pitch))
|
||||
|
||||
# three steps to merge instruments
|
||||
self._merge_percussion(midi_obj)
|
||||
self._pruning_instrument(midi_obj)
|
||||
self._limit_max_track(midi_obj)
|
||||
|
||||
if self.num_features == 7 or self.num_features == 8:
|
||||
# in case of 7 or 8 features, we need to extract chords
|
||||
new_midi_obj = self._pruning_notes_for_chord_extraction(midi_obj)
|
||||
chords = Dechorder.dechord(new_midi_obj)
|
||||
markers = []
|
||||
for cidx, chord in enumerate(chords):
|
||||
if chord.is_complete():
|
||||
chord_text = NUM2PITCH[chord.root_pc] + '_' + chord.quality + '_' + NUM2PITCH[chord.bass_pc]
|
||||
else:
|
||||
chord_text = 'N_N_N'
|
||||
markers.append(Marker(time=int(cidx*new_midi_obj.ticks_per_beat), text=chord_text))
|
||||
|
||||
# de-duplication
|
||||
prev_chord = None
|
||||
dedup_chords = []
|
||||
for m in markers:
|
||||
if m.text != prev_chord:
|
||||
prev_chord = m.text
|
||||
dedup_chords.append(m)
|
||||
|
||||
# return midi
|
||||
midi_obj.markers = dedup_chords
|
||||
return midi_obj
|
||||
|
||||
def _pruning_grouped_notes_from_quantization(self, instr_grid:dict):
|
||||
'''
|
||||
In case where notes are grouped in the same quant_time but with different start time, unintentional chords are created
|
||||
rule1: if notes have half step interval, delete the shorter one
|
||||
rule2: if notes do not share 50% of duration of the shorter note, delete the shorter one
|
||||
'''
|
||||
for instr in instr_grid.keys():
|
||||
time_list = sorted(list(instr_grid[instr].keys()))
|
||||
for time in time_list:
|
||||
notes = instr_grid[instr][time]
|
||||
if len(notes) == 1:
|
||||
continue
|
||||
else:
|
||||
new_notes = []
|
||||
# sort in pitch with ascending order
|
||||
notes.sort(key=lambda x: x.pitch)
|
||||
for i in range(len(notes)-1):
|
||||
# if start time is same add to new_notes
|
||||
if notes[i].start == notes[i+1].start:
|
||||
new_notes.append(notes[i])
|
||||
new_notes.append(notes[i+1])
|
||||
continue
|
||||
if notes[i].pitch == notes[i+1].pitch or notes[i].pitch + 1 == notes[i+1].pitch:
|
||||
# select longer note
|
||||
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
|
||||
new_notes.append(notes[i])
|
||||
else:
|
||||
new_notes.append(notes[i+1])
|
||||
else:
|
||||
# check how much duration they share
|
||||
shared_duration = min(notes[i].end, notes[i+1].end) - max(notes[i].start, notes[i+1].start)
|
||||
shorter_duration = min(notes[i].end - notes[i].start, notes[i+1].end - notes[i+1].start)
|
||||
# unless they share more than 80% of duration, select longer note (pruning shorter note)
|
||||
if shared_duration / shorter_duration < 0.8:
|
||||
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
|
||||
new_notes.append(notes[i])
|
||||
else:
|
||||
new_notes.append(notes[i+1])
|
||||
else:
|
||||
if len(new_notes) == 0:
|
||||
new_notes.append(notes[i])
|
||||
new_notes.append(notes[i+1])
|
||||
else:
|
||||
new_notes.append(notes[i+1])
|
||||
instr_grid[instr][time] = new_notes
|
||||
|
||||
def _midi2corpus(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
# Checks if the ticks per beat in the MIDI file is lower than the expected resolution.
|
||||
# If it is, raise an error.
|
||||
if midi_obj.ticks_per_beat < self.in_beat_resolution:
|
||||
raise ValueError(f'[x] Irregular ticks_per_beat. {midi_obj.ticks_per_beat}')
|
||||
|
||||
# Ensure there is at least one time signature change in the MIDI file.
|
||||
# if len(midi_obj.time_signature_changes) == 0:
|
||||
# raise ValueError('[x] No time_signature_changes')
|
||||
|
||||
# Ensure there are no duplicated time signature changes.
|
||||
# time_list = [ts.time for ts in midi_obj.time_signature_changes]
|
||||
# if len(time_list) != len(set(time_list)):
|
||||
# raise ValueError('[x] Duplicated time_signature_changes')
|
||||
|
||||
# If the dataset is 'LakhClean' or 'SymphonyMIDI', verify there are at least 4 tracks.
|
||||
# if self.dataset_name == 'LakhClean' or self.dataset_name == 'SymphonyMIDI':
|
||||
# if len(midi_obj.instruments) < 4:
|
||||
# raise ValueError('[x] We will use more than 4 tracks in Lakh Clean dataset.')
|
||||
|
||||
# Calculate the resolution of ticks per beat as a fraction.
|
||||
in_beat_tick_resol = Fraction(midi_obj.ticks_per_beat, self.in_beat_resolution)
|
||||
|
||||
# Extract the initial time signature (numerator and denominator) and calculate the number of ticks for the first bar.
|
||||
if len(midi_obj.time_signature_changes) != 0:
|
||||
initial_numerator = midi_obj.time_signature_changes[0].numerator
|
||||
initial_denominator = midi_obj.time_signature_changes[0].denominator
|
||||
else:
|
||||
# If no time signature changes, set default values
|
||||
initial_numerator = 4
|
||||
initial_denominator = 4
|
||||
first_bar_resol = int(midi_obj.ticks_per_beat * initial_numerator * (4 / initial_denominator))
|
||||
|
||||
# --- load notes --- #
|
||||
instr_notes = self._make_instr_notes(midi_obj)
|
||||
# --- load information --- #
|
||||
# load chords, labels
|
||||
chords = split_markers(midi_obj.markers)
|
||||
chords.sort(key=lambda x: x.time)
|
||||
|
||||
|
||||
# load tempos
|
||||
tempos = midi_obj.tempo_changes if len(midi_obj.tempo_changes) > 0 else []
|
||||
if len(tempos) == 0:
|
||||
# if no tempo changes, set the default tempo to 120 BPM
|
||||
tempos = [miditoolkit.midi.containers.TempoChange(time=0, tempo=120)]
|
||||
tempos.sort(key=lambda x: x.time)
|
||||
|
||||
# --- process items to grid --- #
|
||||
# compute empty bar offset at head
|
||||
first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
|
||||
last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
|
||||
|
||||
quant_time_first = int(round(first_note_time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
offset = quant_time_first // first_bar_resol # empty bar
|
||||
offset_by_resol = offset * first_bar_resol
|
||||
# --- process notes --- #
|
||||
instr_grid = dict()
|
||||
for key in instr_notes.keys():
|
||||
notes = instr_notes[key]
|
||||
note_grid = defaultdict(list)
|
||||
for note in notes:
|
||||
# skip notes out of range, below C-1 and above C8
|
||||
if note.pitch < 12 or note.pitch >= 120:
|
||||
continue
|
||||
|
||||
# in case when the first note starts at slightly before the first bar
|
||||
note.start = note.start - offset_by_resol if note.start - offset_by_resol > 0 else 0
|
||||
note.end = note.end - offset_by_resol if note.end - offset_by_resol > 0 else 0
|
||||
|
||||
# relative duration
|
||||
# skip note with 0 duration
|
||||
note_duration = note.end - note.start
|
||||
relative_duration = round(note_duration / in_beat_tick_resol)
|
||||
if relative_duration == 0:
|
||||
continue
|
||||
if relative_duration > self.in_beat_resolution * 8: # 8 beats
|
||||
relative_duration = self.in_beat_resolution * 8
|
||||
|
||||
# use regular duration bins
|
||||
note.quantized_duration = self.regular_duration_bins[np.argmin(abs(self.regular_duration_bins-relative_duration))]
|
||||
|
||||
# quantize start time
|
||||
quant_time = int(round(note.start / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
|
||||
# velocity
|
||||
note.velocity = self.regular_velocity_bins[
|
||||
np.argmin(abs(self.regular_velocity_bins-note.velocity))]
|
||||
|
||||
# append
|
||||
note_grid[quant_time].append(note)
|
||||
|
||||
# set to track
|
||||
instr_grid[key] = note_grid
|
||||
|
||||
# --- pruning grouped notes --- #
|
||||
self._pruning_grouped_notes_from_quantization(instr_grid)
|
||||
|
||||
# --- process chords --- #
|
||||
chord_grid = defaultdict(list)
|
||||
for chord in chords:
|
||||
# quantize
|
||||
chord.time = chord.time - offset_by_resol
|
||||
chord.time = 0 if chord.time < 0 else chord.time
|
||||
quant_time = int(round(chord.time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
chord_grid[quant_time].append(chord)
|
||||
|
||||
# --- process tempos --- #
|
||||
|
||||
first_notes_list = []
|
||||
for instr in instr_grid.keys():
|
||||
time_list = sorted(list(instr_grid[instr].keys()))
|
||||
if len(time_list) == 0: # 跳过空轨道
|
||||
continue
|
||||
first_quant_time = time_list[0]
|
||||
first_notes_list.append(first_quant_time)
|
||||
|
||||
# 处理全空情况
|
||||
if not first_notes_list:
|
||||
raise ValueError("[x] No valid notes found in any instrument track.")
|
||||
quant_first_note_time = min(first_notes_list)
|
||||
tempo_grid = defaultdict(list)
|
||||
for tempo in tempos:
|
||||
# quantize
|
||||
tempo.time = tempo.time - offset_by_resol if tempo.time - offset_by_resol > 0 else 0
|
||||
quant_time = int(round(tempo.time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
tempo.tempo = self.regular_tempo_bins[
|
||||
np.argmin(abs(self.regular_tempo_bins-tempo.tempo))]
|
||||
if quant_time < quant_first_note_time:
|
||||
tempo_grid[quant_first_note_time].append(tempo)
|
||||
else:
|
||||
tempo_grid[quant_time].append(tempo)
|
||||
if len(tempo_grid[quant_first_note_time]) > 1:
|
||||
tempo_grid[quant_first_note_time] = [tempo_grid[quant_first_note_time][-1]]
|
||||
# --- process time signature --- #
|
||||
quant_time_signature = deepcopy(midi_obj.time_signature_changes)
|
||||
quant_time_signature.sort(key=lambda x: x.time)
|
||||
for ts in quant_time_signature:
|
||||
ts.time = ts.time - offset_by_resol if ts.time - offset_by_resol > 0 else 0
|
||||
ts.time = int(round(ts.time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
|
||||
# --- make new midi object to check processed values --- #
|
||||
new_midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
|
||||
new_midi_obj.max_tick = midi_obj.max_tick
|
||||
for instr_idx in instr_grid.keys():
|
||||
new_instrument = Instrument(program=instr_idx)
|
||||
new_instrument.notes = [y for x in instr_grid[instr_idx].values() for y in x]
|
||||
new_midi_obj.instruments.append(new_instrument)
|
||||
new_midi_obj.markers = [y for x in chord_grid.values() for y in x]
|
||||
new_midi_obj.tempo_changes = [y for x in tempo_grid.values() for y in x]
|
||||
new_midi_obj.time_signature_changes = midi_obj.time_signature_changes
|
||||
|
||||
# make corpus
|
||||
song_data = {
|
||||
'notes': instr_grid,
|
||||
'chords': chord_grid,
|
||||
'tempos': tempo_grid,
|
||||
'metadata': {
|
||||
'first_note': first_note_time,
|
||||
'last_note': last_note_time,
|
||||
'time_signature': quant_time_signature,
|
||||
'ticks_per_beat': midi_obj.ticks_per_beat,
|
||||
}
|
||||
}
|
||||
return song_data, new_midi_obj
|
||||
|
||||
def _make_instr_notes(self, midi_obj):
|
||||
'''
|
||||
This part is important, we can use three different ways to merge instruments
|
||||
1st option: compare the number of notes and choose tracks with more notes
|
||||
2nd option: merge all instruments with the same tracks
|
||||
3rd option: leave all instruments as they are. differentiate tracks with different track number
|
||||
|
||||
In this version we choose to use the 2nd option as it helps to reduce the number of tracks and sequence length
|
||||
'''
|
||||
instr_notes = defaultdict(list)
|
||||
for instr in midi_obj.instruments:
|
||||
instr_idx = instr.program
|
||||
# change instrument idx
|
||||
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
|
||||
if instr_name is None:
|
||||
continue
|
||||
new_instr_idx = INSTRUMENT_PROGRAM_MAP[instr_name]
|
||||
instr_notes[new_instr_idx].extend(instr.notes)
|
||||
instr_notes[new_instr_idx].sort(key=lambda x: (x.start, -x.pitch))
|
||||
return instr_notes
|
||||
|
||||
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
|
||||
def _merge_percussion(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
'''
|
||||
merge drum track to one track
|
||||
'''
|
||||
drum_0_lst = []
|
||||
new_instruments = []
|
||||
for instrument in midi_obj.instruments:
|
||||
if len(instrument.notes) == 0:
|
||||
continue
|
||||
if instrument.is_drum:
|
||||
drum_0_lst.extend(instrument.notes)
|
||||
else:
|
||||
new_instruments.append(instrument)
|
||||
if len(drum_0_lst) > 0:
|
||||
drum_0_lst.sort(key=lambda x: x.start)
|
||||
# remove duplicate
|
||||
drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst))
|
||||
drum_0_instrument = Instrument(program=114, is_drum=True, name="percussion")
|
||||
drum_0_instrument.notes = drum_0_lst
|
||||
new_instruments.append(drum_0_instrument)
|
||||
midi_obj.instruments = new_instruments
|
||||
|
||||
# referred to mmt "https://github.com/salu133445/mmt"
|
||||
def _pruning_instrument(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
'''
|
||||
merge instrument number with similar intrument category
|
||||
ex. 0: Acoustic Grand Piano, 1: Bright Acoustic Piano, 2: Electric Grand Piano into 0: Acoustic Grand Piano
|
||||
'''
|
||||
new_instruments = []
|
||||
for instr in midi_obj.instruments:
|
||||
instr_idx = instr.program
|
||||
# change instrument idx
|
||||
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
|
||||
if instr_name != None:
|
||||
new_instruments.append(instr)
|
||||
midi_obj.instruments = new_instruments
|
||||
|
||||
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
|
||||
def _limit_max_track(self, midi_obj:miditoolkit.midi.parser.MidiFile, MAX_TRACK:int=16):
|
||||
'''
|
||||
merge track with least notes to other track with same program
|
||||
and limit the maximum amount of track to 16
|
||||
'''
|
||||
if len(midi_obj.instruments) == 1:
|
||||
if midi_obj.instruments[0].is_drum:
|
||||
midi_obj.instruments[0].program = 114
|
||||
midi_obj.instruments[0].is_drum = False
|
||||
return midi_obj
|
||||
good_instruments = midi_obj.instruments
|
||||
good_instruments.sort(
|
||||
key=lambda x: (not x.is_drum, -len(x.notes))) # place drum track or the most note track at first
|
||||
assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len(
|
||||
good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3])
|
||||
# assert good_instruments[0].is_drum == False, (, len(good_instruments[2]))
|
||||
track_idx_lst = list(range(len(good_instruments)))
|
||||
if len(good_instruments) > MAX_TRACK:
|
||||
new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK])
|
||||
# print(midi_file_path)
|
||||
for id in track_idx_lst[MAX_TRACK:]:
|
||||
cur_ins = good_instruments[id]
|
||||
merged = False
|
||||
new_good_instruments.sort(key=lambda x: len(x.notes))
|
||||
for nid, ins in enumerate(new_good_instruments):
|
||||
if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum:
|
||||
new_good_instruments[nid].notes.extend(cur_ins.notes)
|
||||
merged = True
|
||||
break
|
||||
if not merged:
|
||||
pass
|
||||
good_instruments = new_good_instruments
|
||||
|
||||
assert len(good_instruments) <= MAX_TRACK, len(good_instruments)
|
||||
for idx, good_instrument in enumerate(good_instruments):
|
||||
if good_instrument.is_drum:
|
||||
good_instruments[idx].program = 114
|
||||
good_instruments[idx].is_drum = False
|
||||
midi_obj.instruments = good_instruments
|
||||
|
||||
def _pruning_notes_for_chord_extraction(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
'''
|
||||
extract notes for chord extraction
|
||||
'''
|
||||
new_midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
|
||||
new_midi_obj.max_tick = midi_obj.max_tick
|
||||
new_instrument = Instrument(program=0, is_drum=False, name="for_chord")
|
||||
new_instruments = []
|
||||
new_notes = []
|
||||
for instrument in midi_obj.instruments:
|
||||
if instrument.program == 114 or instrument.is_drum: # pass drum track
|
||||
continue
|
||||
valid_notes = [note for note in instrument.notes if note.pitch >= 21 and note.pitch <= 108]
|
||||
new_notes.extend(valid_notes)
|
||||
new_notes.sort(key=lambda x: x.start)
|
||||
new_instrument.notes = new_notes
|
||||
new_instruments.append(new_instrument)
|
||||
new_midi_obj.instruments = new_instruments
|
||||
return new_midi_obj
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
required=True,
|
||||
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
||||
type=str,
|
||||
help="dataset names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--num_features",
|
||||
required=True,
|
||||
choices=(4, 5, 7, 8),
|
||||
type=int,
|
||||
help="number of features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--in_dir",
|
||||
default="../dataset/",
|
||||
type=Path,
|
||||
help="input data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--out_dir",
|
||||
default="../dataset/represented_data/corpus/",
|
||||
type=Path,
|
||||
help="output data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="enable debug mode",
|
||||
)
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = get_argument_parser()
|
||||
args = parser.parse_args()
|
||||
corpus_maker = CorpusMaker(args.dataset, args.num_features, args.in_dir, args.out_dir, args.debug)
|
||||
corpus_maker.make_corpus()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# python3 step1_midi2corpus.py --dataset SOD --num_features 5
|
||||
# python3 step2_corpus2event.py --dataset LakhClean --num_features 5 --encoding nb
|
||||
# python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb
|
||||
# python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb
|
||||
654
data_representation/step1_midi2corpus_fined.py
Normal file
654
data_representation/step1_midi2corpus_fined.py
Normal file
@ -0,0 +1,654 @@
|
||||
import argparse
|
||||
import time
|
||||
import itertools
|
||||
import copy
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from collections import defaultdict
|
||||
from fractions import Fraction
|
||||
from typing import List
|
||||
import os
|
||||
from muspy import sort
|
||||
import numpy as np
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
import miditoolkit
|
||||
from miditoolkit.midi.containers import Marker, Instrument
|
||||
from chorder import Dechorder
|
||||
|
||||
from constants import NUM2PITCH,FINED_PROGRAM_INSTRUMENT_MAP, INSTRUMENT_PROGRAM_MAP
|
||||
|
||||
'''
|
||||
This script is designed to preprocess MIDI files and convert them into a structured corpus suitable for symbolic music analysis or model training.
|
||||
It handles various tasks, including setting beat resolution, calculating duration, velocity, and tempo bins, and processing MIDI data into quantized musical events.
|
||||
We dont do instrument merging here.
|
||||
'''
|
||||
|
||||
def get_tempo_bin(max_tempo:int, ratio:float=1.1):
|
||||
bpm = 30
|
||||
regular_tempo_bins = [bpm]
|
||||
while bpm < max_tempo:
|
||||
bpm *= ratio
|
||||
bpm = round(bpm)
|
||||
if bpm > max_tempo:
|
||||
break
|
||||
regular_tempo_bins.append(bpm)
|
||||
return np.array(regular_tempo_bins)
|
||||
|
||||
def split_markers(markers:List[miditoolkit.midi.containers.Marker]):
|
||||
'''
|
||||
split markers into chord, tempo, label
|
||||
'''
|
||||
chords = []
|
||||
for marker in markers:
|
||||
splitted_text = marker.text.split('_')
|
||||
if splitted_text[0] != 'global' and 'Boundary' not in splitted_text[0]:
|
||||
chords.append(marker)
|
||||
return chords
|
||||
|
||||
class CorpusMaker():
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name:str,
|
||||
num_features:int,
|
||||
in_dir:Path,
|
||||
out_dir:Path,
|
||||
debug:bool
|
||||
):
|
||||
'''
|
||||
Initialize the CorpusMaker with dataset information and directory paths.
|
||||
It sets up MIDI paths, output directories, and debug mode, then
|
||||
retrieves the beat resolution, duration bins, velocity/tempo bins, and prepares the MIDI file list.
|
||||
'''
|
||||
self.dataset_name = dataset_name
|
||||
self.num_features = num_features
|
||||
self.midi_path = in_dir / f"{dataset_name}"
|
||||
self.out_dir = out_dir
|
||||
self.debug = debug
|
||||
self._get_in_beat_resolution()
|
||||
self._get_duration_bins()
|
||||
self._get_velocity_tempo_bins()
|
||||
self._get_min_max_last_time()
|
||||
self._prepare_midi_list()
|
||||
|
||||
def _get_in_beat_resolution(self):
|
||||
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
|
||||
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
|
||||
try:
|
||||
self.in_beat_resolution = in_beat_resolution_dict[self.dataset_name]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
||||
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
|
||||
|
||||
def _get_duration_bins(self):
|
||||
# Set up regular duration bins for quantizing note lengths, based on the beat resolution.
|
||||
base_duration = {4:[1,2,3,4,5,6,8,10,12,16,20,24,28,32],
|
||||
8:[1,2,3,4,6,8,10,12,14,16,20,24,28,32,36,40,48,56,64],
|
||||
12:[1,2,3,4,6,9,12,15,18,24,30,36,42,48,54,60,72,84,96]}
|
||||
base_duration_list = base_duration[self.in_beat_resolution]
|
||||
self.regular_duration_bins = np.array(base_duration_list)
|
||||
|
||||
def _get_velocity_tempo_bins(self):
|
||||
# Define velocity and tempo bins based on whether the dataset is a performance or score type.
|
||||
midi_type_dict = {'BachChorale': 'score', 'Pop1k7': 'perform', 'Pop909': 'score', 'SOD': 'score', 'LakhClean': 'score', 'Symphony': 'score'}
|
||||
try:
|
||||
midi_type = midi_type_dict[self.dataset_name]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
||||
midi_type = midi_type_dict['LakhClean']
|
||||
# For performance-type datasets, set finer granularity of velocity and tempo bins.
|
||||
if midi_type == 'perform':
|
||||
self.regular_velocity_bins = np.array(list(range(40, 128, 8)) + [127])
|
||||
self.regular_tempo_bins = get_tempo_bin(max_tempo=240, ratio=1.04)
|
||||
# For score-type datasets, use coarser velocity and tempo bins.
|
||||
elif midi_type == 'score':
|
||||
self.regular_velocity_bins = np.array([40, 60, 80, 100, 120])
|
||||
self.regular_tempo_bins = get_tempo_bin(max_tempo=390, ratio=1.04)
|
||||
|
||||
def _get_min_max_last_time(self):
|
||||
'''
|
||||
Set the minimum and maximum allowed length of a MIDI track, depending on the dataset.
|
||||
0 to 2000 means no limitation
|
||||
'''
|
||||
# last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (60, 600), 'Symphony': (60, 1500)}
|
||||
last_time_dict = {'BachChorale': (0, 2000), 'Pop1k7': (0, 2000), 'Pop909': (0, 2000), 'SOD': (60, 1000), 'LakhClean': (0, 2000), 'Symphony': (60, 1500)}
|
||||
try:
|
||||
self.min_last_time, self.max_last_time = last_time_dict[self.dataset_name]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset_name} is not supported. use the setting of LakhClean")
|
||||
self.min_last_time, self.max_last_time = last_time_dict['LakhClean']
|
||||
|
||||
def _prepare_midi_list(self):
|
||||
midi_path = Path(self.midi_path)
|
||||
# detect subdirectories and get all midi files
|
||||
if not midi_path.exists():
|
||||
raise ValueError(f"midi_path {midi_path} does not exist")
|
||||
# go though all subdirectories and get all midi files
|
||||
midi_files = []
|
||||
for root, _, files in os.walk(midi_path):
|
||||
for file in files:
|
||||
if file.endswith('.mid') or file.endswith('.midi') or file.endswith('.MID'):
|
||||
# print(Path(root) / file)
|
||||
midi_files.append(Path(root) / file)
|
||||
self.midi_list = midi_files
|
||||
print(f"Found {len(self.midi_list)} MIDI files in {midi_path}")
|
||||
|
||||
def make_corpus(self) -> None:
|
||||
'''
|
||||
Main method to process the MIDI files and create the corpus data.
|
||||
It supports both single-processing (debug mode) and multi-processing for large datasets.
|
||||
'''
|
||||
print("preprocessing midi data to corpus data")
|
||||
# check the corpus folder is already exist and make it if not
|
||||
Path(self.out_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(self.out_dir / f"corpus_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
|
||||
Path(self.out_dir / f"midi_{self.dataset_name}").mkdir(parents=True, exist_ok=True)
|
||||
start_time = time.time()
|
||||
if self.debug:
|
||||
# single processing for debugging
|
||||
broken_counter = 0
|
||||
success_counter = 0
|
||||
for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
|
||||
message = self._mp_midi2corpus(file_path)
|
||||
if message == "error":
|
||||
broken_counter += 1
|
||||
elif message == "success":
|
||||
success_counter += 1
|
||||
else:
|
||||
# Multi-threaded processing for faster corpus generation.
|
||||
broken_counter = 0
|
||||
success_counter = 0
|
||||
# filter out processed files
|
||||
print(self.out_dir)
|
||||
processed_files = list(Path(self.out_dir).glob(f"midi_{self.dataset_name}/*.mid"))
|
||||
processed_files = [x.name for x in processed_files]
|
||||
print(f"processed files: {len(processed_files)}")
|
||||
print("length of midi list: ", len(self.midi_list))
|
||||
# Use set for faster lookup (O(1) per check)
|
||||
processed_files_set = set(processed_files)
|
||||
# self.midi_list = [x for x in self.midi_list if x.name not in processed_files_set]
|
||||
# reverse the list to process the latest files first
|
||||
self.midi_list.reverse()
|
||||
print(f"length of midi list after filtering: ", len(self.midi_list))
|
||||
with Pool(16) as p:
|
||||
for message in tqdm(p.imap(self._mp_midi2corpus, self.midi_list, 500), total=len(self.midi_list)):
|
||||
if message == "error":
|
||||
broken_counter += 1
|
||||
elif message == "success":
|
||||
success_counter += 1
|
||||
# for file_path in tqdm(self.midi_list, total=len(self.midi_list)):
|
||||
# message = self._mp_midi2corpus(file_path)
|
||||
# if message == "error":
|
||||
# broken_counter += 1
|
||||
# elif message == "success":
|
||||
# success_counter += 1
|
||||
print(f"Making corpus takes: {time.time() - start_time}s, success: {success_counter}, broken: {broken_counter}")
|
||||
|
||||
def _mp_midi2corpus(self, file_path: Path):
|
||||
"""Convert MIDI to corpus format and save both corpus (.pkl) and MIDI (.mid)."""
|
||||
try:
|
||||
midi_obj = self._analyze(file_path)
|
||||
corpus, midi_obj = self._midi2corpus(midi_obj)
|
||||
# --- 1. Save corpus (.pkl) ---
|
||||
relative_path = file_path.relative_to(self.midi_path) # Get relative path from input dir
|
||||
safe_name = str(relative_path).replace("/", "_").replace("\\", "_").replace(".mid", ".pkl")
|
||||
save_path = Path(self.out_dir) / f"corpus_{self.dataset_name}" / safe_name
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
|
||||
with save_path.open("wb") as f:
|
||||
pickle.dump(corpus, f)
|
||||
|
||||
# --- 2. Save MIDI (.mid) ---
|
||||
midi_save_dir = Path("../dataset/represented_data/corpus") / f"midi_{self.dataset_name}"
|
||||
midi_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
midi_save_path = midi_save_dir / file_path.name # Keep original MIDI filename
|
||||
midi_obj.dump(midi_save_path)
|
||||
|
||||
del midi_obj, corpus
|
||||
return "success"
|
||||
|
||||
except (OSError, EOFError, ValueError, KeyError, AssertionError) as e:
|
||||
print(f"Error processing {file_path.name}: {e}")
|
||||
return "error"
|
||||
except Exception as e:
|
||||
print(f"Unexpected error in {file_path.name}: {e}")
|
||||
return "error"
|
||||
def _check_length(self, last_time:float):
|
||||
if last_time < self.min_last_time:
|
||||
raise ValueError(f"last time {last_time} is out of range")
|
||||
|
||||
def _analyze(self, midi_path:Path):
|
||||
# Loads and analyzes a MIDI file, performing various checks and extracting chords.
|
||||
midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
|
||||
|
||||
# check length
|
||||
mapping = midi_obj.get_tick_to_time_mapping()
|
||||
last_time = mapping[midi_obj.max_tick]
|
||||
self._check_length(last_time)
|
||||
|
||||
for ins in midi_obj.instruments:
|
||||
# delete instrument with no notes
|
||||
if len(ins.notes) == 0:
|
||||
midi_obj.instruments.remove(ins)
|
||||
continue
|
||||
notes = ins.notes
|
||||
notes = sorted(notes, key=lambda x: (x.start, x.pitch))
|
||||
|
||||
# three steps to merge instruments
|
||||
self._merge_percussion(midi_obj)
|
||||
# self._pruning_instrument(midi_obj)
|
||||
self._limit_max_track(midi_obj)
|
||||
|
||||
if self.num_features == 7 or self.num_features == 8:
|
||||
# in case of 7 or 8 features, we need to extract chords
|
||||
new_midi_obj = self._pruning_notes_for_chord_extraction(midi_obj)
|
||||
chords = Dechorder.dechord(new_midi_obj)
|
||||
markers = []
|
||||
for cidx, chord in enumerate(chords):
|
||||
if chord.is_complete():
|
||||
chord_text = NUM2PITCH[chord.root_pc] + '_' + chord.quality + '_' + NUM2PITCH[chord.bass_pc]
|
||||
else:
|
||||
chord_text = 'N_N_N'
|
||||
markers.append(Marker(time=int(cidx*new_midi_obj.ticks_per_beat), text=chord_text))
|
||||
|
||||
# de-duplication
|
||||
prev_chord = None
|
||||
dedup_chords = []
|
||||
for m in markers:
|
||||
if m.text != prev_chord:
|
||||
prev_chord = m.text
|
||||
dedup_chords.append(m)
|
||||
|
||||
# return midi
|
||||
midi_obj.markers = dedup_chords
|
||||
return midi_obj
|
||||
|
||||
def _pruning_grouped_notes_from_quantization(self, instr_grid:dict):
|
||||
'''
|
||||
In case where notes are grouped in the same quant_time but with different start time, unintentional chords are created
|
||||
rule1: if notes have half step interval, delete the shorter one
|
||||
rule2: if notes do not share 50% of duration of the shorter note, delete the shorter one
|
||||
'''
|
||||
for instr in instr_grid.keys():
|
||||
time_list = sorted(list(instr_grid[instr].keys()))
|
||||
for time in time_list:
|
||||
notes = instr_grid[instr][time]
|
||||
if len(notes) == 1:
|
||||
continue
|
||||
else:
|
||||
new_notes = []
|
||||
# sort in pitch with ascending order
|
||||
notes.sort(key=lambda x: x.pitch)
|
||||
for i in range(len(notes)-1):
|
||||
# if start time is same add to new_notes
|
||||
if notes[i].start == notes[i+1].start:
|
||||
new_notes.append(notes[i])
|
||||
new_notes.append(notes[i+1])
|
||||
continue
|
||||
if notes[i].pitch == notes[i+1].pitch or notes[i].pitch + 1 == notes[i+1].pitch:
|
||||
# select longer note
|
||||
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
|
||||
new_notes.append(notes[i])
|
||||
else:
|
||||
new_notes.append(notes[i+1])
|
||||
else:
|
||||
# check how much duration they share
|
||||
shared_duration = min(notes[i].end, notes[i+1].end) - max(notes[i].start, notes[i+1].start)
|
||||
shorter_duration = min(notes[i].end - notes[i].start, notes[i+1].end - notes[i+1].start)
|
||||
# unless they share more than 80% of duration, select longer note (pruning shorter note)
|
||||
if shared_duration / shorter_duration < 0.8:
|
||||
if notes[i].end - notes[i].start > notes[i+1].end - notes[i+1].start:
|
||||
new_notes.append(notes[i])
|
||||
else:
|
||||
new_notes.append(notes[i+1])
|
||||
else:
|
||||
if len(new_notes) == 0:
|
||||
new_notes.append(notes[i])
|
||||
new_notes.append(notes[i+1])
|
||||
else:
|
||||
new_notes.append(notes[i+1])
|
||||
instr_grid[instr][time] = new_notes
|
||||
|
||||
def _midi2corpus(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
# Checks if the ticks per beat in the MIDI file is lower than the expected resolution.
|
||||
# If it is, raise an error.
|
||||
if midi_obj.ticks_per_beat < self.in_beat_resolution:
|
||||
raise ValueError(f'[x] Irregular ticks_per_beat. {midi_obj.ticks_per_beat}')
|
||||
|
||||
# Ensure there is at least one time signature change in the MIDI file.
|
||||
# if len(midi_obj.time_signature_changes) == 0:
|
||||
# raise ValueError('[x] No time_signature_changes')
|
||||
|
||||
# Ensure there are no duplicated time signature changes.
|
||||
# time_list = [ts.time for ts in midi_obj.time_signature_changes]
|
||||
# if len(time_list) != len(set(time_list)):
|
||||
# raise ValueError('[x] Duplicated time_signature_changes')
|
||||
|
||||
# If the dataset is 'LakhClean' or 'SymphonyMIDI', verify there are at least 4 tracks.
|
||||
# if self.dataset_name == 'LakhClean' or self.dataset_name == 'SymphonyMIDI':
|
||||
# if len(midi_obj.instruments) < 4:
|
||||
# raise ValueError('[x] We will use more than 4 tracks in Lakh Clean dataset.')
|
||||
|
||||
# Calculate the resolution of ticks per beat as a fraction.
|
||||
in_beat_tick_resol = Fraction(midi_obj.ticks_per_beat, self.in_beat_resolution)
|
||||
|
||||
# Extract the initial time signature (numerator and denominator) and calculate the number of ticks for the first bar.
|
||||
if len(midi_obj.time_signature_changes) != 0:
|
||||
initial_numerator = midi_obj.time_signature_changes[0].numerator
|
||||
initial_denominator = midi_obj.time_signature_changes[0].denominator
|
||||
else:
|
||||
# If no time signature changes, set default values
|
||||
initial_numerator = 4
|
||||
initial_denominator = 4
|
||||
first_bar_resol = int(midi_obj.ticks_per_beat * initial_numerator * (4 / initial_denominator))
|
||||
|
||||
# --- load notes --- #
|
||||
instr_notes = self._make_instr_notes(midi_obj)
|
||||
# --- load information --- #
|
||||
# load chords, labels
|
||||
chords = split_markers(midi_obj.markers)
|
||||
chords.sort(key=lambda x: x.time)
|
||||
|
||||
|
||||
# load tempos
|
||||
tempos = midi_obj.tempo_changes if len(midi_obj.tempo_changes) > 0 else []
|
||||
if len(tempos) == 0:
|
||||
# if no tempo changes, set the default tempo to 120 BPM
|
||||
tempos = [miditoolkit.midi.containers.TempoChange(time=0, tempo=120)]
|
||||
tempos.sort(key=lambda x: x.time)
|
||||
|
||||
# --- process items to grid --- #
|
||||
# compute empty bar offset at head
|
||||
first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
|
||||
last_note_time = max([instr_notes[k][-1].start for k in instr_notes.keys()])
|
||||
|
||||
quant_time_first = int(round(first_note_time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
offset = quant_time_first // first_bar_resol # empty bar
|
||||
offset_by_resol = offset * first_bar_resol
|
||||
# --- process notes --- #
|
||||
instr_grid = dict()
|
||||
for key in instr_notes.keys():
|
||||
notes = instr_notes[key]
|
||||
note_grid = defaultdict(list)
|
||||
for note in notes:
|
||||
# skip notes out of range, below C-1 and above C8
|
||||
if note.pitch < 12 or note.pitch >= 120:
|
||||
continue
|
||||
|
||||
# in case when the first note starts at slightly before the first bar
|
||||
note.start = note.start - offset_by_resol if note.start - offset_by_resol > 0 else 0
|
||||
note.end = note.end - offset_by_resol if note.end - offset_by_resol > 0 else 0
|
||||
|
||||
# relative duration
|
||||
# skip note with 0 duration
|
||||
note_duration = note.end - note.start
|
||||
relative_duration = round(note_duration / in_beat_tick_resol)
|
||||
if relative_duration == 0:
|
||||
continue
|
||||
if relative_duration > self.in_beat_resolution * 8: # 8 beats
|
||||
relative_duration = self.in_beat_resolution * 8
|
||||
|
||||
# use regular duration bins
|
||||
note.quantized_duration = self.regular_duration_bins[np.argmin(abs(self.regular_duration_bins-relative_duration))]
|
||||
|
||||
# quantize start time
|
||||
quant_time = int(round(note.start / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
|
||||
# velocity
|
||||
note.velocity = self.regular_velocity_bins[
|
||||
np.argmin(abs(self.regular_velocity_bins-note.velocity))]
|
||||
|
||||
# append
|
||||
note_grid[quant_time].append(note)
|
||||
|
||||
# set to track
|
||||
instr_grid[key] = note_grid
|
||||
|
||||
# --- pruning grouped notes --- #
|
||||
self._pruning_grouped_notes_from_quantization(instr_grid)
|
||||
|
||||
# --- process chords --- #
|
||||
chord_grid = defaultdict(list)
|
||||
for chord in chords:
|
||||
# quantize
|
||||
chord.time = chord.time - offset_by_resol
|
||||
chord.time = 0 if chord.time < 0 else chord.time
|
||||
quant_time = int(round(chord.time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
chord_grid[quant_time].append(chord)
|
||||
|
||||
# --- process tempos --- #
|
||||
|
||||
first_notes_list = []
|
||||
for instr in instr_grid.keys():
|
||||
time_list = sorted(list(instr_grid[instr].keys()))
|
||||
if len(time_list) == 0: # 跳过空轨道
|
||||
continue
|
||||
first_quant_time = time_list[0]
|
||||
first_notes_list.append(first_quant_time)
|
||||
|
||||
# 处理全空情况
|
||||
if not first_notes_list:
|
||||
raise ValueError("[x] No valid notes found in any instrument track.")
|
||||
quant_first_note_time = min(first_notes_list)
|
||||
tempo_grid = defaultdict(list)
|
||||
for tempo in tempos:
|
||||
# quantize
|
||||
tempo.time = tempo.time - offset_by_resol if tempo.time - offset_by_resol > 0 else 0
|
||||
quant_time = int(round(tempo.time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
tempo.tempo = self.regular_tempo_bins[
|
||||
np.argmin(abs(self.regular_tempo_bins-tempo.tempo))]
|
||||
if quant_time < quant_first_note_time:
|
||||
tempo_grid[quant_first_note_time].append(tempo)
|
||||
else:
|
||||
tempo_grid[quant_time].append(tempo)
|
||||
if len(tempo_grid[quant_first_note_time]) > 1:
|
||||
tempo_grid[quant_first_note_time] = [tempo_grid[quant_first_note_time][-1]]
|
||||
# --- process time signature --- #
|
||||
quant_time_signature = deepcopy(midi_obj.time_signature_changes)
|
||||
quant_time_signature.sort(key=lambda x: x.time)
|
||||
for ts in quant_time_signature:
|
||||
ts.time = ts.time - offset_by_resol if ts.time - offset_by_resol > 0 else 0
|
||||
ts.time = int(round(ts.time / in_beat_tick_resol)) * in_beat_tick_resol
|
||||
|
||||
# --- make new midi object to check processed values --- #
|
||||
new_midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
|
||||
new_midi_obj.max_tick = midi_obj.max_tick
|
||||
for instr_idx in instr_grid.keys():
|
||||
new_instrument = Instrument(program=instr_idx)
|
||||
new_instrument.notes = [y for x in instr_grid[instr_idx].values() for y in x]
|
||||
new_midi_obj.instruments.append(new_instrument)
|
||||
new_midi_obj.markers = [y for x in chord_grid.values() for y in x]
|
||||
new_midi_obj.tempo_changes = [y for x in tempo_grid.values() for y in x]
|
||||
new_midi_obj.time_signature_changes = midi_obj.time_signature_changes
|
||||
|
||||
# make corpus
|
||||
song_data = {
|
||||
'notes': instr_grid,
|
||||
'chords': chord_grid,
|
||||
'tempos': tempo_grid,
|
||||
'metadata': {
|
||||
'first_note': first_note_time,
|
||||
'last_note': last_note_time,
|
||||
'time_signature': quant_time_signature,
|
||||
'ticks_per_beat': midi_obj.ticks_per_beat,
|
||||
}
|
||||
}
|
||||
return song_data, new_midi_obj
|
||||
|
||||
def _make_instr_notes(self, midi_obj):
|
||||
'''
|
||||
This part is important, we can use three different ways to merge instruments
|
||||
1st option: compare the number of notes and choose tracks with more notes
|
||||
2nd option: merge all instruments with the same tracks
|
||||
3rd option: leave all instruments as they are. differentiate tracks with different track number
|
||||
|
||||
In this version we choose to use the 2nd option as it helps to reduce the number of tracks and sequence length
|
||||
'''
|
||||
instr_notes = defaultdict(list)
|
||||
for instr in midi_obj.instruments:
|
||||
instr_idx = instr.program
|
||||
# change instrument idx
|
||||
instr_name = FINED_PROGRAM_INSTRUMENT_MAP.get(instr_idx)
|
||||
if instr_name is None:
|
||||
continue
|
||||
# new_instr_idx = INSTRUMENT_PROGRAM_MAP[instr_name]
|
||||
new_instr_idx = instr_idx
|
||||
if new_instr_idx not in instr_notes:
|
||||
instr_notes[new_instr_idx] = []
|
||||
instr_notes[new_instr_idx].extend(instr.notes)
|
||||
instr_notes[new_instr_idx].sort(key=lambda x: (x.start, -x.pitch))
|
||||
return instr_notes
|
||||
|
||||
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
|
||||
def _merge_percussion(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
'''
|
||||
merge drum track to one track
|
||||
'''
|
||||
drum_0_lst = []
|
||||
new_instruments = []
|
||||
for instrument in midi_obj.instruments:
|
||||
if len(instrument.notes) == 0:
|
||||
continue
|
||||
if instrument.is_drum:
|
||||
drum_0_lst.extend(instrument.notes)
|
||||
else:
|
||||
new_instruments.append(instrument)
|
||||
if len(drum_0_lst) > 0:
|
||||
drum_0_lst.sort(key=lambda x: x.start)
|
||||
# remove duplicate
|
||||
drum_0_lst = list(k for k, _ in itertools.groupby(drum_0_lst))
|
||||
drum_0_instrument = Instrument(program=114, is_drum=True, name="percussion")
|
||||
drum_0_instrument.notes = drum_0_lst
|
||||
new_instruments.append(drum_0_instrument)
|
||||
midi_obj.instruments = new_instruments
|
||||
|
||||
# referred to mmt "https://github.com/salu133445/mmt"
|
||||
def _pruning_instrument(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
'''
|
||||
merge instrument number with similar intrument category
|
||||
ex. 0: Acoustic Grand Piano, 1: Bright Acoustic Piano, 2: Electric Grand Piano into 0: Acoustic Grand Piano
|
||||
'''
|
||||
new_instruments = []
|
||||
for instr in midi_obj.instruments:
|
||||
instr_idx = instr.program
|
||||
# change instrument idx
|
||||
instr_name = PROGRAM_INSTRUMENT_MAP.get(instr_idx)
|
||||
if instr_name != None:
|
||||
new_instruments.append(instr)
|
||||
midi_obj.instruments = new_instruments
|
||||
|
||||
# refered to SymphonyNet "https://github.com/symphonynet/SymphonyNet"
|
||||
def _limit_max_track(self, midi_obj:miditoolkit.midi.parser.MidiFile, MAX_TRACK:int=16):
|
||||
'''
|
||||
merge track with least notes to other track with same program
|
||||
and limit the maximum amount of track to 16
|
||||
'''
|
||||
if len(midi_obj.instruments) == 1:
|
||||
if midi_obj.instruments[0].is_drum:
|
||||
midi_obj.instruments[0].program = 114
|
||||
midi_obj.instruments[0].is_drum = False
|
||||
return midi_obj
|
||||
good_instruments = midi_obj.instruments
|
||||
good_instruments.sort(
|
||||
key=lambda x: (not x.is_drum, -len(x.notes))) # place drum track or the most note track at first
|
||||
assert good_instruments[0].is_drum == True or len(good_instruments[0].notes) >= len(
|
||||
good_instruments[1].notes), tuple(len(x.notes) for x in good_instruments[:3])
|
||||
# assert good_instruments[0].is_drum == False, (, len(good_instruments[2]))
|
||||
track_idx_lst = list(range(len(good_instruments)))
|
||||
if len(good_instruments) > MAX_TRACK:
|
||||
new_good_instruments = copy.deepcopy(good_instruments[:MAX_TRACK])
|
||||
# print(midi_file_path)
|
||||
for id in track_idx_lst[MAX_TRACK:]:
|
||||
cur_ins = good_instruments[id]
|
||||
merged = False
|
||||
new_good_instruments.sort(key=lambda x: len(x.notes))
|
||||
for nid, ins in enumerate(new_good_instruments):
|
||||
if cur_ins.program == ins.program and cur_ins.is_drum == ins.is_drum:
|
||||
new_good_instruments[nid].notes.extend(cur_ins.notes)
|
||||
merged = True
|
||||
break
|
||||
if not merged:
|
||||
pass
|
||||
good_instruments = new_good_instruments
|
||||
|
||||
assert len(good_instruments) <= MAX_TRACK, len(good_instruments)
|
||||
for idx, good_instrument in enumerate(good_instruments):
|
||||
if good_instrument.is_drum:
|
||||
good_instruments[idx].program = 114
|
||||
good_instruments[idx].is_drum = False
|
||||
midi_obj.instruments = good_instruments
|
||||
|
||||
def _pruning_notes_for_chord_extraction(self, midi_obj:miditoolkit.midi.parser.MidiFile):
|
||||
'''
|
||||
extract notes for chord extraction
|
||||
'''
|
||||
new_midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||
new_midi_obj.ticks_per_beat = midi_obj.ticks_per_beat
|
||||
new_midi_obj.max_tick = midi_obj.max_tick
|
||||
new_instrument = Instrument(program=0, is_drum=False, name="for_chord")
|
||||
new_instruments = []
|
||||
new_notes = []
|
||||
for instrument in midi_obj.instruments:
|
||||
if instrument.program == 114 or instrument.is_drum: # pass drum track
|
||||
continue
|
||||
valid_notes = [note for note in instrument.notes if note.pitch >= 21 and note.pitch <= 108]
|
||||
new_notes.extend(valid_notes)
|
||||
new_notes.sort(key=lambda x: x.start)
|
||||
new_instrument.notes = new_notes
|
||||
new_instruments.append(new_instrument)
|
||||
new_midi_obj.instruments = new_instruments
|
||||
return new_midi_obj
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
required=True,
|
||||
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
||||
type=str,
|
||||
help="dataset names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--num_features",
|
||||
required=True,
|
||||
choices=(4, 5, 7, 8),
|
||||
type=int,
|
||||
help="number of features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--in_dir",
|
||||
default="../dataset/",
|
||||
type=Path,
|
||||
help="input data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--out_dir",
|
||||
default="../dataset/represented_data/corpus/",
|
||||
type=Path,
|
||||
help="output data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="enable debug mode",
|
||||
)
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = get_argument_parser()
|
||||
args = parser.parse_args()
|
||||
corpus_maker = CorpusMaker(args.dataset, args.num_features, args.in_dir, args.out_dir, args.debug)
|
||||
corpus_maker.make_corpus()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# python3 step1_midi2corpus.py --dataset SOD --num_features 5
|
||||
# python3 step2_corpus2event.py --dataset LakhClean --num_features 5 --encoding nb
|
||||
# python3 step3_creating_vocab.py --dataset SOD --num_features 5 --encoding nb
|
||||
# python3 step4_event2tuneidx.py --dataset SOD --num_features 5 --encoding nb
|
||||
147
data_representation/step2_corpus2event.py
Normal file
147
data_representation/step2_corpus2event.py
Normal file
@ -0,0 +1,147 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Pool
|
||||
|
||||
import encoding_utils
|
||||
|
||||
'''
|
||||
This script is for converting corpus data to event data.
|
||||
'''
|
||||
|
||||
class Corpus2Event():
|
||||
def __init__(
|
||||
self,
|
||||
dataset: str,
|
||||
encoding_scheme: str,
|
||||
num_features: int,
|
||||
in_dir: Path,
|
||||
out_dir: Path,
|
||||
debug: bool,
|
||||
cache: bool,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.encoding_name = encoding_scheme + str(num_features)
|
||||
self.in_dir = in_dir / f"corpus_{self.dataset}"
|
||||
self.out_dir = out_dir / f"events_{self.dataset}" / self.encoding_name
|
||||
self.debug = debug
|
||||
self.cache = cache
|
||||
self.encoding_function = getattr(encoding_utils, f'Corpus2event_{encoding_scheme}')(num_features)
|
||||
self._get_in_beat_resolution()
|
||||
|
||||
def _get_in_beat_resolution(self):
|
||||
# Retrieve the resolution of quarter note based on the dataset name (e.g., 4 means the minimum resolution sets to 16th note)
|
||||
in_beat_resolution_dict = {'BachChorale': 4, 'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4, 'SymphonyMIDI': 8}
|
||||
try:
|
||||
self.in_beat_resolution = in_beat_resolution_dict[self.dataset]
|
||||
except KeyError:
|
||||
print(f"Dataset {self.dataset} is not supported. use the setting of LakhClean")
|
||||
self.in_beat_resolution = in_beat_resolution_dict['LakhClean']
|
||||
|
||||
def make_events(self):
|
||||
'''
|
||||
Preprocess corpus data to events data.
|
||||
The process in each encoding scheme is different.
|
||||
Please refer to encoding_utils.py for more details.
|
||||
'''
|
||||
print("preprocessing corpus data to events data")
|
||||
# check output directory exists
|
||||
self.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
start_time = time.time()
|
||||
# single-processing
|
||||
broken_count = 0
|
||||
success_count = 0
|
||||
corpus_list = sorted(list(self.in_dir.rglob("*.pkl")))
|
||||
if corpus_list == []:
|
||||
print(f"No corpus files found in {self.in_dir}. Please check the directory.")
|
||||
corpus_list = sorted(list(self.in_dir.glob("*.pkli")))
|
||||
# remove the corpus files that are already in the out_dir
|
||||
# Use set for faster existence checks
|
||||
existing_files = set(f.name for f in self.out_dir.glob("*.pkl"))
|
||||
# corpus_list = [corpus for corpus in corpus_list if corpus.name not in existing_files]
|
||||
for filepath_name, event in tqdm(map(self._load_single_corpus_and_make_event, corpus_list), total=len(corpus_list)):
|
||||
if event is None:
|
||||
broken_count += 1
|
||||
continue
|
||||
# if using cache, check if the event file already exists
|
||||
if self.cache and (self.out_dir / filepath_name).exists():
|
||||
# print(f"event file {filepath_name} already exists, skipping")
|
||||
continue
|
||||
with open(self.out_dir / filepath_name, 'wb') as f:
|
||||
pickle.dump(event, f)
|
||||
success_count += 1
|
||||
del event
|
||||
print(f"taken time for making events is {time.time()-start_time}s, success: {success_count}, broken: {broken_count}")
|
||||
|
||||
def _load_single_corpus_and_make_event(self, file_path):
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
corpus = pickle.load(f)
|
||||
event = self.encoding_function(corpus, self.in_beat_resolution)
|
||||
except Exception as e:
|
||||
print(f"error in encoding {file_path}: {e}")
|
||||
event = None
|
||||
return file_path.name, event
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
required=True,
|
||||
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
||||
type=str,
|
||||
help="dataset names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--encoding",
|
||||
required=True,
|
||||
choices=("remi", "cp", "nb", "remi_pos"),
|
||||
type=str,
|
||||
help="encoding scheme",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--num_features",
|
||||
required=True,
|
||||
choices=(4, 5, 7, 8),
|
||||
type=int,
|
||||
help="number of features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--in_dir",
|
||||
default="../dataset/represented_data/corpus/",
|
||||
type=Path,
|
||||
help="input data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--out_dir",
|
||||
default="../dataset/represented_data/events/",
|
||||
type=Path,
|
||||
help="output data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="enable debug mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache",
|
||||
action="store_true",
|
||||
help="enable cache mode",
|
||||
)
|
||||
return parser
|
||||
|
||||
def main():
|
||||
args = get_argument_parser().parse_args()
|
||||
corpus2event = Corpus2Event(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug, args.cache)
|
||||
corpus2event.make_events()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
data_representation/step3_creating_vocab.py
Normal file
84
data_representation/step3_creating_vocab.py
Normal file
@ -0,0 +1,84 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import vocab_utils
|
||||
|
||||
'''
|
||||
This script is for creating vocab file.
|
||||
'''
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
required=True,
|
||||
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
||||
type=str,
|
||||
help="dataset names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--encoding",
|
||||
required=True,
|
||||
choices=("remi", "cp", "nb"),
|
||||
type=str,
|
||||
help="encoding scheme",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--num_features",
|
||||
required=True,
|
||||
choices=(4, 5, 7, 8),
|
||||
type=int,
|
||||
help="number of features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--in_dir",
|
||||
default="../dataset/represented_data/events/",
|
||||
type=Path,
|
||||
help="input data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--out_dir",
|
||||
default="../vocab/",
|
||||
type=Path,
|
||||
help="output data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="enable debug mode",
|
||||
)
|
||||
return parser
|
||||
|
||||
def main():
|
||||
args = get_argument_parser().parse_args()
|
||||
encoding_scheme = args.encoding
|
||||
num_features = args.num_features
|
||||
dataset = args.dataset
|
||||
|
||||
out_vocab_path = args.out_dir / f"vocab_{dataset}"
|
||||
out_vocab_path.mkdir(parents=True, exist_ok=True)
|
||||
out_vocab_file_path = out_vocab_path / f"vocab_{dataset}_{encoding_scheme}{num_features}.json"
|
||||
|
||||
events_path = Path(args.in_dir / f"events_{dataset}" / f"{encoding_scheme}{num_features}")
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||
selected_vocab_name = vocab_name[encoding_scheme]
|
||||
event_data = sorted(list(events_path.rglob("*.pkl")))
|
||||
if event_data == []:
|
||||
print(f"No event files found in {events_path}. Please check the directory.")
|
||||
event_data = sorted(list(events_path.glob("*.pkli")))
|
||||
vocab = getattr(vocab_utils, selected_vocab_name)(
|
||||
in_vocab_file_path=None,
|
||||
event_data=event_data,
|
||||
encoding_scheme=encoding_scheme,
|
||||
num_features=num_features
|
||||
)
|
||||
vocab.save_vocab(out_vocab_file_path)
|
||||
print(f"Vocab file saved at {out_vocab_file_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
data_representation/step4_event2tuneidx.py
Normal file
127
data_representation/step4_event2tuneidx.py
Normal file
@ -0,0 +1,127 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
import vocab_utils
|
||||
|
||||
class Event2tuneidx():
|
||||
def __init__(
|
||||
self,
|
||||
dataset: str,
|
||||
encoding_scheme: str,
|
||||
num_features: int,
|
||||
in_dir: Path,
|
||||
out_dir: Path,
|
||||
debug: bool
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.encoding_scheme = encoding_scheme
|
||||
self.encoding_name = encoding_scheme + str(num_features)
|
||||
self.in_dir = in_dir / f"events_{self.dataset}" / self.encoding_name
|
||||
self.out_dir = out_dir / f"tuneidx_{self.dataset}" / self.encoding_name
|
||||
self.debug = debug
|
||||
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||
selected_vocab_name = vocab_name[encoding_scheme]
|
||||
in_vocab_file_path = Path(f"../vocab/vocab_{dataset}/vocab_{dataset}_{encoding_scheme}{num_features}.json")
|
||||
self.vocab = getattr(vocab_utils, selected_vocab_name)(in_vocab_file_path=in_vocab_file_path, event_data=None,
|
||||
encoding_scheme=encoding_scheme, num_features=num_features)
|
||||
|
||||
def _convert_event_to_tune_in_idx(self, tune_in_event):
|
||||
tune_in_idx = []
|
||||
for event in tune_in_event:
|
||||
event_in_idx = self.vocab(event)
|
||||
if event_in_idx != None:
|
||||
tune_in_idx.append(event_in_idx)
|
||||
return tune_in_idx
|
||||
|
||||
def _load_single_event_and_make_tune_in_idx(self, file_path):
|
||||
with open(file_path, 'rb') as f:
|
||||
tune_in_event = pickle.load(f)
|
||||
tune_in_idx = self._convert_event_to_tune_in_idx(tune_in_event)
|
||||
return file_path.name, tune_in_idx
|
||||
|
||||
def make_tune_in_idx(self):
|
||||
print("preprocessing events data to tune_in_idx data")
|
||||
# check output directory exists
|
||||
self.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
start_time = time.time()
|
||||
event_list = sorted(list(self.in_dir.rglob("*.pkl")))
|
||||
if event_list == []:
|
||||
event_list = sorted(list(self.in_dir.glob("*.pkli")))
|
||||
for filepath_name, tune_in_idx in tqdm(map(self._load_single_event_and_make_tune_in_idx, event_list), total=len(event_list)):
|
||||
# save tune_in_idx as npz file with uint16 dtype for remi because it has more than 256 tokens
|
||||
if self.encoding_scheme == 'remi':
|
||||
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
|
||||
else:
|
||||
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
|
||||
if np.max(tune_in_idx) < 256:
|
||||
tune_in_idx = np.array(tune_in_idx, dtype=np.uint8)
|
||||
if filepath_name.endswith('.pkli'):
|
||||
file_name = filepath_name.replace('.pkli', '.npz')
|
||||
else:
|
||||
file_name = filepath_name.replace('.pkl', '.npz')
|
||||
np.savez_compressed(self.out_dir / file_name, tune_in_idx)
|
||||
del tune_in_idx
|
||||
print(f"taken time for making tune_in_idx is {time.time()-start_time}")
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
required=True,
|
||||
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
||||
type=str,
|
||||
help="dataset names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--encoding",
|
||||
required=True,
|
||||
choices=("remi", "cp", "nb"),
|
||||
type=str,
|
||||
help="encoding scheme",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--num_features",
|
||||
required=True,
|
||||
choices=(4, 5, 7, 8),
|
||||
type=int,
|
||||
help="number of features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--in_dir",
|
||||
default="../dataset/represented_data/events/",
|
||||
type=Path,
|
||||
help="input data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--out_dir",
|
||||
default="../dataset/represented_data/tuneidx/",
|
||||
type=Path,
|
||||
help="output data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="enable debug mode",
|
||||
)
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = get_argument_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
event2tuneidx = Event2tuneidx(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug)
|
||||
event2tuneidx.make_tune_in_idx()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
data_representation/step4_event2tuneidx_addprompt.py
Normal file
122
data_representation/step4_event2tuneidx_addprompt.py
Normal file
@ -0,0 +1,122 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
import vocab_utils
|
||||
|
||||
class Event2tuneidx():
|
||||
def __init__(
|
||||
self,
|
||||
dataset: str,
|
||||
encoding_scheme: str,
|
||||
num_features: int,
|
||||
in_dir: Path,
|
||||
out_dir: Path,
|
||||
debug: bool
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.encoding_scheme = encoding_scheme
|
||||
self.encoding_name = encoding_scheme + str(num_features)
|
||||
self.in_dir = in_dir / f"events_{self.dataset}" / self.encoding_name
|
||||
self.out_dir = out_dir / f"tuneidx_{self.dataset}" / self.encoding_name
|
||||
self.debug = debug
|
||||
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||
selected_vocab_name = vocab_name[encoding_scheme]
|
||||
in_vocab_file_path = Path(f"../vocab/vocab_{dataset}/vocab_{dataset}_{encoding_scheme}{num_features}.json")
|
||||
self.vocab = getattr(vocab_utils, selected_vocab_name)(in_vocab_file_path=in_vocab_file_path, event_data=None,
|
||||
encoding_scheme=encoding_scheme, num_features=num_features)
|
||||
|
||||
def _convert_event_to_tune_in_idx(self, tune_in_event):
|
||||
tune_in_idx = []
|
||||
for event in tune_in_event:
|
||||
event_in_idx = self.vocab(event)
|
||||
if event_in_idx != None:
|
||||
tune_in_idx.append(event_in_idx)
|
||||
return tune_in_idx
|
||||
|
||||
def _load_single_event_and_make_tune_in_idx(self, file_path):
|
||||
with open(file_path, 'rb') as f:
|
||||
tune_in_event = pickle.load(f)
|
||||
tune_in_idx = self._convert_event_to_tune_in_idx(tune_in_event)
|
||||
return file_path.name, tune_in_idx
|
||||
|
||||
def make_tune_in_idx(self):
|
||||
print("preprocessing events data to tune_in_idx data")
|
||||
# check output directory exists
|
||||
self.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
start_time = time.time()
|
||||
event_list = sorted(list(self.in_dir.rglob("*.pkl")))
|
||||
for filepath_name, tune_in_idx in tqdm(map(self._load_single_event_and_make_tune_in_idx, event_list), total=len(event_list)):
|
||||
# save tune_in_idx as npz file with uint16 dtype for remi because it has more than 256 tokens
|
||||
if self.encoding_scheme == 'remi':
|
||||
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
|
||||
else:
|
||||
tune_in_idx = np.array(tune_in_idx, dtype=np.int16)
|
||||
if np.max(tune_in_idx) < 256:
|
||||
tune_in_idx = np.array(tune_in_idx, dtype=np.uint8)
|
||||
file_name = filepath_name.replace('.pkl', '.npz')
|
||||
np.savez_compressed(self.out_dir / file_name, tune_in_idx)
|
||||
del tune_in_idx
|
||||
print(f"taken time for making tune_in_idx is {time.time()-start_time}")
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
required=True,
|
||||
# choices=("BachChorale", "Pop1k7", "Pop909", "SOD", "LakhClean", "SymphonyMIDI"),
|
||||
type=str,
|
||||
help="dataset names",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--encoding",
|
||||
required=True,
|
||||
choices=("remi", "cp", "nb"),
|
||||
type=str,
|
||||
help="encoding scheme",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--num_features",
|
||||
required=True,
|
||||
choices=(4, 5, 7, 8),
|
||||
type=int,
|
||||
help="number of features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--in_dir",
|
||||
default="../dataset/represented_data/events/",
|
||||
type=Path,
|
||||
help="input data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--out_dir",
|
||||
default="../dataset/represented_data/tuneidx_withcaption/",
|
||||
type=Path,
|
||||
help="output data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="enable debug mode",
|
||||
)
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = get_argument_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
event2tuneidx = Event2tuneidx(args.dataset, args.encoding, args.num_features, args.in_dir, args.out_dir, args.debug)
|
||||
event2tuneidx.make_tune_in_idx()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
395
data_representation/vocab_utils.py
Normal file
395
data_representation/vocab_utils.py
Normal file
@ -0,0 +1,395 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from collections import defaultdict
|
||||
from fractions import Fraction
|
||||
|
||||
import torch
|
||||
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
def sort_key(s):
|
||||
fraction_part = s.split('_')[-1]
|
||||
numerator, denominator = map(int, fraction_part.split('/'))
|
||||
# Return a tuple with denominator first, then numerator, both in negative for descending order
|
||||
return (-denominator, -numerator)
|
||||
|
||||
class LangTokenVocab:
|
||||
def __init__(
|
||||
self,
|
||||
in_vocab_file_path:Union[Path, None],
|
||||
event_data: list,
|
||||
encoding_scheme: str,
|
||||
num_features: int
|
||||
):
|
||||
'''
|
||||
Initializes the LangTokenVocab class.
|
||||
|
||||
Args:
|
||||
in_vocab_file_path (Union[Path, None]): Path to the pre-made vocabulary file (optional).
|
||||
event_data (list): List of event data to create a vocabulary if no pre-made vocab is provided.
|
||||
encoding_scheme (str): Encoding scheme to be used (e.g., 'remi', 'cp', 'nb').
|
||||
num_features (int): Number of features to be used (e.g., 4, 5, 7, 8).
|
||||
|
||||
Summary:
|
||||
This class is responsible for handling vocabularies used in language models, especially for REMI encoding.
|
||||
It supports multiple encoding schemes, creates vocabularies based on event data, handles special tokens (e.g.,
|
||||
start/end of sequence), and manages feature-specific masks. It provides methods for saving, loading, and decoding
|
||||
vocabularies. It also supports vocabulary augmentation for pitch, instrument, beat, and chord features, ensuring
|
||||
that these are arranged and ordered appropriately.
|
||||
|
||||
For all encoding schemes, the metric or special tokens are named as 'type',
|
||||
so that we can easily handle and compare among different encoding schemes.
|
||||
'''
|
||||
|
||||
self.encoding_scheme = encoding_scheme
|
||||
self.num_features = num_features
|
||||
self._prepare_in_vocab(in_vocab_file_path, event_data) # Prepares initial vocab based on the input file or event data
|
||||
self._get_features() # Extracts relevant features based on the num_features
|
||||
self.idx2event, self.event2idx = self._get_vocab(event_data, unique_vocabs=self.idx2event) # Creates vocab or loads premade vocab
|
||||
if self.encoding_scheme == 'remi':
|
||||
self._make_mask() # Generates masks for 'remi' encoding scheme
|
||||
self._get_sos_eos_token() # Retrieves special tokens (Start of Sequence, End of Sequence)
|
||||
|
||||
# Prepares vocabulary if a pre-made vocab file exists or handles cases with no input file.
|
||||
def _prepare_in_vocab(self, in_vocab_file_path, event_data):
|
||||
if in_vocab_file_path is not None:
|
||||
with open(in_vocab_file_path, 'r') as f:
|
||||
idx2event_temp = json.load(f)
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
for key in idx2event_temp.keys():
|
||||
idx2event_temp[key] = {int(idx):tok for idx, tok in idx2event_temp[key].items()}
|
||||
elif self.encoding_scheme == 'remi':
|
||||
idx2event_temp = {int(idx):tok for idx, tok in idx2event_temp.items()}
|
||||
self.idx2event = idx2event_temp
|
||||
elif in_vocab_file_path is None and event_data is None:
|
||||
raise NotImplementedError('either premade vocab or event_data should be given')
|
||||
else:
|
||||
self.idx2event = None
|
||||
|
||||
# Extracts features depending on the number of features chosen (4, 5, 7, 8).
|
||||
def _get_features(self):
|
||||
feature_args = {
|
||||
4: ["type", "beat", "pitch", "duration"],
|
||||
5: ["type", "beat", "instrument", "pitch", "duration"],
|
||||
7: ["type", "beat", "chord", "tempo", "pitch", "duration", "velocity"],
|
||||
8: ["type", "beat", "chord", "tempo", "instrument", "pitch", "duration", "velocity"]}
|
||||
self.feature_list = feature_args[self.num_features]
|
||||
|
||||
# Saves the current vocabulary to a specified JSON path.
|
||||
def save_vocab(self, json_path):
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump(self.idx2event, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Returns the size of the current vocabulary.
|
||||
def get_vocab_size(self):
|
||||
return len(self.idx2event)
|
||||
|
||||
# Handles Start of Sequence (SOS) and End of Sequence (EOS) tokens based on the encoding scheme.
|
||||
def _get_sos_eos_token(self):
|
||||
if self.encoding_scheme == 'remi':
|
||||
self.sos_token = [self.event2idx['SOS_None']]
|
||||
self.eos_token = [[self.event2idx['EOS_None']]]
|
||||
else:
|
||||
self.sos_token = [[self.event2idx['type']['SOS']] + [0] * (self.num_features - 1)]
|
||||
self.eos_token = [[self.event2idx['type']['EOS']] + [0] * (self.num_features - 1)]
|
||||
|
||||
# Generates vocabularies by either loading from a file or creating them based on the event data.
|
||||
def _get_vocab(self, event_data, unique_vocabs=None):
|
||||
# make new vocab from given event_data
|
||||
if event_data is not None:
|
||||
unique_char_list = list(set([f'{event["name"]}_{event["value"]}' for tune_path in event_data for event in pickle.load(open(tune_path, 'rb'))]))
|
||||
unique_vocabs = sorted(unique_char_list)
|
||||
unique_vocabs.remove('SOS_None')
|
||||
unique_vocabs.remove('EOS_None')
|
||||
unique_vocabs.remove('Bar_None')
|
||||
new_unique_vocab = self._augment_pitch_vocab(unique_vocabs)
|
||||
if self.num_features == 5 or self.num_features == 8:
|
||||
new_unique_vocab = self._arange_instrument_vocab(new_unique_vocab)
|
||||
if self.num_features == 7 or self.num_features == 8:
|
||||
new_unique_vocab = self._arange_chord_vocab(new_unique_vocab)
|
||||
new_unique_vocab = self._arange_beat_vocab(new_unique_vocab)
|
||||
new_unique_vocab.insert(0, 'SOS_None')
|
||||
new_unique_vocab.insert(1, 'EOS_None')
|
||||
new_unique_vocab.insert(2, 'Bar_None')
|
||||
idx2event = {int(idx) : tok for idx, tok in enumerate(new_unique_vocab)}
|
||||
event2idx = {tok : int(idx) for idx, tok in idx2event.items()}
|
||||
# load premade vocab
|
||||
else:
|
||||
idx2event = unique_vocabs
|
||||
event2idx = {tok : int(idx) for idx, tok in unique_vocabs.items()}
|
||||
return idx2event, event2idx
|
||||
|
||||
# Augments the pitch vocabulary by expanding the range of pitch values.
|
||||
def _augment_pitch_vocab(self, unique_vocabs):
|
||||
pitch_vocab = [x for x in unique_vocabs if 'Note_Pitch_' in x]
|
||||
pitch_int = [int(x.replace('Note_Pitch_', '')) for x in pitch_vocab if x.replace('Note_Pitch_', '').isdigit()]
|
||||
min_pitch = min(pitch_int)
|
||||
max_pitch = max(pitch_int)
|
||||
min_pitch_margin = max(min_pitch-6, 0)
|
||||
max_pitch_margin = min(max_pitch+7, 127)
|
||||
new_pitch_vocab = sorted([f'Note_Pitch_{x}' for x in range(min_pitch_margin, max_pitch_margin+1)], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
new_unique_vocab = [x for x in unique_vocabs if x not in new_pitch_vocab] + new_pitch_vocab
|
||||
return new_unique_vocab
|
||||
|
||||
# Orders and arranges the instrument vocabulary.
|
||||
def _arange_instrument_vocab(self, unique_vocabs):
|
||||
instrument_vocab = [x for x in unique_vocabs if 'Instrument_' in x]
|
||||
new_instrument_vocab = sorted(instrument_vocab, key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
new_unique_vocab = [x for x in unique_vocabs if x not in new_instrument_vocab] + new_instrument_vocab
|
||||
return new_unique_vocab
|
||||
|
||||
# Orders and arranges the chord vocabulary, ensuring 'Chord_N_N' is the last token.
|
||||
def _arange_chord_vocab(self, unique_vocabs):
|
||||
'''
|
||||
for chord augmentation
|
||||
Chord_N_N should be the last token in the list for an easy implementation of chord augmentation
|
||||
'''
|
||||
chord_vocab = [x for x in unique_vocabs if 'Chord_' in x]
|
||||
chord_vocab.remove('Chord_N_N')
|
||||
new_chord_vocab = sorted(chord_vocab, key=lambda x: (not isinstance(x, int), x.split('_')[-1] if isinstance(x, str) else x, x.split('_')[1] if isinstance(x, str) else x))
|
||||
new_chord_vocab.append('Chord_N_N')
|
||||
new_unique_vocab = [x for x in unique_vocabs if x not in new_chord_vocab] + new_chord_vocab
|
||||
return new_unique_vocab
|
||||
|
||||
# Orders and arranges the beat vocabulary.
|
||||
def _arange_beat_vocab(self, unique_vocabs):
|
||||
beat_vocab = [x for x in unique_vocabs if 'Beat_' in x]
|
||||
new_beat_vocab = sorted(beat_vocab, key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
count = 0
|
||||
for idx, token in enumerate(unique_vocabs):
|
||||
if 'Beat_' in token:
|
||||
unique_vocabs[idx] = new_beat_vocab[count]
|
||||
count += 1
|
||||
return unique_vocabs
|
||||
|
||||
# Generates masks for the 'remi' encoding scheme.
|
||||
def _make_mask(self):
|
||||
'''
|
||||
This function is used to extract the target musical features for validation.
|
||||
'''
|
||||
idx2feature = {}
|
||||
for idx, feature in self.idx2event.items():
|
||||
if feature.startswith('SOS') or feature.startswith('EOS') or feature.startswith('Bar'):
|
||||
idx2feature[idx] = 'type'
|
||||
elif feature.startswith('Beat'):
|
||||
idx2feature[idx] = 'beat'
|
||||
elif feature.startswith('Chord'):
|
||||
idx2feature[idx] = 'chord'
|
||||
elif feature.startswith('Tempo'):
|
||||
idx2feature[idx] = 'tempo'
|
||||
elif feature.startswith('Note_Pitch'):
|
||||
idx2feature[idx] = 'pitch'
|
||||
elif feature.startswith('Note_Duration'):
|
||||
idx2feature[idx] = 'duration'
|
||||
elif feature.startswith('Note_Velocity'):
|
||||
idx2feature[idx] = 'velocity'
|
||||
elif feature.startswith('Instrument'):
|
||||
idx2feature[idx] = 'instrument'
|
||||
|
||||
self.total_mask = {}
|
||||
self.remi_vocab_boundaries_by_key = {}
|
||||
for target in self.feature_list:
|
||||
mask = [0] * len(idx2feature) # Initialize all-zero list of length equal to dictionary
|
||||
for key, value in idx2feature.items():
|
||||
if value == target:
|
||||
mask[int(key)] = 1 # If value equals target, set corresponding position in mask to 1
|
||||
mask = torch.LongTensor(mask)
|
||||
self.total_mask[target] = mask
|
||||
start_idx, end_idx = torch.argwhere(mask == 1).flatten().tolist()[0], torch.argwhere(mask == 1).flatten().tolist()[-1]
|
||||
self.remi_vocab_boundaries_by_key[target] = (start_idx, end_idx+1)
|
||||
|
||||
def decode(self, events:torch.Tensor):
|
||||
'''
|
||||
Used for checking events in the evaluation
|
||||
events: 1d tensor
|
||||
'''
|
||||
decoded_list = []
|
||||
for event in events:
|
||||
decoded_list.append(self.idx2event[event.item()])
|
||||
return decoded_list
|
||||
|
||||
def __call__(self, word):
|
||||
'''
|
||||
for remi style encoding
|
||||
'''
|
||||
return self.event2idx[f"{word['name']}_{word['value']}"]
|
||||
|
||||
class MusicTokenVocabCP(LangTokenVocab):
|
||||
def __init__(
|
||||
self,
|
||||
in_vocab_file_path:Union[Path, None],
|
||||
event_data: list,
|
||||
encoding_scheme: str,
|
||||
num_features: int
|
||||
):
|
||||
# Initialize the vocabulary class with vocab file path, event data, encoding scheme, and feature count
|
||||
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
|
||||
|
||||
def _augment_pitch_vocab(self, unique_vocabs):
|
||||
# Extract pitch-related vocabularies and adjust pitch range
|
||||
pitch_total_vocab = unique_vocabs['pitch']
|
||||
pitch_vocab = [x for x in pitch_total_vocab if 'Note_Pitch_' in str(x)]
|
||||
pitch_int = [int(x.replace('Note_Pitch_', '')) for x in pitch_vocab if x.replace('Note_Pitch_', '').isdigit()]
|
||||
# Determine the min and max pitch values and extend the pitch range slightly
|
||||
min_pitch = min(pitch_int)
|
||||
max_pitch = max(pitch_int)
|
||||
min_pitch_margin = max(min_pitch - 6, 0)
|
||||
max_pitch_margin = min(max_pitch + 7, 127)
|
||||
# Create new pitch vocab and ensure new entries do not overlap with existing ones
|
||||
new_pitch_vocab = [f'Note_Pitch_{x}' for x in range(min_pitch_margin, max_pitch_margin + 1)]
|
||||
new_pitch_vocab = [x for x in pitch_total_vocab if str(x) not in new_pitch_vocab] + new_pitch_vocab
|
||||
unique_vocabs['pitch'] = new_pitch_vocab
|
||||
return unique_vocabs
|
||||
|
||||
def _mp_get_unique_vocab(self, tune, features):
|
||||
# Read event data from a file and collect unique vocabularies for specified features
|
||||
with open(tune, 'rb') as f:
|
||||
events_list = pickle.load(f)
|
||||
unique_vocabs = defaultdict(set)
|
||||
for event in events_list:
|
||||
for key in features:
|
||||
unique_vocabs[key].add(event[key])
|
||||
return unique_vocabs
|
||||
|
||||
def _get_chord_vocab(self):
|
||||
'''
|
||||
Manually define the chord vocabulary by combining roots and qualities
|
||||
from a predefined list. This is used for chord augmentation.
|
||||
'''
|
||||
root_list = ['A', 'A#', 'B', 'C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#']
|
||||
quality_list = ['+', '/o7', '7', 'M', 'M7', 'm', 'm7', 'o', 'o7', 'sus2', 'sus4']
|
||||
chord_vocab = [f'Chord_{root}_{quality}' for root in root_list for quality in quality_list]
|
||||
# Sort the chord vocabulary based on the root and quality
|
||||
chord_vocab = sorted(chord_vocab, key=lambda x: (not isinstance(x, int), x.split('_')[-1] if isinstance(x, str) else x, x.split('_')[0] if isinstance(x, str) else x))
|
||||
return chord_vocab
|
||||
|
||||
def _cp_sort_type(self, unique_vocabs):
|
||||
# Similar to _nb_sort_type but used for the 'cp' encoding scheme, sorting vocabularies in a different order
|
||||
unique_vocabs.remove('SOS')
|
||||
unique_vocabs.remove('EOS')
|
||||
unique_vocabs.remove('Metrical')
|
||||
unique_vocabs.remove('Note')
|
||||
vocab_list = list(unique_vocabs)
|
||||
unique_vocabs = sorted(vocab_list, key=sort_key)
|
||||
unique_vocabs.insert(0, 'SOS')
|
||||
unique_vocabs.insert(1, 'EOS')
|
||||
unique_vocabs.insert(2, 'Metrical')
|
||||
unique_vocabs.insert(3, 'Note')
|
||||
return unique_vocabs
|
||||
|
||||
# Define custom sorting function
|
||||
def sort_type_cp(self, item):
|
||||
if item == 0:
|
||||
return (0, 0) # Move 0 to the beginning
|
||||
elif isinstance(item, str):
|
||||
if item.startswith("Bar"):
|
||||
return (1, item) # "Bar" items come next, sorted lexicographically
|
||||
elif item.startswith("Beat"):
|
||||
# Extract numeric part of "Beat_x" to sort numerically
|
||||
beat_number = int(item.split('_')[1])
|
||||
return (2, beat_number) # "Beat" items come last, sorted by number
|
||||
return (3, item) # Catch-all for anything unexpected (shouldn't be necessary here)
|
||||
|
||||
def _get_vocab(self, event_data, unique_vocabs=None):
|
||||
if event_data is not None:
|
||||
# Create vocab mappings (event2idx, idx2event) from the provided event data
|
||||
print('start to get unique vocab')
|
||||
event2idx = {}
|
||||
idx2event = {}
|
||||
unique_vocabs = defaultdict(set)
|
||||
# Use multiprocessing to extract unique vocabularies for each event
|
||||
with Pool(16) as p:
|
||||
results = p.starmap(self._mp_get_unique_vocab, tqdm([(tune, self.feature_list) for tune in event_data]))
|
||||
# Combine results from different processes
|
||||
for result in results:
|
||||
for key in self.feature_list:
|
||||
if key == 'chord': # Chords are handled separately
|
||||
continue
|
||||
unique_vocabs[key].update(result[key])
|
||||
# Augment pitch vocab and add manually defined chord vocab
|
||||
unique_vocabs = self._augment_pitch_vocab(unique_vocabs)
|
||||
unique_vocabs['chord'] = self._get_chord_vocab()
|
||||
# Process each feature type, handling special cases like 'tempo' and 'chord'
|
||||
for key in self.feature_list:
|
||||
if key == 'tempo':
|
||||
remove_nn_flag = False
|
||||
if 'Tempo_N_N' in unique_vocabs[key]:
|
||||
unique_vocabs[key].remove('Tempo_N_N')
|
||||
remove_nn_flag = True
|
||||
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
if remove_nn_flag:
|
||||
unique_vocabs[key].insert(1, 'Tempo_N_N')
|
||||
elif key == 'chord':
|
||||
unique_vocabs[key].insert(0, 0)
|
||||
unique_vocabs[key].insert(1, 'Chord_N_N')
|
||||
elif key == 'type': # Sort 'type' vocab depending on the encoding scheme
|
||||
if self.encoding_scheme == 'cp':
|
||||
unique_vocabs[key] = self._cp_sort_type(unique_vocabs[key])
|
||||
else: # NB encoding scheme
|
||||
unique_vocabs[key] = self._nb_sort_type(unique_vocabs[key])
|
||||
elif key == 'beat' and self.encoding_scheme == 'cp': # Handle 'beat' vocab with 'cp' scheme
|
||||
# unique_vocabs[key].remove('Bar')
|
||||
# unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), Fraction(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
# unique_vocabs[key].insert(1, 'Bar')
|
||||
unique_vocabs[key] = sorted(unique_vocabs[key], key = self.sort_type_cp)
|
||||
elif key == 'beat' and self.encoding_scheme == 'nb': # Handle 'beat' vocab with 'nb' scheme
|
||||
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
elif key == 'instrument': # Sort 'instrument' vocab by integer values
|
||||
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
else: # Default case: sort by integer values for other keys
|
||||
unique_vocabs[key] = sorted(unique_vocabs[key], key=lambda x: (not isinstance(x, int), int(x.split('_')[-1] if isinstance(x, str) else x)))
|
||||
# Create event2idx and idx2event mappings for each feature
|
||||
event2idx[key] = {tok: int(idx) for idx, tok in enumerate(unique_vocabs[key])}
|
||||
idx2event[key] = {int(idx): tok for idx, tok in enumerate(unique_vocabs[key])}
|
||||
return idx2event, event2idx
|
||||
else:
|
||||
# If no event data, simply map unique vocab to indexes
|
||||
event2idx = {}
|
||||
for key in self.feature_list:
|
||||
event2idx[key] = {tok: int(idx) for idx, tok in unique_vocabs[key].items()}
|
||||
return unique_vocabs, event2idx
|
||||
|
||||
def get_vocab_size(self):
|
||||
# Return the size of the vocabulary for each feature
|
||||
return {key: len(self.idx2event[key]) for key in self.feature_list}
|
||||
|
||||
def __call__(self, event):
|
||||
# Convert an event to its corresponding indices
|
||||
return [self.event2idx[key][event[key]] for key in self.feature_list]
|
||||
|
||||
def decode(self, events:torch.Tensor):
|
||||
decoded_list = []
|
||||
for event in events:
|
||||
decoded_list.append([self.idx2event[key][event[idx].item()] for idx, key in enumerate(self.feature_list)])
|
||||
return decoded_list
|
||||
|
||||
class MusicTokenVocabNB(MusicTokenVocabCP):
|
||||
def __init__(
|
||||
self,
|
||||
in_vocab_file_path:Union[Path, None],
|
||||
event_data: list,
|
||||
encoding_scheme: str,
|
||||
num_features: int
|
||||
):
|
||||
super().__init__(in_vocab_file_path, event_data, encoding_scheme, num_features)
|
||||
|
||||
def _nb_sort_type(self, unique_vocabs):
|
||||
# Remove special tokens and sort the remaining vocab list, then re-insert the special tokens in order
|
||||
unique_vocabs.remove('SOS')
|
||||
unique_vocabs.remove('EOS')
|
||||
unique_vocabs.remove('Empty_Bar')
|
||||
unique_vocabs.remove('SSS')
|
||||
unique_vocabs.remove('SSN')
|
||||
unique_vocabs.remove('SNN')
|
||||
vocab_list = list(unique_vocabs)
|
||||
unique_vocabs = sorted(vocab_list, key=sort_key)
|
||||
unique_vocabs.insert(0, 'SOS')
|
||||
unique_vocabs.insert(1, 'EOS')
|
||||
unique_vocabs.insert(2, 'Empty_Bar')
|
||||
unique_vocabs.insert(3, 'SSS')
|
||||
unique_vocabs.insert(4, 'SSN')
|
||||
unique_vocabs.insert(5, 'SNN')
|
||||
return unique_vocabs
|
||||
Reference in New Issue
Block a user